diff --git a/rust-whisper.d/src/main.rs b/rust-whisper.d/src/main.rs index 5a7ece4..9a4a66e 100644 --- a/rust-whisper.d/src/main.rs +++ b/rust-whisper.d/src/main.rs @@ -3,7 +3,6 @@ use cpal::traits::{HostTrait, DeviceTrait, StreamTrait}; use signal_hook::{iterator::Signals, consts::signal::SIGINT}; use std::time::{Duration, Instant}; use std::fs; -use std::io::Write; use byteorder::WriteBytesExt; fn main() { @@ -37,6 +36,7 @@ fn main() { .unwrap() .with_max_sample_rate(); let channels = cfg.channels(); + let downsample_ratio = cfg.sample_rate().0 as f32 / 16000.0; let output_cfg = output_device.supported_output_configs() .unwrap() @@ -49,7 +49,7 @@ fn main() { let mut buffer = vec![]; let mut last = Instant::now(); - let five_seconds = Duration::new(15, 0); + let five_seconds = Duration::new(5, 0); device.build_output_stream( &output_cfg.into(), move |data: &mut [f32], _: &cpal::OutputCallbackInfo| { @@ -66,18 +66,23 @@ fn main() { let stream = device.build_input_stream( &cfg.clone().into(), move |data: &[f32], _: &cpal::InputCallbackInfo| { - data.iter() - .map(|x| *x) - .step_by(channels.into()) - .step_by((cfg.sample_rate().0 / 16000) as usize) - .for_each(|x| buffer.push(x)); + let mono_data: Vec = data.iter().map(|x| *x).step_by(channels.into()).collect(); + let mut downsampled_data = vec![]; + for i in 0..(mono_data.len() as f32 / downsample_ratio) as usize { + let mut upsampled = i as f32 * downsample_ratio; + if upsampled > (mono_data.len()-1) as f32 { + upsampled = (mono_data.len()-1) as f32 + } + downsampled_data.push(mono_data[upsampled as usize]); + } + downsampled_data.iter().for_each(|x| buffer.push(*x)); if Instant::now() - last > five_seconds { - let result = w.transcribe(&buffer).unwrap(); - println!("({} from {:?}) {}", buffer.len(), cfg, result); let mut f = fs::File::create("/tmp/transcribed.pcm").unwrap(); for i in &buffer { f.write_f32::(*i).unwrap(); } + let result = w.transcribe(&buffer).unwrap(); + println!("({} from {:?} and downsampled {} * {} ({} -> {})) {}", buffer.len(), cfg, channels, downsample_ratio, data.len(), downsampled_data.len(), result); let retain = buffer.len() - buffer.len() / 10; for i in retain..buffer.len() { @@ -140,6 +145,7 @@ impl Whisper { params.set_print_progress(false); params.set_print_realtime(false); params.set_print_timestamps(false); + params.set_no_context(true); let mut state = self.ctx.create_state()?; state.full(params, &data[..])?;