19 Commits

Author SHA1 Message Date
Bel LaPointe
9bc009996c oop 2024-01-03 08:38:24 -07:00
Bel LaPointe
cbc8a4f9fd cargo run -- --stream-step 8 --stream-retain 4 --stream-head=2 --stream-tail=0 2> /dev/null 2024-01-03 08:37:27 -07:00
Bel LaPointe
a8c8140d18 functionize at least 2024-01-03 08:28:22 -07:00
Bel LaPointe
5bc3209070 x=2; cargo run -- --wav $HOME/Downloads/41A6C472-6E4D-4953-9A90-2497D2DAD8C9.wav --stream-step $((x*4)) --stream-retain $((x*2)) --stream-{head,tail}=$((x)) 2> /dev/null 2024-01-03 08:22:45 -07:00
Bel LaPointe
8b5c18e65e todo 2024-01-03 08:22:15 -07:00
Bel LaPointe
ec47d8142a destutter with stopwords impl 2024-01-03 07:54:21 -07:00
bel
03659164ba wip 2024-01-02 21:14:02 -07:00
bel
709dd1dba3 tod 2024-01-02 21:12:33 -07:00
bel
26595396cf tod 2024-01-02 21:01:01 -07:00
Bel LaPointe
fb7892b52b todo 2024-01-02 18:38:13 -07:00
Bel LaPointe
b08e055dac todo 2024-01-02 18:23:03 -07:00
Bel LaPointe
9d993cfc8a update destutterer to do punctuation-free words 2024-01-02 18:20:46 -07:00
Bel LaPointe
f4f8ea429a merge 2024-01-02 17:51:29 -07:00
Bel LaPointe
38bea3735f todo 2024-01-02 17:51:14 -07:00
Bel LaPointe
1c48026690 need to overlap without ANY puctuation, which i can do by breaking into words 2024-01-02 17:49:47 -07:00
Bel LaPointe
a57312786a gr 2024-01-02 17:48:17 -07:00
Bel LaPointe
55e3bf0a26 update defaults 2024-01-02 17:47:00 -07:00
Bel LaPointe
743c8c5f67 time cargo run -- --wav $HOME/Downloads/41A6C472-6E4D-4953-9A90-2497D2DAD8C9.wav --stream-step 30 --stream-retain 25 --stream-{head,tail}=1 2> /dev/null 2024-01-02 16:45:04 -07:00
Bel LaPointe
d32f7a4c40 destutterer doesnt drop stutter for prev 2024-01-02 16:36:39 -07:00
6 changed files with 199 additions and 48 deletions

View File

@@ -6,7 +6,7 @@ pub fn channel<F>(
stream: std::sync::mpsc::Receiver<Vec<f32>>, stream: std::sync::mpsc::Receiver<Vec<f32>>,
) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static { ) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static {
flags.model_path = None; flags.model_path = None;
flags.model_buffer = Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()); flags.model_buffer = Some(get_fast());
rust_whisper_lib::channel(flags.clone(), handler_fn, stream); rust_whisper_lib::channel(flags.clone(), handler_fn, stream);
} }
@@ -15,7 +15,7 @@ pub fn wav<F>(
handler_fn: F handler_fn: F
) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static { ) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static {
flags.model_path = None; flags.model_path = None;
flags.model_buffer = Some(include_bytes!("../../models/ggml-distil-medium.en.bin").to_vec()); flags.model_buffer = Some(get_good());
rust_whisper_lib::wav(flags.clone(), handler_fn, flags.wav.unwrap()); rust_whisper_lib::wav(flags.clone(), handler_fn, flags.wav.unwrap());
} }
@@ -24,10 +24,18 @@ pub fn wav_channel<F>(
handler_fn: F handler_fn: F
) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static { ) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static {
flags.model_path = None; flags.model_path = None;
flags.model_buffer = Some(include_bytes!("../../models/ggml-base.en.bin").to_vec()); flags.model_buffer = Some(get_good());
rust_whisper_lib::wav_channel(flags, handler_fn); rust_whisper_lib::wav_channel(flags, handler_fn);
} }
pub fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> { pub fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
rust_whisper_lib::f32_from_wav_file(path) rust_whisper_lib::f32_from_wav_file(path)
} }
fn get_fast() -> Vec<u8> {
include_bytes!("../../models/ggml-small.en.bin").to_vec()
}
fn get_good() -> Vec<u8> {
include_bytes!("../../models/ggml-distil-medium.en.bin").to_vec()
}

View File

