Compare commits
13 Commits
0.2.5
...
03659164ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03659164ba | ||
|
|
709dd1dba3 | ||
|
|
26595396cf | ||
|
|
fb7892b52b | ||
|
|
b08e055dac | ||
|
|
9d993cfc8a | ||
|
|
f4f8ea429a | ||
|
|
38bea3735f | ||
|
|
1c48026690 | ||
|
|
a57312786a | ||
|
|
55e3bf0a26 | ||
|
|
743c8c5f67 | ||
|
|
d32f7a4c40 |
@@ -24,6 +24,7 @@ pub fn wav_channel<F>(
|
||||
handler_fn: F
|
||||
) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static {
|
||||
flags.model_path = None;
|
||||
flags.model_buffer = Some(include_bytes!("../../models/ggml-distil-medium.en.bin").to_vec());
|
||||
flags.model_buffer = Some(include_bytes!("../../models/ggml-base.en.bin").to_vec());
|
||||
rust_whisper_lib::wav_channel(flags, handler_fn);
|
||||
}
|
||||
|
||||
53
rust-whisper-baked/Cargo.lock
generated
53
rust-whisper-baked/Cargo.lock
generated
@@ -383,6 +383,12 @@ dependencies = [
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
|
||||
|
||||
[[package]]
|
||||
name = "jni"
|
||||
version = "0.19.0"
|
||||
@@ -765,6 +771,7 @@ dependencies = [
|
||||
"listen-lib",
|
||||
"rust-whisper-baked-lib",
|
||||
"rust-whisper-lib",
|
||||
"stop-words",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -802,6 +809,12 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
@@ -817,6 +830,37 @@ version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.193"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.193"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.41",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb0652c533506ad7a2e353cce269330d6afd8bdfb6d75e0ace5b35aacbd7b9e9"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.2.0"
|
||||
@@ -848,6 +892,15 @@ version = "1.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
|
||||
|
||||
[[package]]
|
||||
name = "stop-words"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8500024d809de02ecbf998472b7bed3c4fca380df2be68917f6a473bdb28ddcc"
|
||||
dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
|
||||
@@ -10,3 +10,4 @@ rust-whisper-lib = { path = "../rust-whisper-lib" }
|
||||
rust-whisper-baked-lib = { path = "../rust-whisper-baked-lib" }
|
||||
listen-lib = { path = "../listen-lib" }
|
||||
clap = { version = "4.4.10", features = ["derive"] }
|
||||
stop-words = "0.8.0"
|
||||
|
||||
@@ -84,52 +84,104 @@ fn channel(flags: rust_whisper_lib::Flags) {
|
||||
}
|
||||
|
||||
struct Destutterer {
|
||||
prev: Option<String>,
|
||||
prev: Words,
|
||||
}
|
||||
|
||||
fn new_destutterer() -> Destutterer {
|
||||
Destutterer{prev: None}
|
||||
Destutterer{prevs: vec![]}
|
||||
}
|
||||
|
||||
impl Destutterer {
|
||||
fn step(&mut self, next: String) -> String {
|
||||
let next = next.trim().to_string();
|
||||
if next.len() == 0 {
|
||||
return next;
|
||||
}
|
||||
match &self.prev {
|
||||
None => {
|
||||
self.prev = Some(next.clone());
|
||||
next
|
||||
},
|
||||
Some(prev) => {
|
||||
let without_trailing_punctuation = {
|
||||
let mut next = next.clone();
|
||||
while next.ends_with("?") || next.ends_with(".") {
|
||||
next = next[..next.len()-1].to_string();
|
||||
}
|
||||
next
|
||||
};
|
||||
let trailing_punctuation = next[without_trailing_punctuation.len() ..].to_string();
|
||||
let next = without_trailing_punctuation;
|
||||
let next = {
|
||||
let mut n = prev.len().clamp(0, next.len());
|
||||
|
||||
let next_words = Words::from_string(next.clone());
|
||||
let mut n = self.prevs.len().clamp(0, next_words.len());
|
||||
while n > 0 {
|
||||
if prev[prev.len() - n..] == next[..n] {
|
||||
let prev_s, _ = self.prevs.last_n_comparable_to_string(n);
|
||||
let next_s, _ = next_words.first_n_comparable_to_string(n);
|
||||
eprintln!("prevs => '{}'", &prev_s);
|
||||
eprintln!("nexts => '{}'", &next_s);
|
||||
if prev_s == next_s {
|
||||
break;
|
||||
}
|
||||
n -= 1;
|
||||
}
|
||||
next[n..].to_string()
|
||||
};
|
||||
if next.len() == 0 {
|
||||
return "".to_string();
|
||||
self.prevs = next_words;
|
||||
Word::to_string(nexts[n..].to_vec())
|
||||
}
|
||||
self.prev = Some(next.clone());
|
||||
next + &trailing_punctuation
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Words {
|
||||
raw: Vec<String>,
|
||||
}
|
||||
|
||||
impl Words {
|
||||
fn from_string(s: String) -> Words {
|
||||
let mut result = Words{raw: vec![]};
|
||||
for word in s.split(" ") {
|
||||
let word = word.trim();
|
||||
if word.len() > 0 {
|
||||
result.raw.push(word.to_string());
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn last_n_comparable_to_string(&self, n: usize) -> (String, usize) {
|
||||
let v = self.to_comparable_words();
|
||||
v = v[(v.len() - n).clamp(0, v.len())..].to_vec();
|
||||
v.iter().map(|x| x.s).collect().join(" "), v[v.len()-1].idx
|
||||
}
|
||||
|
||||
fn first_n_comparable_to_string(&self, n: usize) -> (String, usize){
|
||||
let v = self.to_comparable_words();
|
||||
v = v[0..n.clamp(0, v.len())].to_vec();
|
||||
v.iter().map(|x| x.s).collect().join(" "), v[0].idx
|
||||
}
|
||||
|
||||
fn comparable_len(&self) -> usize {
|
||||
self.to_comparable_words().len()
|
||||
}
|
||||
|
||||
fn to_comparable_words(&self) -> Vec<Word> {
|
||||
self.to_words().iter().filter(|x| x.s.is_some()).collect()
|
||||
}
|
||||
|
||||
fn to_words(&self) -> Vec<Word> {
|
||||
let skips = stop_words::get("en");
|
||||
let strs = self.raw.iter()
|
||||
.map(|w| w.to_lowercase())
|
||||
.map(|word| word.chars().filter(|c| c.is_ascii_alphanumeric()).collect::<String>())
|
||||
.collect::<Vec<String>>();
|
||||
let mut result = vec![];
|
||||
for i in 0..strs.len() {
|
||||
result.push(Word{
|
||||
s: match skips.contains(&strs[i]) {
|
||||
true => None,
|
||||
false => Some(strs[i]),
|
||||
},
|
||||
idx: i as usize,
|
||||
});
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn to_string(&self) -> String {
|
||||
self.raw.iter()
|
||||
.map(|x| x.clone())
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Word {
|
||||
s: Option<String>,
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -137,11 +189,26 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_destutterer() {
|
||||
fn test_destutterer_stop_words() {
|
||||
let mut w = new_destutterer();
|
||||
assert_eq!("abcde".to_string(), w.step("abcde".to_string()));
|
||||
assert_eq!("fg".to_string(), w.step("cdefg".to_string()));
|
||||
assert_eq!("hij".to_string(), w.step("fghij".to_string()));
|
||||
assert_eq!("fghij".to_string(), w.step("fghij".to_string()));
|
||||
assert_eq!("welcome to the internet".to_string(), w.step("welcome to the internet".to_string()));
|
||||
assert_eq!("have a look around".to_string(), w.step("welcome to the a internet; have a look around".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_destutterer_punctuation() {
|
||||
let mut w = new_destutterer();
|
||||
assert_eq!("a, b. c? d!".to_string(), w.step("a, b. c? d!".to_string()));
|
||||
assert_eq!("e! f g".to_string(), w.step("d, e! f g".to_string()));
|
||||
assert_eq!("hij".to_string(), w.step("f g hij".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_destutterer_letters() {
|
||||
let mut w = new_destutterer();
|
||||
assert_eq!("a b c d e".to_string(), w.step("a b c d e".to_string()));
|
||||
assert_eq!("f g".to_string(), w.step(" c d e f g".to_string()));
|
||||
assert_eq!("h i j".to_string(), w.step("f g h i j ".to_string()));
|
||||
assert_eq!("a g h i j".to_string(), w.step("a g h i j".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,13 +14,13 @@ pub struct Flags {
|
||||
#[arg(long, default_value = "8")]
|
||||
pub threads: i32,
|
||||
|
||||
#[arg(long, default_value = "5")]
|
||||
#[arg(long, default_value = "30")]
|
||||
pub stream_step: u64,
|
||||
#[arg(long, default_value = "0.6")]
|
||||
#[arg(long, default_value = "28.0")]
|
||||
pub stream_retain: f32,
|
||||
#[arg(long, default_value = "0.3")]
|
||||
#[arg(long, default_value = "0.1")]
|
||||
pub stream_head: f32,
|
||||
#[arg(long, default_value = "0.3")]
|
||||
#[arg(long, default_value = "0.1")]
|
||||
pub stream_tail: f32,
|
||||
|
||||
#[arg(long, default_value = "false")]
|
||||
|
||||
Reference in New Issue
Block a user