rust-whisper-baked de-stutters wav_channel output

master
Bel LaPointe 2024-01-02 16:10:11 -07:00
parent dd6f980266
commit 6082f7e446
1 changed files with 68 additions and 1 deletions

View File

@ -13,11 +13,15 @@ fn main() {
}
fn wav_channel(flags: rust_whisper_lib::Flags) {
let mut w = new_destutterer();
rust_whisper_baked_lib::wav_channel(
flags.clone(),
move |result: Result<rust_whisper_lib::Transcribed, String>| {
match result {
Ok(transcribed) => { println!("{}", transcribed.to_string()); },
Ok(transcribed) => {
let s = w.step(transcribed.to_string());
println!("{}", s);
},
Err(msg) => { eprintln!("error: {}", msg); },
};
},
@ -74,3 +78,66 @@ fn channel(flags: rust_whisper_lib::Flags) {
}
eprintln!("/listen lib main...");
}
struct Destutterer {
prev: Option<String>,
}
fn new_destutterer() -> Destutterer {
Destutterer{prev: None}
}
impl Destutterer {
fn step(&mut self, next: String) -> String {
let next = next.trim().to_string();
if next.len() == 0 {
return next;
}
match &self.prev {
None => {
self.prev = Some(next.clone());
next
},
Some(prev) => {
let without_trailing_punctuation = {
let mut next = next.clone();
while next.ends_with("?") || next.ends_with(".") {
next = next[..next.len()-1].to_string();
}
next
};
let trailing_punctuation = next[without_trailing_punctuation.len() ..].to_string();
let next = without_trailing_punctuation;
let next = {
let mut n = prev.len().clamp(0, next.len());
while n > 0 {
if prev[prev.len() - n..] == next[..n] {
break;
}
n -= 1;
}
next[n..].to_string()
};
if next.len() == 0 {
return "".to_string();
}
self.prev = Some(next.clone());
next + &trailing_punctuation
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_destutterer() {
let mut w = new_destutterer();
assert_eq!("abcde".to_string(), w.step("abcde".to_string()));
assert_eq!("fg".to_string(), w.step("cdefg".to_string()));
assert_eq!("hij".to_string(), w.step("fghij".to_string()));
assert_eq!("fghij".to_string(), w.step("fghij".to_string()));
}
}