diff --git a/src/main.rs b/src/main.rs index 0ba8fb7..fcb3751 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,17 +31,10 @@ struct Flags { debug: bool, } -static mut G_HANDLER: Handler = new_handler(&0.0, &0.0); - fn main() { let flags = Flags::parse(); - unsafe { - G_HANDLER.head = flags.stream_head.clone(); - G_HANDLER.tail = flags.stream_tail.clone(); - } - - let w = new_whisper(flags.model, flags.threads, new_handler(&flags.stream_head, &flags.stream_tail)).unwrap(); + let w = new_whisper_service(flags.model, flags.threads, flags.stream_head, flags.stream_tail, new_handler()).unwrap(); let stream_retain = (flags.stream_retain * 16_000.0) as usize; let stream_step = Duration::new(flags.stream_step, 0); match flags.wav { @@ -66,15 +59,7 @@ fn main() { new_listener().listen(move |data: Vec| { data.iter().for_each(|x| buffer.push(*x)); if Instant::now() - last > stream_step { - match w.transcribe_async(&buffer, |result| { - match result { - Ok(whispered) => unsafe { G_HANDLER.on_success(&whispered) }, - Err(msg) => unsafe { G_HANDLER.on_error(msg) }, - }; - }) { - Ok(_) => (), - Err(msg) => eprintln!("{}", msg), - }; + w.transcribe_async(&buffer).unwrap(); match &flags.debug { true => { @@ -101,43 +86,37 @@ fn main() { }; } -struct Whisper { +struct WhisperService { jobs: std::sync::mpsc::SyncSender, } -struct WhisperEngine { - ctx: WhisperContext, - threads: i32, - handler: Handler, -} - -fn new_whisper(model_path: String, threads: i32, handler: Handler) -> Result { - match new_whisper_engine(model_path, threads, handler) { +fn new_whisper_service(model_path: String, threads: i32, stream_head: f32, stream_tail: f32, handler: Handler) -> Result { + match new_whisper_engine(model_path, threads) { Ok(engine) => { + let whisper = new_whisper_impl(engine, stream_head, stream_tail, handler); let (send, recv) = std::sync::mpsc::sync_channel(100); - thread::spawn(move || { engine.transcribe_asyncs(recv); }); - Ok(Whisper{jobs: send}) + thread::spawn(move || { whisper.transcribe_asyncs(recv); }); + Ok(WhisperService{jobs: send}) }, Err(msg) => Err(format!("failed to initialize engine: {}", msg)), } } -impl Whisper { +impl WhisperService { fn transcribe(&self, data: &Vec) { let (send, recv) = std::sync::mpsc::sync_channel(0); - self._transcribe_async(data, Some(send), None).unwrap(); + self._transcribe_async(data, Some(send)).unwrap(); recv.recv().unwrap(); } - fn transcribe_async(&self, data: &Vec, callback: fn(Result)) -> Result<(), String> { - self._transcribe_async(data, None, Some(callback)) + fn transcribe_async(&self, data: &Vec) -> Result<(), String> { + self._transcribe_async(data, None) } - fn _transcribe_async(&self, data: &Vec, ack: Option>, callback: Option)>) -> Result<(), String> { + fn _transcribe_async(&self, data: &Vec, ack: Option>) -> Result<(), String> { match self.jobs.try_send(AWhisper{ data: data.clone().to_vec(), ack: ack, - callback: callback, }) { Ok(_) => Ok(()), Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)), @@ -145,48 +124,80 @@ impl Whisper { } } -fn new_whisper_engine(model_path: String, threads: i32, handler: Handler) -> Result { - match WhisperContext::new(&model_path) { - Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads, handler: handler}), - Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), +struct WhisperImpl { + engine: WhisperEngine, + stream_head: f32, + stream_tail: f32, + handler: Handler, +} + +fn new_whisper_impl(engine: WhisperEngine, stream_head: f32, stream_tail: f32, handler: Handler) -> WhisperImpl { + WhisperImpl { + engine: engine, + stream_head: stream_head, + stream_tail: stream_tail, + handler: handler, } } -impl WhisperEngine { +impl WhisperImpl { fn transcribe_asyncs(&self, recv: std::sync::mpsc::Receiver) { loop { match recv.recv() { Ok(job) => { - match self.transcribe(&job.data) { - Ok(result) => { - self.handler.on_success(&result); - match job.ack { - Some(ack) => { let _ = ack.send(true); }, - None => (), - }; - match job.callback { - Some(foo) => foo(Ok(result)), - None => (), - }; - }, - Err(msg) => { - self.handler.on_error(format!("failed to transcribe: {}", &msg)); - match job.ack { - Some(ack) => { let _ = ack.send(false); }, - None => (), - }; - match job.callback { - Some(foo) => foo(Err(format!("failed to transcribe: {}", &msg))), - None => (), - }; + let result = self.transcribe(&job).is_ok(); + match job.ack { + Some(ack) => { + ack.send(result).unwrap(); }, + None => (), }; - }, + } Err(_) => return, }; } } + fn transcribe(&self, a_whisper: &AWhisper) -> Result<(), ()> { + match self.engine.transcribe(&a_whisper.data) { + Ok(result) => { + self.on_success(&result); + Ok(()) + }, + Err(msg) => { + self.on_error(msg.to_string()); + Err(()) + }, + } + } + + fn on_success(&self, whispered: &Whispered) { + eprintln!("{}: {:?}", chrono::Local::now(), whispered); + self.handler.on_success( + &whispered + .after(&(self.stream_head * 100.0)) + .before(&(self.stream_tail * 100.0)), + ); + } + + fn on_error(&self, msg: String) { + self.handler.on_error(format!("failed to transcribe: {}", &msg)); + } +} + +struct WhisperEngine { + ctx: WhisperContext, + threads: i32, +} + +fn new_whisper_engine(model_path: String, threads: i32) -> Result { + match WhisperContext::new(&model_path) { + Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}), + Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), + } +} + +impl WhisperEngine { fn transcribe(&self, data: &Vec) -> Result { let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); params.set_no_context(true); @@ -218,7 +229,6 @@ impl WhisperEngine { struct AWhisper { data: Vec, ack: Option>, - callback: Option)>, } #[derive(Clone, Debug)] @@ -289,24 +299,18 @@ impl Whispered { } #[derive(Clone)] -struct Handler { - head: f32, - tail: f32, -} +struct Handler {} -const fn new_handler(seconds_head: &f32, seconds_tail: &f32) -> Handler { - Handler{ - head: *seconds_head, - tail: *seconds_tail, - } +const fn new_handler() -> Handler { + Handler{} } impl Handler { fn on_success(&self, result: &Whispered) { eprintln!("{}: {:?}", chrono::Local::now(), &result); println!("{}", result - .after(&(self.head * 100.0)) - .before(&(self.tail * 100.0)) + .after(&(0.25 * 100.0)) + .before(&(0.25 * 100.0)) .to_string(), ); } @@ -347,7 +351,7 @@ impl Listener { let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0); let stream = device.build_input_stream( &cfg.clone().into(), - move |data: &[f32], _: &cpal::InputCallbackInfo| { + move |data: &[f32], _: &cpal::InputCallbackInfo| { // TODO why cant i do this... let mut downsampled_data = vec![]; for i in 0..(data.len() as f32 / downsample_ratio) as usize { let mut upsampled = i as f32 * downsample_ratio;