rust-whisper-baked de-stutters wav_channel output
parent
dd6f980266
commit
6082f7e446
|
|
@ -13,11 +13,15 @@ fn main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wav_channel(flags: rust_whisper_lib::Flags) {
|
fn wav_channel(flags: rust_whisper_lib::Flags) {
|
||||||
|
let mut w = new_destutterer();
|
||||||
rust_whisper_baked_lib::wav_channel(
|
rust_whisper_baked_lib::wav_channel(
|
||||||
flags.clone(),
|
flags.clone(),
|
||||||
move |result: Result<rust_whisper_lib::Transcribed, String>| {
|
move |result: Result<rust_whisper_lib::Transcribed, String>| {
|
||||||
match result {
|
match result {
|
||||||
Ok(transcribed) => { println!("{}", transcribed.to_string()); },
|
Ok(transcribed) => {
|
||||||
|
let s = w.step(transcribed.to_string());
|
||||||
|
println!("{}", s);
|
||||||
|
},
|
||||||
Err(msg) => { eprintln!("error: {}", msg); },
|
Err(msg) => { eprintln!("error: {}", msg); },
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
|
@ -74,3 +78,66 @@ fn channel(flags: rust_whisper_lib::Flags) {
|
||||||
}
|
}
|
||||||
eprintln!("/listen lib main...");
|
eprintln!("/listen lib main...");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Destutterer {
|
||||||
|
prev: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_destutterer() -> Destutterer {
|
||||||
|
Destutterer{prev: None}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Destutterer {
|
||||||
|
fn step(&mut self, next: String) -> String {
|
||||||
|
let next = next.trim().to_string();
|
||||||
|
if next.len() == 0 {
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
match &self.prev {
|
||||||
|
None => {
|
||||||
|
self.prev = Some(next.clone());
|
||||||
|
next
|
||||||
|
},
|
||||||
|
Some(prev) => {
|
||||||
|
let without_trailing_punctuation = {
|
||||||
|
let mut next = next.clone();
|
||||||
|
while next.ends_with("?") || next.ends_with(".") {
|
||||||
|
next = next[..next.len()-1].to_string();
|
||||||
|
}
|
||||||
|
next
|
||||||
|
};
|
||||||
|
let trailing_punctuation = next[without_trailing_punctuation.len() ..].to_string();
|
||||||
|
let next = without_trailing_punctuation;
|
||||||
|
let next = {
|
||||||
|
let mut n = prev.len().clamp(0, next.len());
|
||||||
|
while n > 0 {
|
||||||
|
if prev[prev.len() - n..] == next[..n] {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
n -= 1;
|
||||||
|
}
|
||||||
|
next[n..].to_string()
|
||||||
|
};
|
||||||
|
if next.len() == 0 {
|
||||||
|
return "".to_string();
|
||||||
|
}
|
||||||
|
self.prev = Some(next.clone());
|
||||||
|
next + &trailing_punctuation
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_destutterer() {
|
||||||
|
let mut w = new_destutterer();
|
||||||
|
assert_eq!("abcde".to_string(), w.step("abcde".to_string()));
|
||||||
|
assert_eq!("fg".to_string(), w.step("cdefg".to_string()));
|
||||||
|
assert_eq!("hij".to_string(), w.step("fghij".to_string()));
|
||||||
|
assert_eq!("fghij".to_string(), w.step("fghij".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue