148 lines
4.0 KiB
Rust
148 lines
4.0 KiB
Rust
use rust_whisper_lib;
|
|
use rust_whisper_baked_lib;
|
|
use clap::Parser;
|
|
use listen_lib;
|
|
use std::thread;
|
|
|
|
fn main() {
|
|
let flags = rust_whisper_lib::Flags::parse();
|
|
match flags.wav.clone() {
|
|
Some(_) => wav_channel(flags),
|
|
None => channel(flags),
|
|
};
|
|
}
|
|
|
|
fn wav_channel(flags: rust_whisper_lib::Flags) {
|
|
let mut w = new_destutterer();
|
|
rust_whisper_baked_lib::wav_channel(
|
|
flags.clone(),
|
|
move |result: Result<rust_whisper_lib::Transcribed, String>| {
|
|
match result {
|
|
Ok(transcribed) => {
|
|
let s = w.step(transcribed.to_string());
|
|
println!("{}", s);
|
|
},
|
|
Err(msg) => { eprintln!("error: {}", msg); },
|
|
};
|
|
},
|
|
);
|
|
}
|
|
|
|
fn wav(flags: rust_whisper_lib::Flags, _path: String) {
|
|
let mut w = new_destutterer();
|
|
rust_whisper_baked_lib::wav(flags,
|
|
move |result: Result<rust_whisper_lib::Transcribed, String>| {
|
|
match result {
|
|
Ok(transcribed) => {
|
|
let s = w.step(transcribed.to_string());
|
|
println!("{}", s);
|
|
},
|
|
Err(msg) => { eprintln!("error: {}", msg); },
|
|
};
|
|
},
|
|
);
|
|
}
|
|
|
|
fn channel(flags: rust_whisper_lib::Flags) {
|
|
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
|
|
|
eprintln!("rust whisper baked lib channel...");
|
|
thread::spawn(move || {
|
|
rust_whisper_baked_lib::channel(
|
|
flags.clone(),
|
|
|result: Result<rust_whisper_lib::Transcribed, String>| {
|
|
match result {
|
|
Ok(transcribed) => { println!("{}", transcribed.to_string()); },
|
|
Err(msg) => { eprintln!("error: {}", msg); },
|
|
};
|
|
},
|
|
recv,
|
|
);
|
|
});
|
|
|
|
eprintln!("listen lib main...");
|
|
let flags = rust_whisper_lib::Flags::parse();
|
|
match flags.stream_device {
|
|
Some(device_name) => {
|
|
if device_name == "" {
|
|
for device in listen_lib::devices() {
|
|
eprintln!("{}", device);
|
|
}
|
|
} else {
|
|
listen_lib::main_with(|data| {
|
|
send.send(data).unwrap();
|
|
}, device_name);
|
|
}
|
|
},
|
|
None => {
|
|
listen_lib::main(|data| {
|
|
send.send(data).unwrap();
|
|
});
|
|
}
|
|
}
|
|
eprintln!("/listen lib main...");
|
|
}
|
|
|
|
struct Destutterer {
|
|
prev: Option<String>,
|
|
}
|
|
|
|
fn new_destutterer() -> Destutterer {
|
|
Destutterer{prev: None}
|
|
}
|
|
|
|
impl Destutterer {
|
|
fn step(&mut self, next: String) -> String {
|
|
let next = next.trim().to_string();
|
|
if next.len() == 0 {
|
|
return next;
|
|
}
|
|
match &self.prev {
|
|
None => {
|
|
self.prev = Some(next.clone());
|
|
next
|
|
},
|
|
Some(prev) => {
|
|
let without_trailing_punctuation = {
|
|
let mut next = next.clone();
|
|
while next.ends_with("?") || next.ends_with(".") {
|
|
next = next[..next.len()-1].to_string();
|
|
}
|
|
next
|
|
};
|
|
let trailing_punctuation = next[without_trailing_punctuation.len() ..].to_string();
|
|
let next = without_trailing_punctuation.clone();
|
|
let next = {
|
|
let mut n = prev.len().clamp(0, next.len());
|
|
while n > 0 {
|
|
if prev[prev.len() - n..] == next[..n] {
|
|
break;
|
|
}
|
|
n -= 1;
|
|
}
|
|
next[n..].to_string()
|
|
};
|
|
if next.len() == 0 {
|
|
return "".to_string();
|
|
}
|
|
self.prev = Some(without_trailing_punctuation);
|
|
next + &trailing_punctuation
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_destutterer() {
|
|
let mut w = new_destutterer();
|
|
assert_eq!("abcde".to_string(), w.step("abcde".to_string()));
|
|
assert_eq!("fg".to_string(), w.step("cdefg".to_string()));
|
|
assert_eq!("hij".to_string(), w.step("fghij".to_string()));
|
|
assert_eq!("fghij".to_string(), w.step("fghij".to_string()));
|
|
}
|
|
}
|