@@ -383,6 +383,12 @@ dependencies = [
"hashbrown", "hashbrown",
] ]
[[package]]
name = "itoa"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
[[package]] [[package]]
name = "jni" name = "jni"
version = "0.19.0" version = "0.19.0"
@@ -765,6 +771,7 @@ dependencies = [
"listen-lib", "listen-lib",
"rust-whisper-baked-lib", "rust-whisper-baked-lib",
"rust-whisper-lib", "rust-whisper-lib",
"stop-words",
] ]
[[package]] [[package]]
@@ -802,6 +809,12 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "ryu"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c"
[[package]] [[package]]
name = "same-file" name = "same-file"
version = "1.0.6" version = "1.0.6"
@@ -817,6 +830,37 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
version = "1.0.193"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.193"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.41",
]
[[package]]
name = "serde_json"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb0652c533506ad7a2e353cce269330d6afd8bdfb6d75e0ace5b35aacbd7b9e9"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]] [[package]]
name = "shlex" name = "shlex"
version = "1.2.0" version = "1.2.0"
@@ -848,6 +892,15 @@ version = "1.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
[[package]]
name = "stop-words"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8500024d809de02ecbf998472b7bed3c4fca380df2be68917f6a473bdb28ddcc"
dependencies = [
"serde_json",
]
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.10.0" version = "0.10.0"

View File

@@ -10,3 +10,4 @@ rust-whisper-lib = { path = "../rust-whisper-lib" }
rust-whisper-baked-lib = { path = "../rust-whisper-baked-lib" } rust-whisper-baked-lib = { path = "../rust-whisper-baked-lib" }
listen-lib = { path = "../listen-lib" } listen-lib = { path = "../listen-lib" }
clap = { version = "4.4.10", features = ["derive"] } clap = { version = "4.4.10", features = ["derive"] }
stop-words = "0.8.0"

View File

