use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError}; use clap::Parser; use std::thread; use std::fs::File; use std::io::Write; #[derive(Parser, Debug, Clone)] pub struct Flags { #[arg(long, default_value = "../models/ggml-tiny.en.bin")] pub model_path: Option, #[arg(long, default_value = None)] pub model_buffer: Option>, #[arg(long, default_value = "8")] pub threads: i32, #[arg(long, default_value = "5")] pub stream_step: u64, #[arg(long, default_value = "0.6")] pub stream_retain: f32, #[arg(long, default_value = "0.3")] pub stream_head: f32, #[arg(long, default_value = "0.3")] pub stream_tail: f32, #[arg(long, default_value = "false")] pub debug: bool, #[arg(long, default_value = None)] pub wav: Option, #[arg(long, default_value = None)] pub stream_device: Option, } pub fn wav(flags: Flags, handler_fn: F, wav_path: String) where F: FnMut(Result) + Send + 'static { let w = new_service( flags.model_path, flags.model_buffer, flags.threads, flags.stream_head, flags.stream_tail, handler_fn, ).unwrap(); w.transcribe(&f32_from_wav_file(&wav_path).unwrap()) } pub fn f32_from_wav_file(path: &String) -> Result, String> { let f = std::fs::File::open(path); if let Some(err) = f.as_ref().err() { return Err(format!("failed to open wav file: {}", err)); } let wav_read = wav::read(&mut f.unwrap()); if let Some(err) = wav_read.as_ref().err() { return Err(format!("failed to parse wav file: {}", err)); } let (header, data) = wav_read.unwrap(); if header.channel_count != 1 { return Err("!= 1 channel".to_string()); } if header.sampling_rate != 16_000 { return Err("!= 16_000 hz".to_string()); } match data.as_sixteen() { Some(data16) => { let mut floats = Vec::with_capacity(data16.len()); for sample in data16 { floats.push(*sample as f32 / 32768.0); } Ok(floats) }, None => Err(format!("couldnt translate wav to 16s")), } } pub fn wav_channel(flags: Flags, handler_fn: F) where F: FnMut(Result) + Send + 'static { let path = flags.wav.as_ref().unwrap(); let mut audio = f32_from_wav_file(&path).unwrap(); let mut iter = vec![]; let n = audio.len() / match audio.len() % 100 { 0 => 100, _ => 99, }; for _ in 0..100 { iter.push(audio.drain(0..n.clamp(0, audio.len())).collect()); } let (fin_send, fin_recv) = std::sync::mpsc::sync_channel::>(1); channel_and_close(flags.clone(), handler_fn, iter, move || { fin_send.send(None).unwrap(); }); match fin_recv.recv() { Ok(_) => {}, Err(x) => panic!("failed to receive: {}", x), }; } pub fn channel(flags: Flags, handler_fn: F, stream: I) where F: FnMut(Result) + Send + 'static, I: IntoIterator> { channel_and_close(flags, handler_fn, stream, || {}); } fn channel_and_close(flags: Flags, handler_fn: F, stream: I, mut close_fn: G) where F: FnMut(Result) + Send + 'static, I: IntoIterator>, G: FnMut() + Send + 'static { let w = new_service( flags.model_path, flags.model_buffer, flags.threads, flags.stream_head, flags.stream_tail, handler_fn, ).unwrap(); let stream_retain = (flags.stream_retain * 16_000.0) as usize; match &flags.debug { true => { File::create("/tmp/page.rawf32audio").unwrap(); }, false => {}, }; let mut buffer = vec![]; for data in stream { data.iter().for_each(|x| buffer.push(*x)); if buffer.len() >= (flags.stream_step * 16_000) as usize { w.transcribe_async(&buffer).unwrap(); match &flags.debug { true => { let mut f = File::options().append(true).open("/tmp/page.rawf32audio").unwrap(); let mut wav_data = vec![]; for i in buffer.iter() { for j in i.to_le_bytes() { wav_data.push(j); } } f.write_all(wav_data.as_slice()).unwrap(); }, false => {}, }; for i in 0..stream_retain { buffer[i] = buffer[buffer.len() - stream_retain + i]; } buffer.truncate(stream_retain); } } if buffer.len() > 0 { w.transcribe(&buffer); } close_fn(); } struct Service { jobs: std::sync::mpsc::SyncSender, } fn new_service(model_path: Option, model_buffer: Option>, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result where F: FnMut(Result) + Send + 'static { match new_engine(model_path, model_buffer, threads) { Ok(engine) => { let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn); let (send, recv) = std::sync::mpsc::sync_channel(100); thread::spawn(move || { whisper.transcribe_asyncs(recv); }); Ok(Service{jobs: send}) }, Err(msg) => Err(format!("failed to initialize engine: {}", msg)), } } impl Service { fn transcribe(&self, data: &Vec) { let (send, recv) = std::sync::mpsc::sync_channel(0); self._transcribe_async(data, Some(send)).unwrap(); recv.recv().unwrap(); } fn transcribe_async(&self, data: &Vec) -> Result<(), String> { self._transcribe_async(data, None) } fn _transcribe_async(&self, data: &Vec, ack: Option>) -> Result<(), String> { match self.jobs.try_send(ATranscribe{ data: data.clone().to_vec(), ack: ack, }) { Ok(_) => Ok(()), Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)), } } } struct Impl { engine: Engine, stream_head: f32, stream_tail: f32, handler_fn: Option) + Send + 'static>> } fn new_whisper_impl(engine: Engine, stream_head: f32, stream_tail: f32, handler_fn: F) -> Impl where F: FnMut(Result) + Send + 'static { Impl { engine: engine, stream_head: stream_head, stream_tail: stream_tail, handler_fn: Some(Box::new(handler_fn)), } } impl Impl { fn transcribe_asyncs(&mut self, recv: std::sync::mpsc::Receiver) { loop { match recv.recv() { Ok(job) => { let result = self.transcribe(&job).is_ok(); match job.ack { Some(ack) => { ack.send(result).unwrap(); }, None => (), }; } Err(_) => return, }; } } fn transcribe(&mut self, a_whisper: &ATranscribe) -> Result<(), ()> { match self.engine.transcribe(&a_whisper.data) { Ok(result) => { self.on_success(&result); Ok(()) }, Err(msg) => { self.on_error(msg.to_string()); Err(()) }, } } fn on_success(&mut self, whispered: &Transcribed) { let result = whispered .after(&(self.stream_head * 100.0)) .before(&(self.stream_tail * 100.0)); (self.handler_fn.as_mut().unwrap())(Ok(result)); } fn on_error(&mut self, msg: String) { (self.handler_fn.as_mut().unwrap())(Err(format!("failed to transcribe: {}", &msg))); } } struct Engine { ctx: WhisperContext, threads: i32, } fn new_engine(model_path: Option, model_buffer: Option>, threads: i32) -> Result { let whisper_context_result = match model_path { Some(model_path) => WhisperContext::new(&model_path), None => WhisperContext::new_from_buffer(&model_buffer.unwrap()), }; match whisper_context_result { Ok(ctx) => Ok(Engine{ctx: ctx, threads: threads}), Err(msg) => Err(format!("failed to load model: {}", msg)), } } impl Engine { fn transcribe(&self, data: &Vec) -> Result { match self._transcribe(data) { Ok(transcribed) => Ok(transcribed), Err(msg) => Err(format!("{}", msg)), } } fn _transcribe(&self, data: &Vec) -> Result { let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); params.set_no_context(true); params.set_n_threads(self.threads); params.set_translate(false); params.set_detect_language(false); params.set_language(Some("en")); params.set_print_special(false); params.set_print_progress(false); params.set_print_realtime(false); params.set_print_timestamps(false); let mut state = self.ctx.create_state()?; state.full(params, &data[..])?; let mut result = new_whispered(); let num_segments = state.full_n_segments()?; for i in 0..num_segments { let data = state.full_get_segment_text(i)?; let start = state.full_get_segment_t0(i)?; let stop = state.full_get_segment_t1(i)?; result.push(data, start, stop); } Ok(result) } } struct ATranscribe { data: Vec, ack: Option>, } #[derive(Clone, Debug)] pub struct Transcribed { data: Vec, } #[derive(Clone, Debug)] struct ATranscribed { pub data: String, pub offset: i64, pub length: i64, } fn new_whispered() -> Transcribed { Transcribed{data: vec![]} } fn new_a_whispered(data: String, start: i64, stop: i64) -> ATranscribed { ATranscribed{ data: data, offset: start.clone(), length: stop - start, } } impl Transcribed { pub fn to_string(&self) -> String { let mut result = "".to_string(); for i in 0..self.data.len() { result = format!("{} {}", result, &self.data[i].data); } result } fn after(&self, t: &f32) -> Transcribed { let mut result = new_whispered(); self.data .iter() .filter(|x| x.offset as f32 >= *t) .for_each(|x| result.data.push(x.clone())); result } fn before(&self, t: &f32) -> Transcribed { let mut result = new_whispered(); let end = match self.data.iter().map(|x| x.offset + x.length).max() { Some(x) => x, None => 1, }; let t = (end as f32) - *t; self.data .iter() .filter(|x| ((x.offset) as f32) <= t) .for_each(|x| result.data.push(x.clone())); result } fn push(&mut self, data: String, start: i64, stop: i64) { let words: Vec<_> = data.split_whitespace().collect(); let per_word = (stop - start) / (words.len() as i64); for i in 0..words.len() { let start = (i as i64) * per_word; let stop = start.clone() + per_word; self.data.push(new_a_whispered(words[i].to_string(), start, stop)); } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_transcribe_tiny_jfk_wav_whisper_rs_wav_channel() { wav_channel( Flags { model_path: None, model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()), threads: 8, stream_step: 30, stream_retain: 0.0, stream_head: 0.0, stream_tail: 0.0, wav: Some("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()), debug: false, stream_device: None, }, move | result | { assert!(result.is_ok()); assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country."); }, ); } #[test] fn test_transcribe_tiny_jfk_wav_whisper_rs() { wav( Flags { model_path: None, model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()), threads: 8, stream_step: 0, stream_retain: 0.0, stream_head: 0.0, stream_tail: 0.0, wav: Some("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()), debug: false, stream_device: None, }, | result | { assert!(result.is_ok()); assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country."); }, "../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string(), ); } }