use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError}; use cpal::traits::{HostTrait, DeviceTrait, StreamTrait}; use signal_hook::{iterator::Signals, consts::signal::SIGINT}; use std::time::{Duration, Instant}; use chrono; use clap::Parser; #[derive(Parser, Debug)] struct Flags { #[arg(long, default_value = "../models/ggml-tiny.en.bin")] model: String, #[arg(long, default_value = "8")] threads: i32, #[arg(long, default_value = "0.8")] stream_churn: f32, #[arg(long, default_value = "5")] stream_step: u64, wav: Option, } fn main() { let flags = Flags::parse(); let w = new_whisper(flags.model, flags.threads).unwrap(); let stream_churn = flags.stream_churn; let stream_step = Duration::new(flags.stream_step, 0); match flags.wav { Some(wav) => { let (header, data) = wav::read( &mut std::fs::File::open(wav).expect("failed to open $WAV"), ).expect("failed to decode $WAV"); assert!(header.channel_count == 1); assert!(header.sampling_rate == 16000); let data16 = data.as_sixteen().expect("wav is not 32bit floats"); let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16); let result = w.transcribe(&audio_data).unwrap(); println!("{}", result); }, None => { let host = cpal::default_host(); let device = host.default_input_device().unwrap(); let cfg = device.supported_input_configs() .unwrap() .filter(|x| x.sample_format() == cpal::SampleFormat::F32) .nth(0) .unwrap() .with_max_sample_rate(); let channels = cfg.channels(); let downsample_ratio = cfg.sample_rate().0 as f32 / 16000.0; let mut buffer = vec![]; let mut last = Instant::now(); let stream = device.build_input_stream( &cfg.clone().into(), move |data: &[f32], _: &cpal::InputCallbackInfo| { let mono_data: Vec = data.iter().map(|x| *x).step_by(channels.into()).collect(); let mut downsampled_data = vec![]; for i in 0..(mono_data.len() as f32 / downsample_ratio) as usize { let mut upsampled = i as f32 * downsample_ratio; if upsampled > (mono_data.len()-1) as f32 { upsampled = (mono_data.len()-1) as f32 } downsampled_data.push(mono_data[upsampled as usize]); } downsampled_data.iter().for_each(|x| buffer.push(*x)); if Instant::now() - last > stream_step { let result = w.transcribe(&buffer).unwrap(); eprintln!("{}", chrono::Local::now()); println!("{}", result); let retain = buffer.len() - (buffer.len() as f32 * stream_churn) as usize; for i in retain..buffer.len() { buffer[i - retain] = buffer[i] } buffer.truncate(retain); last = Instant::now(); } }, move |err| { eprintln!("input error: {}", err) }, None, ).unwrap(); stream.play().unwrap(); eprintln!("listening on {}", device.name().unwrap()); let mut signals = Signals::new(&[SIGINT]).unwrap(); for sig in signals.forever() { eprintln!("sig {}", sig); break; } stream.pause().unwrap(); }, }; } struct Whisper { ctx: WhisperContext, threads: i32, } fn new_whisper(model_path: String, threads: i32) -> Result { match WhisperContext::new(&model_path) { Ok(ctx) => Ok(Whisper{ ctx: ctx, threads: threads, }), Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), } } impl Whisper { fn transcribe(&self, data: &Vec) -> Result { match self._transcribe(&data) { Ok(result) => Ok(result), Err(msg) => Err(format!("failed to transcribe: {}", msg)), } } fn _transcribe(&self, data: &Vec) -> Result { let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); params.set_no_context(true); params.set_n_threads(self.threads); params.set_translate(false); params.set_detect_language(false); params.set_language(Some("en")); params.set_print_special(false); params.set_print_progress(false); params.set_print_realtime(false); params.set_print_timestamps(false); let mut state = self.ctx.create_state()?; state.full(params, &data[..])?; let num_segments = state.full_n_segments()?; let mut result = "".to_string(); for i in 0..num_segments { let segment = state.full_get_segment_text(i)?; result = format!("{} {}", result, segment); } Ok(result) } }