13 Commits

Author SHA1 Message Date
bel
03659164ba wip 2024-01-02 21:14:02 -07:00
bel
709dd1dba3 tod 2024-01-02 21:12:33 -07:00
bel
26595396cf tod 2024-01-02 21:01:01 -07:00
Bel LaPointe
fb7892b52b todo 2024-01-02 18:38:13 -07:00
Bel LaPointe
b08e055dac todo 2024-01-02 18:23:03 -07:00
Bel LaPointe
9d993cfc8a update destutterer to do punctuation-free words 2024-01-02 18:20:46 -07:00
Bel LaPointe
f4f8ea429a merge 2024-01-02 17:51:29 -07:00
Bel LaPointe
38bea3735f todo 2024-01-02 17:51:14 -07:00
Bel LaPointe
1c48026690 need to overlap without ANY puctuation, which i can do by breaking into words 2024-01-02 17:49:47 -07:00
Bel LaPointe
a57312786a gr 2024-01-02 17:48:17 -07:00
Bel LaPointe
55e3bf0a26 update defaults 2024-01-02 17:47:00 -07:00
Bel LaPointe
743c8c5f67 time cargo run -- --wav $HOME/Downloads/41A6C472-6E4D-4953-9A90-2497D2DAD8C9.wav --stream-step 30 --stream-retain 25 --stream-{head,tail}=1 2> /dev/null 2024-01-02 16:45:04 -07:00
Bel LaPointe
d32f7a4c40 destutterer doesnt drop stutter for prev 2024-01-02 16:36:39 -07:00
6 changed files with 172 additions and 43 deletions

View File

@@ -24,6 +24,7 @@ pub fn wav_channel<F>(
handler_fn: F handler_fn: F
) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static { ) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static {
flags.model_path = None; flags.model_path = None;
flags.model_buffer = Some(include_bytes!("../../models/ggml-distil-medium.en.bin").to_vec());
flags.model_buffer = Some(include_bytes!("../../models/ggml-base.en.bin").to_vec()); flags.model_buffer = Some(include_bytes!("../../models/ggml-base.en.bin").to_vec());
rust_whisper_lib::wav_channel(flags, handler_fn); rust_whisper_lib::wav_channel(flags, handler_fn);
} }

View File

@@ -383,6 +383,12 @@ dependencies = [
"hashbrown", "hashbrown",
] ]
[[package]]
name = "itoa"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
[[package]] [[package]]
name = "jni" name = "jni"
version = "0.19.0" version = "0.19.0"
@@ -765,6 +771,7 @@ dependencies = [
"listen-lib", "listen-lib",
"rust-whisper-baked-lib", "rust-whisper-baked-lib",
"rust-whisper-lib", "rust-whisper-lib",
"stop-words",
] ]
[[package]] [[package]]
@@ -802,6 +809,12 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "ryu"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c"
[[package]] [[package]]
name = "same-file" name = "same-file"
version = "1.0.6" version = "1.0.6"
@@ -817,6 +830,37 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
version = "1.0.193"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.193"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.41",
]
[[package]]
name = "serde_json"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb0652c533506ad7a2e353cce269330d6afd8bdfb6d75e0ace5b35aacbd7b9e9"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]] [[package]]
name = "shlex" name = "shlex"
version = "1.2.0" version = "1.2.0"
@@ -848,6 +892,15 @@ version = "1.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
[[package]]
name = "stop-words"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8500024d809de02ecbf998472b7bed3c4fca380df2be68917f6a473bdb28ddcc"
dependencies = [
"serde_json",
]
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.10.0" version = "0.10.0"

View File

@@ -10,3 +10,4 @@ rust-whisper-lib = { path = "../rust-whisper-lib" }
rust-whisper-baked-lib = { path = "../rust-whisper-baked-lib" } rust-whisper-baked-lib = { path = "../rust-whisper-baked-lib" }
listen-lib = { path = "../listen-lib" } listen-lib = { path = "../listen-lib" }
clap = { version = "4.4.10", features = ["derive"] } clap = { version = "4.4.10", features = ["derive"] }
stop-words = "0.8.0"

View File