@@ -48,11 +48,15 @@ fn channel(flags: rust_whisper_lib::Flags) {
eprintln!("rust whisper baked lib channel..."); eprintln!("rust whisper baked lib channel...");
thread::spawn(move || { thread::spawn(move || {
let mut w = new_destutterer();
rust_whisper_baked_lib::channel( rust_whisper_baked_lib::channel(
flags.clone(), flags.clone(),
|result: Result<rust_whisper_lib::Transcribed, String>| { move |result: Result<rust_whisper_lib::Transcribed, String>| {
match result { match result {
Ok(transcribed) => { println!("{}", transcribed.to_string()); }, Ok(transcribed) => {
let s = w.step(transcribed.to_string());
println!("{}", s);
},
Err(msg) => { eprintln!("error: {}", msg); }, Err(msg) => { eprintln!("error: {}", msg); },
}; };
}, },
@@ -84,52 +88,114 @@ fn channel(flags: rust_whisper_lib::Flags) {
} }
struct Destutterer { struct Destutterer {
prev: Option<String>, prev: Words,
} }
fn new_destutterer() -> Destutterer { fn new_destutterer() -> Destutterer {
Destutterer{prev: None} Destutterer{prev: new_words()}
} }
impl Destutterer { impl Destutterer {
fn step(&mut self, next: String) -> String { fn step(&mut self, next: String) -> String {
let next = next.trim().to_string();
if next.len() == 0 { if next.len() == 0 {
return next; return next;
} }
match &self.prev {
None => { let next_words = Words::from_string(next.clone());
self.prev = Some(next.clone()); let mut n = self.prev.comparable_len().clamp(0, next_words.comparable_len());
next //println!("n={} prev='{:?}' next='{:?}'", n, self.prev.to_comparable_words(), next_words.to_comparable_words());
},
Some(prev) => {
let without_trailing_punctuation = {
let mut next = next.clone();
while next.ends_with("?") || next.ends_with(".") {
next = next[..next.len()-1].to_string();
}
next
};
let trailing_punctuation = next[without_trailing_punctuation.len() ..].to_string();
let next = without_trailing_punctuation;
let next = {
let mut n = prev.len().clamp(0, next.len());
while n > 0 { while n > 0 {
if prev[prev.len() - n..] == next[..n] { let (prev_s, _) = self.prev.last_n_comparable_to_string(n);
break; let (next_s, next_idx) = next_words.first_n_comparable_to_string(n);
if prev_s == next_s {
self.prev = next_words;
return self.prev.skip(next_idx+1).to_string();
} }
n -= 1; n -= 1;
} }
next[n..].to_string() self.prev = next_words;
}; self.prev.to_string()
if next.len() == 0 {
return "".to_string();
} }
self.prev = Some(next.clone()); }
next + &trailing_punctuation
#[derive(Clone, Debug)]
struct Words {
raw: Vec<String>,
}
fn new_words() -> Words {
Words{raw: vec![]}
}
impl Words {
fn from_string(s: String) -> Words {
let mut result = Words{raw: vec![]};
for word in s.split(" ") {
let word = word.trim();
if word.len() > 0 {
result.raw.push(word.to_string());
}
}
result
}
fn skip(&self, n: usize) -> Words {
Words{
raw: self.raw.iter().skip(n).map(|x| x.clone()).collect(),
}
}
fn last_n_comparable_to_string(&self, n: usize) -> (String, usize) {
let v = self.to_comparable_words();
let v = v[(v.len() - n).clamp(0, v.len())..].to_vec();
return (v.iter().map(|x| x.s.clone().unwrap()).collect::<Vec<String>>().join(" "), v[0].idx)
}
fn first_n_comparable_to_string(&self, n: usize) -> (String, usize){
let v = self.to_comparable_words();
let v = v[0..n.clamp(0, v.len())].to_vec();
return (v.iter().map(|x| x.s.clone().unwrap()).collect::<Vec<String>>().join(" "), v[v.len()-1].idx)
}
fn comparable_len(&self) -> usize {
self.to_comparable_words().len()
}
fn to_comparable_words(&self) -> Vec<Word> {
self.to_words().iter().filter(|x| x.s.is_some()).map(|x| x.clone()).collect()
}
fn to_words(&self) -> Vec<Word> {
let skips = stop_words::get("en");
let strs = self.raw.iter()
.map(|w| w.to_lowercase())
.map(|word| word.chars().filter(|c| c.is_ascii_alphanumeric()).collect::<String>())
.collect::<Vec<String>>();
let mut result = vec![];
for i in 0..strs.len() {
result.push(Word{
s: match skips.contains(&strs[i]) {
true => None,
false => Some(strs[i].clone()),
}, },
idx: i as usize,
});
}
result
}
fn to_string(&self) -> String {
self.raw.iter()
.map(|x| x.clone())
.collect::<Vec<String>>()
.join(" ")
} }
} }
#[derive(Debug, Clone)]
struct Word {
s: Option<String>,
idx: usize,
} }
#[cfg(test)] #[cfg(test)]
@@ -137,11 +203,24 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_destutterer() { fn test_destutterer_stop_words() {
let mut w = new_destutterer(); let mut w = new_destutterer();
assert_eq!("abcde".to_string(), w.step("abcde".to_string())); assert_eq!("welcome to the internet".to_string(), w.step("welcome to the internet".to_string()));
assert_eq!("fg".to_string(), w.step("cdefg".to_string())); assert_eq!("have a look around".to_string(), w.step("welcome to the a internet; have a look around".to_string()));
assert_eq!("hij".to_string(), w.step("fghij".to_string())); }
assert_eq!("fghij".to_string(), w.step("fghij".to_string()));
#[test]
fn test_destutterer_punctuation() {
let mut w = new_destutterer();
assert_eq!("cat, dog. cow? moose!".to_string(), w.step("cat, dog. cow? moose!".to_string()));
assert_eq!("elephant! fez gator".to_string(), w.step("moose, elephant! fez gator".to_string()));
assert_eq!("hij".to_string(), w.step("fez gator hij".to_string()));
}
#[test]
fn test_destutterer_basic() {
let mut w = new_destutterer();
assert_eq!("cat dog cow".to_string(), w.step(" cat dog cow ".to_string()));
assert_eq!("moose".to_string(), w.step(" dog cow moose ".to_string()));
} }
} }

View File

@@ -14,13 +14,13 @@ pub struct Flags {
#[arg(long, default_value = "8")] #[arg(long, default_value = "8")]
pub threads: i32, pub threads: i32,
#[arg(long, default_value = "5")] #[arg(long, default_value = "30")]
pub stream_step: u64, pub stream_step: u64,
#[arg(long, default_value = "0.6")] #[arg(long, default_value = "28.0")]
pub stream_retain: f32, pub stream_retain: f32,
#[arg(long, default_value = "0.3")] #[arg(long, default_value = "0.1")]
pub stream_head: f32, pub stream_head: f32,
#[arg(long, default_value = "0.3")] #[arg(long, default_value = "0.1")]
pub stream_tail: f32, pub stream_tail: f32,
#[arg(long, default_value = "false")] #[arg(long, default_value = "false")]

10
todo.yaml Executable file
View File

@@ -0,0 +1,10 @@
todo:
- whisper trims outside silence so head and tail never get hit
- split on silence-ish instead of duration
- rust-whisper warn when transcription time ~ input time
scheduled: []
done:
- todo: need to overlap without ANY puctuation, which i can do by breaking into words
ts: Tue Jan 2 18:23:00 MST 2024
- todo: overlap without stop words
ts: Wed Jan 3 08:22:14 MST 2024