use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError}; use cpal::traits::{HostTrait, DeviceTrait, StreamTrait}; use signal_hook::{iterator::Signals, consts::signal::SIGINT}; use std::time::{Duration, Instant}; fn main() { let w = new_whisper( std::env::var("MODEL") .unwrap_or(String::from("../models/ggml-tiny.en.bin")), std::env::var("P") .unwrap_or(String::from("8")) .parse::().expect("$P must be a number"), ).unwrap(); let stream_churn = std::env::var("STREAM_CHURN") .unwrap_or(String::from("0.8")) .parse::().expect("$STREAM_CHURN must be a number"); let stream_step = Duration::new( std::env::var("STREAM_STEP") .unwrap_or(String::from("5")) .parse::().expect("$STREAM_STEP must be a number"), 0, ); match std::env::var("WAV") { Ok(wav) => { let (header, data) = wav::read( &mut std::fs::File::open(wav).expect("failed to open $WAV"), ).expect("failed to decode $WAV"); assert!(header.channel_count == 1); assert!(header.sampling_rate == 16000); let data16 = data.as_sixteen().expect("wav is not 32bit floats"); let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16); let result = w.transcribe(&audio_data).unwrap(); println!("{}", result); }, Err(_) => { let host = cpal::default_host(); let device = host.default_input_device().unwrap(); let cfg = device.supported_input_configs() .unwrap() .filter(|x| x.sample_format() == cpal::SampleFormat::F32) .nth(0) .unwrap() .with_max_sample_rate(); let channels = cfg.channels(); let downsample_ratio = cfg.sample_rate().0 as f32 / 16000.0; let mut buffer = vec![]; let mut last = Instant::now(); let stream = device.build_input_stream( &cfg.clone().into(), move |data: &[f32], _: &cpal::InputCallbackInfo| { let mono_data: Vec = data.iter().map(|x| *x).step_by(channels.into()).collect(); let mut downsampled_data = vec![]; for i in 0..(mono_data.len() as f32 / downsample_ratio) as usize { let mut upsampled = i as f32 * downsample_ratio; if upsampled > (mono_data.len()-1) as f32 { upsampled = (mono_data.len()-1) as f32 } downsampled_data.push(mono_data[upsampled as usize]); } downsampled_data.iter().for_each(|x| buffer.push(*x)); if Instant::now() - last > stream_step { let result = w.transcribe(&buffer).unwrap(); println!("{}", result); let retain = buffer.len() - (buffer.len() as f32 * stream_churn) as usize; for i in retain..buffer.len() { buffer[i - retain] = buffer[i] } buffer.truncate(retain); last = Instant::now(); } }, move |err| { eprintln!("input error: {}", err) }, None, ).unwrap(); stream.play().unwrap(); eprintln!("listening on {}", device.name().unwrap()); let mut signals = Signals::new(&[SIGINT]).unwrap(); for sig in signals.forever() { println!("sig {}", sig); break; } stream.pause().unwrap(); }, }; } struct Whisper { ctx: WhisperContext, threads: i32, } fn new_whisper(model_path: String, threads: i32) -> Result { match WhisperContext::new(&model_path) { Ok(ctx) => Ok(Whisper{ ctx: ctx, threads: threads, }), Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), } } impl Whisper { fn transcribe(&self, data: &Vec) -> Result { match self._transcribe(&data) { Ok(result) => Ok(result), Err(msg) => Err(format!("failed to transcribe: {}", msg)), } } fn _transcribe(&self, data: &Vec) -> Result { //eprintln!("{:?} ({})", data, data.len()); let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); params.set_n_threads(self.threads); params.set_translate(false); params.set_detect_language(false); params.set_language(Some("en")); params.set_print_special(false); params.set_print_progress(false); params.set_print_realtime(false); params.set_print_timestamps(false); params.set_no_context(true); let mut state = self.ctx.create_state()?; state.full(params, &data[..])?; let num_segments = state.full_n_segments()?; let mut result = "".to_string(); for i in 0..num_segments { let segment = state.full_get_segment_text(i)?; result = format!("{} {}", result, segment); } Ok(result) } }