@@ -84,64 +84,131 @@ fn channel(flags: rust_whisper_lib::Flags) {
} }
struct Destutterer { struct Destutterer {
prev: Option<String>, prev: Words,
} }
fn new_destutterer() -> Destutterer { fn new_destutterer() -> Destutterer {
Destutterer{prev: None} Destutterer{prevs: vec![]}
} }
impl Destutterer { impl Destutterer {
fn step(&mut self, next: String) -> String { fn step(&mut self, next: String) -> String {
let next = next.trim().to_string();
if next.len() == 0 { if next.len() == 0 {
return next; return next;
} }
match &self.prev {
None => { let next_words = Words::from_string(next.clone());
self.prev = Some(next.clone()); let mut n = self.prevs.len().clamp(0, next_words.len());
next while n > 0 {
}, let prev_s, _ = self.prevs.last_n_comparable_to_string(n);
Some(prev) => { let next_s, _ = next_words.first_n_comparable_to_string(n);
let without_trailing_punctuation = { eprintln!("prevs => '{}'", &prev_s);
let mut next = next.clone(); eprintln!("nexts => '{}'", &next_s);
while next.ends_with("?") || next.ends_with(".") { if prev_s == next_s {
next = next[..next.len()-1].to_string(); break;
} }
next n -= 1;
};
let trailing_punctuation = next[without_trailing_punctuation.len() ..].to_string();
let next = without_trailing_punctuation;
let next = {
let mut n = prev.len().clamp(0, next.len());
while n > 0 {
if prev[prev.len() - n..] == next[..n] {
break;
}
n -= 1;
}
next[n..].to_string()
};
if next.len() == 0 {
return "".to_string();
}
self.prev = Some(next.clone());
next + &trailing_punctuation
},
} }
self.prevs = next_words;
Word::to_string(nexts[n..].to_vec())
} }
} }
#[derive(Clone, Debug)]
struct Words {
raw: Vec<String>,
}
impl Words {
fn from_string(s: String) -> Words {
let mut result = Words{raw: vec![]};
for word in s.split(" ") {
let word = word.trim();
if word.len() > 0 {
result.raw.push(word.to_string());
}
}
result
}
fn last_n_comparable_to_string(&self, n: usize) -> (String, usize) {
let v = self.to_comparable_words();
v = v[(v.len() - n).clamp(0, v.len())..].to_vec();
v.iter().map(|x| x.s).collect().join(" "), v[v.len()-1].idx
}
fn first_n_comparable_to_string(&self, n: usize) -> (String, usize){
let v = self.to_comparable_words();
v = v[0..n.clamp(0, v.len())].to_vec();
v.iter().map(|x| x.s).collect().join(" "), v[0].idx
}
fn comparable_len(&self) -> usize {
self.to_comparable_words().len()
}
fn to_comparable_words(&self) -> Vec<Word> {
self.to_words().iter().filter(|x| x.s.is_some()).collect()
}
fn to_words(&self) -> Vec<Word> {
let skips = stop_words::get("en");
let strs = self.raw.iter()
.map(|w| w.to_lowercase())
.map(|word| word.chars().filter(|c| c.is_ascii_alphanumeric()).collect::<String>())
.collect::<Vec<String>>();
let mut result = vec![];
for i in 0..strs.len() {
result.push(Word{
s: match skips.contains(&strs[i]) {
true => None,
false => Some(strs[i]),
},
idx: i as usize,
});
}
result
}
fn to_string(&self) -> String {
self.raw.iter()
.map(|x| x.clone())
.collect::<Vec<String>>()
.join(" ")
}
}
#[derive(Debug)]
struct Word {
s: Option<String>,
idx: usize,
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_destutterer() { fn test_destutterer_stop_words() {
let mut w = new_destutterer(); let mut w = new_destutterer();
assert_eq!("abcde".to_string(), w.step("abcde".to_string())); assert_eq!("welcome to the internet".to_string(), w.step("welcome to the internet".to_string()));
assert_eq!("fg".to_string(), w.step("cdefg".to_string())); assert_eq!("have a look around".to_string(), w.step("welcome to the a internet; have a look around".to_string()));
assert_eq!("hij".to_string(), w.step("fghij".to_string())); }
assert_eq!("fghij".to_string(), w.step("fghij".to_string()));
#[test]
fn test_destutterer_punctuation() {
let mut w = new_destutterer();
assert_eq!("a, b. c? d!".to_string(), w.step("a, b. c? d!".to_string()));
assert_eq!("e! f g".to_string(), w.step("d, e! f g".to_string()));
assert_eq!("hij".to_string(), w.step("f g hij".to_string()));
}
#[test]
fn test_destutterer_letters() {
let mut w = new_destutterer();
assert_eq!("a b c d e".to_string(), w.step("a b c d e".to_string()));
assert_eq!("f g".to_string(), w.step(" c d e f g".to_string()));
assert_eq!("h i j".to_string(), w.step("f g h i j ".to_string()));
assert_eq!("a g h i j".to_string(), w.step("a g h i j".to_string()));
} }
} }

View File

@@ -14,13 +14,13 @@ pub struct Flags {
#[arg(long, default_value = "8")] #[arg(long, default_value = "8")]
pub threads: i32, pub threads: i32,
#[arg(long, default_value = "5")] #[arg(long, default_value = "30")]
pub stream_step: u64, pub stream_step: u64,
#[arg(long, default_value = "0.6")] #[arg(long, default_value = "28.0")]
pub stream_retain: f32, pub stream_retain: f32,
#[arg(long, default_value = "0.3")] #[arg(long, default_value = "0.1")]
pub stream_head: f32, pub stream_head: f32,
#[arg(long, default_value = "0.3")] #[arg(long, default_value = "0.1")]
pub stream_tail: f32, pub stream_tail: f32,
#[arg(long, default_value = "false")] #[arg(long, default_value = "false")]

7
todo.yaml Executable file
View File

@@ -0,0 +1,7 @@
todo:
- overlap without stop words
- rust-whisper warn when transcription time ~ input time
scheduled: []
done:
- todo: need to overlap without ANY puctuation, which i can do by breaking into words
ts: Tue Jan 2 18:23:00 MST 2024