From 2391d07994f825d90076f91b332347dfe4338900 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Thu, 30 Nov 2023 09:58:28 -0700 Subject: [PATCH] transcribing results as callbacks --- src/main.rs | 47 ++++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/main.rs b/src/main.rs index 44cb7a3..6506ead 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,8 +14,8 @@ struct Flags { #[arg(long, default_value = "8")] threads: i32, - #[arg(long, default_value = "0.8")] - stream_churn: f32, + #[arg(long, default_value = "1.0")] + stream_retain: f32, #[arg(long, default_value = "5")] stream_step: u64, @@ -26,7 +26,7 @@ fn main() { let flags = Flags::parse(); let w = new_whisper(flags.model, flags.threads).unwrap(); - let stream_churn = flags.stream_churn; + let stream_retain = (flags.stream_retain * 16_000.0) as usize; let stream_step = Duration::new(flags.stream_step, 0); match flags.wav { Some(wav) => { @@ -34,12 +34,14 @@ fn main() { &mut std::fs::File::open(wav).expect("failed to open $WAV"), ).expect("failed to decode $WAV"); assert!(header.channel_count == 1); - assert!(header.sampling_rate == 16000); + assert!(header.sampling_rate == 16_000); let data16 = data.as_sixteen().expect("wav is not 32bit floats"); let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16); - let result = w.transcribe(&audio_data).unwrap(); - println!("{}", result); + w.transcribe(&audio_data, + |result| println!("{}", result), + |err| eprintln!("failed to transcribe wav: {}", err), + ); }, None => { let mut buffer = vec![]; @@ -47,15 +49,18 @@ fn main() { new_listener().listen(move |data: Vec| { data.iter().for_each(|x| buffer.push(*x)); if Instant::now() - last > stream_step { - let result = w.transcribe(&buffer).unwrap(); - eprintln!("{}", chrono::Local::now()); - println!("{}", result); + w.transcribe(&buffer, + |result| { + eprintln!("{}", chrono::Local::now()); + println!("{}", result); + }, + |err| eprintln!("failed to transcribe stream: {}", err), + ); - let retain = buffer.len() - (buffer.len() as f32 * stream_churn) as usize; - for i in retain..buffer.len() { - buffer[i - retain] = buffer[i] + for i in stream_retain..buffer.len() { + buffer[i - stream_retain] = buffer[i] } - buffer.truncate(retain); + buffer.truncate(stream_retain); last = Instant::now(); } }); @@ -79,11 +84,11 @@ fn new_whisper(model_path: String, threads: i32) -> Result { } impl Whisper { - fn transcribe(&self, data: &Vec) -> Result { + fn transcribe(&self, data: &Vec, on_success: impl Fn(String), on_error: impl Fn(String)) { match self._transcribe(&data) { - Ok(result) => Ok(result), - Err(msg) => Err(format!("failed to transcribe: {}", msg)), - } + Ok(result) => on_success(result), + Err(msg) => on_error(format!("failed to transcribe: {}", msg)), + }; } fn _transcribe(&self, data: &Vec) -> Result { @@ -121,7 +126,7 @@ fn new_listener() -> Listener { impl Listener { fn listen(self, mut cb: impl FnMut(Vec)) { - let (send, recv) = std::sync::mpsc::sync_channel(5_000_000); + let (send, recv) = std::sync::mpsc::sync_channel(5_000); thread::spawn(move || { self._listen(send); }); loop { match recv.recv() { @@ -131,7 +136,7 @@ impl Listener { } } - fn _listen(self, ch: std::sync::mpsc::SyncSender>) { + fn _listen(self, send: std::sync::mpsc::SyncSender>) { let host = cpal::default_host(); let device = host.default_input_device().unwrap(); let cfg = device.supported_input_configs() @@ -141,7 +146,7 @@ impl Listener { .unwrap() .with_max_sample_rate(); - let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16000.0); + let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0); let stream = device.build_input_stream( &cfg.clone().into(), move |data: &[f32], _: &cpal::InputCallbackInfo| { @@ -153,7 +158,7 @@ impl Listener { } downsampled_data.push(data[upsampled as usize]); } - match ch.send(downsampled_data) { + match send.try_send(downsampled_data) { Ok(_) => (), Err(msg) => eprintln!("failed to ingest audio: {}", msg), };