From cd339de3343caddbbb1701578ae524501e9ef9f7 Mon Sep 17 00:00:00 2001 From: Bel LaPointe <153096461+breel-render@users.noreply.github.com> Date: Tue, 2 Jan 2024 14:15:53 -0700 Subject: [PATCH] rust-whisper-lib::wav_channel() --- rust-whisper-lib/src/lib.rs | 53 +++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/rust-whisper-lib/src/lib.rs b/rust-whisper-lib/src/lib.rs index 3e492ed..4b7e109 100644 --- a/rust-whisper-lib/src/lib.rs +++ b/rust-whisper-lib/src/lib.rs @@ -72,7 +72,30 @@ pub fn f32_from_wav_file(path: &String) -> Result, String> { } } -pub fn channel(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver>) where F: FnMut(Result) + Send + 'static { +pub fn wav_channel(flags: Flags, handler_fn: F) where F: FnMut(Result) + Send + 'static { + let path = flags.wav.as_ref().unwrap(); + let mut audio = f32_from_wav_file(&path).unwrap(); + let mut iter = vec![]; + let n = audio.len() / match audio.len() % 100 { + 0 => 100, + _ => 99, + }; + for _ in 0..100 { + iter.push(audio.drain(0..n.clamp(0, audio.len())).collect()); + } + let (fin_send, fin_recv) = std::sync::mpsc::sync_channel::>(1); + channel_and_close(flags.clone(), handler_fn, iter, move || { fin_send.send(None).unwrap(); }); + match fin_recv.recv() { + Ok(_) => {}, + Err(x) => panic!("failed to receive: {}", x), + }; +} + +pub fn channel(flags: Flags, handler_fn: F, stream: I) where F: FnMut(Result) + Send + 'static, I: IntoIterator> { + channel_and_close(flags, handler_fn, stream, || {}); +} + +fn channel_and_close(flags: Flags, handler_fn: F, stream: I, mut close_fn: G) where F: FnMut(Result) + Send + 'static, I: IntoIterator>, G: FnMut() + Send + 'static { let w = new_service( flags.model_path, flags.model_buffer, @@ -87,7 +110,7 @@ pub fn channel(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver false => {}, }; let mut buffer = vec![]; - for data in stream.iter() { + for data in stream { data.iter().for_each(|x| buffer.push(*x)); if buffer.len() >= (flags.stream_step * 16_000) as usize { w.transcribe_async(&buffer).unwrap(); @@ -112,6 +135,10 @@ pub fn channel(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver buffer.truncate(stream_retain); } } + if buffer.len() > 0 { + w.transcribe(&buffer); + } + close_fn(); } struct Service { @@ -339,6 +366,28 @@ impl Transcribed { mod tests { use super::*; + #[test] + fn test_transcribe_tiny_jfk_wav_whisper_rs_wav_channel() { + wav_channel( + Flags { + model_path: None, + model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()), + threads: 8, + stream_step: 30, + stream_retain: 0.0, + stream_head: 0.0, + stream_tail: 0.0, + wav: Some("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()), + debug: false, + stream_device: None, + }, + move | result | { + assert!(result.is_ok()); + assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country."); + }, + ); + } + #[test] fn test_transcribe_tiny_jfk_wav_whisper_rs() { wav(