diff --git a/rust-whisper-baked/src/main.rs b/rust-whisper-baked/src/main.rs index 6683fce..b88c327 100644 --- a/rust-whisper-baked/src/main.rs +++ b/rust-whisper-baked/src/main.rs @@ -84,51 +84,64 @@ fn channel(flags: rust_whisper_lib::Flags) { } struct Destutterer { - prev: Option, + prevs: Vec, } fn new_destutterer() -> Destutterer { - Destutterer{prev: None} + Destutterer{prevs: vec![]} } impl Destutterer { fn step(&mut self, next: String) -> String { - let next = next.trim().to_string(); if next.len() == 0 { return next; } - match &self.prev { - None => { - self.prev = Some(next.clone()); - next - }, - Some(prev) => { - let without_trailing_punctuation = { - let mut next = next.clone(); - while next.ends_with("?") || next.ends_with(".") { - next = next[..next.len()-1].to_string(); - } - next - }; - let trailing_punctuation = next[without_trailing_punctuation.len() ..].to_string(); - let next = without_trailing_punctuation.clone(); - let next = { - let mut n = prev.len().clamp(0, next.len()); - while n > 0 { - if prev[prev.len() - n..] == next[..n] { - break; - } - n -= 1; - } - next[n..].to_string() - }; - if next.len() == 0 { - return "".to_string(); - } - self.prev = Some(without_trailing_punctuation); - next + &trailing_punctuation - }, + + let nexts = Word::from_string(next.clone()); + + let mut n = self.prevs.len().clamp(0, nexts.len()); + while n > 0 { + let prev_s = Word::to_comparable_string(self.prevs[self.prevs.len() - n..].to_vec()); + let next_s = Word::to_comparable_string(nexts[..n].to_vec()); + if prev_s == next_s { + break; + } + n -= 1; } + self.prevs = nexts.clone(); + Word::to_string(nexts[n..].to_vec()) + } +} + +#[derive(Clone)] +struct Word { + raw: String, +} + +impl Word { + fn from_string(s: String) -> Vec { + let mut result = vec![]; + for word in s.split(" ") { + let word = word.trim(); + if word.len() > 0 { + result.push(Word{raw: word.to_string()}); + } + } + result + } + + fn to_comparable_string(v: Vec) -> String { + v.iter() + .map(|x| x.raw.chars().filter(|c| c.is_ascii_alphanumeric()).collect()) + .collect::>() + .join(" ") + } + + fn to_string(v: Vec) -> String { + v.iter() + .map(|x| x.raw.clone()) + .collect::>() + .join(" ") } } @@ -137,11 +150,19 @@ mod tests { use super::*; #[test] - fn test_destutterer() { + fn test_destutterer_punctuation() { let mut w = new_destutterer(); - assert_eq!("abcde".to_string(), w.step("abcde".to_string())); - assert_eq!("fg".to_string(), w.step("cdefg".to_string())); - assert_eq!("hij".to_string(), w.step("fghij".to_string())); - assert_eq!("fghij".to_string(), w.step("fghij".to_string())); + assert_eq!("a, b. c? d!".to_string(), w.step("a, b. c? d!".to_string())); + assert_eq!("e! f g".to_string(), w.step("d, e! f g".to_string())); + assert_eq!("hij".to_string(), w.step("f g hij".to_string())); + } + + #[test] + fn test_destutterer_letters() { + let mut w = new_destutterer(); + assert_eq!("a b c d e".to_string(), w.step("a b c d e".to_string())); + assert_eq!("f g".to_string(), w.step(" c d e f g".to_string())); + assert_eq!("h i j".to_string(), w.step("f g h i j ".to_string())); + assert_eq!("a g h i j".to_string(), w.step("a g h i j".to_string())); } }