use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError}; use cpal::traits::{HostTrait, DeviceTrait, StreamTrait}; use signal_hook::{iterator::Signals, consts::signal::SIGINT}; use std::time::{Duration, Instant}; use chrono; use clap::Parser; use std::thread; use std::fs::File; use std::io::Write; #[derive(Parser, Debug)] struct Flags { #[arg(long, default_value = "./models/ggml-tiny.en.bin")] model: String, #[arg(long, default_value = "8")] threads: i32, #[arg(long, default_value = "5")] stream_step: u64, #[arg(long, default_value = "0.6")] stream_retain: f32, #[arg(long, default_value = "0.3")] stream_head: f32, #[arg(long, default_value = "0.3")] stream_tail: f32, wav: Option, #[arg(long, default_value = "false")] debug: bool, } fn main() { let flags = Flags::parse(); let w = new_whisper_service( flags.model, flags.threads, flags.stream_head, flags.stream_tail, |result| { match result { Ok(whispered) => { eprintln!("{}: {:?}", chrono::Local::now(), whispered); println!("{}", whispered.to_string()); }, Err(msg) => { eprintln!("Error whispering: {}", msg); }, }; }, ).unwrap(); let stream_retain = (flags.stream_retain * 16_000.0) as usize; let stream_step = Duration::new(flags.stream_step, 0); match flags.wav { Some(wav) => { let (header, data) = wav::read( &mut std::fs::File::open(wav).expect("failed to open $WAV"), ).expect("failed to decode $WAV"); assert!(header.channel_count == 1); assert!(header.sampling_rate == 16_000); let data16 = data.as_sixteen().expect("wav is not 32bit floats"); let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16); w.transcribe(&audio_data); }, None => { match &flags.debug { true => { File::create("/tmp/page.rawf32audio").unwrap(); }, false => {}, }; let mut buffer = vec![]; let mut last = Instant::now(); new_listener().listen(move |data: Vec| { data.iter().for_each(|x| buffer.push(*x)); if Instant::now() - last > stream_step { w.transcribe_async(&buffer).unwrap(); match &flags.debug { true => { let mut f = File::options().append(true).open("/tmp/page.rawf32audio").unwrap(); let mut wav_data = vec![]; for i in buffer.iter() { for j in i.to_le_bytes() { wav_data.push(j); } } f.write_all(wav_data.as_slice()).unwrap(); }, false => {}, }; for i in 0..stream_retain { buffer[i] = buffer[buffer.len() - stream_retain + i]; } buffer.truncate(stream_retain); last = Instant::now(); } }); }, }; } struct WhisperService { jobs: std::sync::mpsc::SyncSender, } fn new_whisper_service(model_path: String, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: fn(Result)) -> Result { match new_whisper_engine(model_path, threads) { Ok(engine) => { let whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn); let (send, recv) = std::sync::mpsc::sync_channel(100); thread::spawn(move || { whisper.transcribe_asyncs(recv); }); Ok(WhisperService{jobs: send}) }, Err(msg) => Err(format!("failed to initialize engine: {}", msg)), } } impl WhisperService { fn transcribe(&self, data: &Vec) { let (send, recv) = std::sync::mpsc::sync_channel(0); self._transcribe_async(data, Some(send)).unwrap(); recv.recv().unwrap(); } fn transcribe_async(&self, data: &Vec) -> Result<(), String> { self._transcribe_async(data, None) } fn _transcribe_async(&self, data: &Vec, ack: Option>) -> Result<(), String> { match self.jobs.try_send(AWhisper{ data: data.clone().to_vec(), ack: ack, }) { Ok(_) => Ok(()), Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)), } } } struct WhisperImpl { engine: WhisperEngine, stream_head: f32, stream_tail: f32, handler_fn: fn(Result), } fn new_whisper_impl(engine: WhisperEngine, stream_head: f32, stream_tail: f32, handler_fn: fn(Result)) -> WhisperImpl { WhisperImpl { engine: engine, stream_head: stream_head, stream_tail: stream_tail, handler_fn: handler_fn, } } impl WhisperImpl { fn transcribe_asyncs(&self, recv: std::sync::mpsc::Receiver) { loop { match recv.recv() { Ok(job) => { let result = self.transcribe(&job).is_ok(); match job.ack { Some(ack) => { ack.send(result).unwrap(); }, None => (), }; } Err(_) => return, }; } } fn transcribe(&self, a_whisper: &AWhisper) -> Result<(), ()> { match self.engine.transcribe(&a_whisper.data) { Ok(result) => { self.on_success(&result); Ok(()) }, Err(msg) => { self.on_error(msg.to_string()); Err(()) }, } } fn on_success(&self, whispered: &Whispered) { let result = whispered .after(&(self.stream_head * 100.0)) .before(&(self.stream_tail * 100.0)); (self.handler_fn)(Ok(result)); } fn on_error(&self, msg: String) { (self.handler_fn)(Err(format!("failed to transcribe: {}", &msg))); } } struct WhisperEngine { ctx: WhisperContext, threads: i32, } fn new_whisper_engine(model_path: String, threads: i32) -> Result { match WhisperContext::new(&model_path) { Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}), Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), } } impl WhisperEngine { fn transcribe(&self, data: &Vec) -> Result { let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); params.set_no_context(true); params.set_n_threads(self.threads); params.set_translate(false); params.set_detect_language(false); params.set_language(Some("en")); params.set_print_special(false); params.set_print_progress(false); params.set_print_realtime(false); params.set_print_timestamps(false); let mut state = self.ctx.create_state()?; state.full(params, &data[..])?; let mut result = new_whispered(); let num_segments = state.full_n_segments()?; for i in 0..num_segments { let data = state.full_get_segment_text(i)?; let start = state.full_get_segment_t0(i)?; let stop = state.full_get_segment_t1(i)?; result.push(data, start, stop); } Ok(result) } } struct AWhisper { data: Vec, ack: Option>, } #[derive(Clone, Debug)] struct Whispered { data: Vec, } #[derive(Clone, Debug)] struct AWhispered { data: String, offset: i64, length: i64, } fn new_whispered() -> Whispered { Whispered{data: vec![]} } fn new_a_whispered(data: String, start: i64, stop: i64) -> AWhispered { AWhispered{ data: data, offset: start.clone(), length: stop - start, } } impl Whispered { fn to_string(&self) -> String { let mut result = "".to_string(); for i in 0..self.data.len() { result = format!("{} {}", result, &self.data[i].data); } result } fn after(&self, t: &f32) -> Whispered { let mut result = new_whispered(); self.data .iter() .filter(|x| x.offset as f32 >= *t) .for_each(|x| result.data.push(x.clone())); result } fn before(&self, t: &f32) -> Whispered { let mut result = new_whispered(); let end = match self.data.iter().map(|x| x.offset + x.length).max() { Some(x) => x, None => 1, }; let t = (end as f32) - *t; self.data .iter() .filter(|x| ((x.offset) as f32) <= t) .for_each(|x| result.data.push(x.clone())); result } fn push(&mut self, data: String, start: i64, stop: i64) { let words: Vec<_> = data.split_whitespace().collect(); let per_word = (stop - start) / (words.len() as i64); for i in 0..words.len() { let start = (i as i64) * per_word; let stop = start.clone() + per_word; self.data.push(new_a_whispered(words[i].to_string(), start, stop)); } } } struct Listener { } fn new_listener() -> Listener { Listener{} } impl Listener { fn listen(self, mut cb: impl FnMut(Vec)) { let (send, recv) = std::sync::mpsc::sync_channel(100); thread::spawn(move || { self._listen(send); }); loop { match recv.recv() { Ok(msg) => cb(msg), Err(_) => return, }; } } fn _listen(self, send: std::sync::mpsc::SyncSender>) { let host = cpal::default_host(); let device = host.default_input_device().unwrap(); let cfg = device.supported_input_configs() .unwrap() .filter(|x| x.sample_format() == cpal::SampleFormat::F32) .nth(0) .unwrap() .with_max_sample_rate(); let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0); let stream = device.build_input_stream( &cfg.clone().into(), move |data: &[f32], _: &cpal::InputCallbackInfo| { // TODO why cant i do this... let mut downsampled_data = vec![]; for i in 0..(data.len() as f32 / downsample_ratio) as usize { let mut upsampled = i as f32 * downsample_ratio; if upsampled > (data.len()-1) as f32 { upsampled = (data.len()-1) as f32 } downsampled_data.push(data[upsampled as usize]); } match send.try_send(downsampled_data) { Ok(_) => (), Err(msg) => eprintln!("failed to ingest audio: {}", msg), }; }, move |err| { eprintln!("input error: {}", err) }, None, ).unwrap(); stream.play().unwrap(); eprintln!("listening on {}", device.name().unwrap()); let mut signals = Signals::new(&[SIGINT]).unwrap(); for sig in signals.forever() { eprintln!("sig {}", sig); break; } stream.pause().unwrap(); } }