From 479cfb055f428d00e2c949ede212daf941e6d431 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Thu, 30 Nov 2023 09:39:43 -0700 Subject: [PATCH] threaded something i guess --- src/main.rs | 121 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 74 insertions(+), 47 deletions(-) diff --git a/src/main.rs b/src/main.rs index 379fe74..3929704 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ use signal_hook::{iterator::Signals, consts::signal::SIGINT}; use std::time::{Duration, Instant}; use chrono; use clap::Parser; +use std::thread; #[derive(Parser, Debug)] struct Flags { @@ -41,57 +42,23 @@ fn main() { println!("{}", result); }, None => { - let host = cpal::default_host(); - let device = host.default_input_device().unwrap(); - let cfg = device.supported_input_configs() - .unwrap() - .filter(|x| x.sample_format() == cpal::SampleFormat::F32) - .nth(0) - .unwrap() - .with_max_sample_rate(); - - let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16000.0); let mut buffer = vec![]; let mut last = Instant::now(); - let stream = device.build_input_stream( - &cfg.clone().into(), - move |data: &[f32], _: &cpal::InputCallbackInfo| { - let mut downsampled_data = vec![]; - for i in 0..(data.len() as f32 / downsample_ratio) as usize { - let mut upsampled = i as f32 * downsample_ratio; - if upsampled > (data.len()-1) as f32 { - upsampled = (data.len()-1) as f32 - } - downsampled_data.push(data[upsampled as usize]); - } - downsampled_data.iter().for_each(|x| buffer.push(*x)); - if Instant::now() - last > stream_step { - let result = w.transcribe(&buffer).unwrap(); - eprintln!("{}", chrono::Local::now()); - println!("{}", result); + new_listener().listen(move |data: Vec| { + data.iter().for_each(|x| buffer.push(*x)); + if Instant::now() - last > stream_step { + let result = w.transcribe(&buffer).unwrap(); + eprintln!("{}", chrono::Local::now()); + println!("{}", result); - let retain = buffer.len() - (buffer.len() as f32 * stream_churn) as usize; - for i in retain..buffer.len() { - buffer[i - retain] = buffer[i] - } - buffer.truncate(retain); - last = Instant::now(); + let retain = buffer.len() - (buffer.len() as f32 * stream_churn) as usize; + for i in retain..buffer.len() { + buffer[i - retain] = buffer[i] } - }, - move |err| { - eprintln!("input error: {}", err) - }, - None, - ).unwrap(); - stream.play().unwrap(); - - eprintln!("listening on {}", device.name().unwrap()); - let mut signals = Signals::new(&[SIGINT]).unwrap(); - for sig in signals.forever() { - eprintln!("sig {}", sig); - break; - } - stream.pause().unwrap(); + buffer.truncate(retain); + last = Instant::now(); + } + }); }, }; } @@ -144,3 +111,63 @@ impl Whisper { Ok(result) } } + +struct Listener { +} + +fn new_listener() -> Listener { + Listener{} +} + +impl Listener { + fn listen(self, mut cb: impl FnMut(Vec)) { + let (send, recv) = std::sync::mpsc::channel(); + thread::spawn(move || { self._listen(send); }); + loop { + match recv.recv() { + Ok(msg) => cb(msg), + Err(msg) => return, + }; + } + } + + fn _listen(self, ch: std::sync::mpsc::Sender>) { + let host = cpal::default_host(); + let device = host.default_input_device().unwrap(); + let cfg = device.supported_input_configs() + .unwrap() + .filter(|x| x.sample_format() == cpal::SampleFormat::F32) + .nth(0) + .unwrap() + .with_max_sample_rate(); + + let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16000.0); + let stream = device.build_input_stream( + &cfg.clone().into(), + move |data: &[f32], _: &cpal::InputCallbackInfo| { + let mut downsampled_data = vec![]; + for i in 0..(data.len() as f32 / downsample_ratio) as usize { + let mut upsampled = i as f32 * downsample_ratio; + if upsampled > (data.len()-1) as f32 { + upsampled = (data.len()-1) as f32 + } + downsampled_data.push(data[upsampled as usize]); + } + ch.send(downsampled_data).unwrap(); + }, + move |err| { + eprintln!("input error: {}", err) + }, + None, + ).unwrap(); + stream.play().unwrap(); + + eprintln!("listening on {}", device.name().unwrap()); + let mut signals = Signals::new(&[SIGINT]).unwrap(); + for sig in signals.forever() { + eprintln!("sig {}", sig); + break; + } + stream.pause().unwrap(); + } +}