From deffc420caadf66337201bcd6327856d702cdd01 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Thu, 30 Nov 2023 12:00:16 -0700 Subject: [PATCH] at least it complies --- src/main.rs | 97 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 27 deletions(-) diff --git a/src/main.rs b/src/main.rs index 6506ead..bfb4504 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,7 @@ struct Flags { fn main() { let flags = Flags::parse(); - let w = new_whisper(flags.model, flags.threads).unwrap(); + let w = new_whisper(flags.model, flags.threads, Handler{}).unwrap(); let stream_retain = (flags.stream_retain * 16_000.0) as usize; let stream_step = Duration::new(flags.stream_step, 0); match flags.wav { @@ -38,10 +38,7 @@ fn main() { let data16 = data.as_sixteen().expect("wav is not 32bit floats"); let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16); - w.transcribe(&audio_data, - |result| println!("{}", result), - |err| eprintln!("failed to transcribe wav: {}", err), - ); + w.transcribe(&audio_data); }, None => { let mut buffer = vec![]; @@ -49,13 +46,7 @@ fn main() { new_listener().listen(move |data: Vec| { data.iter().for_each(|x| buffer.push(*x)); if Instant::now() - last > stream_step { - w.transcribe(&buffer, - |result| { - eprintln!("{}", chrono::Local::now()); - println!("{}", result); - }, - |err| eprintln!("failed to transcribe stream: {}", err), - ); + w.transcribe(&buffer); for i in stream_retain..buffer.len() { buffer[i - stream_retain] = buffer[i] @@ -69,29 +60,65 @@ fn main() { } struct Whisper { - ctx: WhisperContext, - threads: i32, + jobs: std::sync::mpsc::SyncSender, } -fn new_whisper(model_path: String, threads: i32) -> Result { - match WhisperContext::new(&model_path) { - Ok(ctx) => Ok(Whisper{ - ctx: ctx, - threads: threads, - }), - Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), +struct WhisperEngine { + ctx: WhisperContext, + threads: i32, + handler: Handler, +} + +fn new_whisper(model_path: String, threads: i32, handler: Handler) -> Result { + match new_whisper_engine(model_path, threads, handler) { + Ok(engine) => { + let (send, recv) = std::sync::mpsc::sync_channel(100); + thread::spawn(move || { engine.transcribe_asyncs(recv); }); + Ok(Whisper{jobs: send}) + }, + Err(msg) => Err(format!("failed to initialize engine: {}", msg)), } } impl Whisper { - fn transcribe(&self, data: &Vec, on_success: impl Fn(String), on_error: impl Fn(String)) { - match self._transcribe(&data) { - Ok(result) => on_success(result), - Err(msg) => on_error(format!("failed to transcribe: {}", msg)), - }; + fn transcribe(&self, data: &Vec) { + self.transcribe_async(data).unwrap(); + // TODO block } - fn _transcribe(&self, data: &Vec) -> Result { + fn transcribe_async(&self, data: &Vec) -> Result<(), String> { + match self.jobs.try_send(AWhisper{ + data: data.clone().to_vec(), + }) { + Ok(_) => Ok(()), + Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)), + } + } +} + +fn new_whisper_engine(model_path: String, threads: i32, handler: Handler) -> Result { + match WhisperContext::new(&model_path) { + Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads, handler: handler}), + Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), + } +} + +impl WhisperEngine { + fn transcribe_asyncs(&self, recv: std::sync::mpsc::Receiver) { + loop { + match recv.recv() { + Ok(job) => { + match self.transcribe(&job.data) { + Ok(result) => self.handler.on_success(result), + Err(msg) => self.handler.on_error(format!("failed to transcribe: {}", msg)), + }; + }, + Err(_) => return, + }; + } + } + + fn transcribe(&self, data: &Vec) -> Result { let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); params.set_no_context(true); params.set_n_threads(self.threads); @@ -117,6 +144,22 @@ impl Whisper { } } +struct AWhisper { + data: Vec, +} + +struct Handler {} + +impl Handler { + fn on_success(&self, result: String) { + eprintln!("{}", chrono::Local::now()); + println!("{}", result); + } + fn on_error(&self, msg: String) { + eprintln!("error: {}", msg); + } +} + struct Listener { }