3 Commits

Author SHA1 Message Date
Bel LaPointe
532ae22908 back to mvp 2023-11-30 12:28:35 -07:00
Bel LaPointe
deffc420ca at least it complies 2023-11-30 12:00:16 -07:00
Bel LaPointe
2391d07994 transcribing results as callbacks 2023-11-30 09:58:28 -07:00

View File

@@ -14,8 +14,8 @@ struct Flags {
#[arg(long, default_value = "8")] #[arg(long, default_value = "8")]
threads: i32, threads: i32,
#[arg(long, default_value = "0.8")] #[arg(long, default_value = "1.0")]
stream_churn: f32, stream_retain: f32,
#[arg(long, default_value = "5")] #[arg(long, default_value = "5")]
stream_step: u64, stream_step: u64,
@@ -25,8 +25,8 @@ struct Flags {
fn main() { fn main() {
let flags = Flags::parse(); let flags = Flags::parse();
let w = new_whisper(flags.model, flags.threads).unwrap(); let w = new_whisper(flags.model, flags.threads, Handler{}).unwrap();
let stream_churn = flags.stream_churn; let stream_retain = (flags.stream_retain * 16_000.0) as usize;
let stream_step = Duration::new(flags.stream_step, 0); let stream_step = Duration::new(flags.stream_step, 0);
match flags.wav { match flags.wav {
Some(wav) => { Some(wav) => {
@@ -34,12 +34,11 @@ fn main() {
&mut std::fs::File::open(wav).expect("failed to open $WAV"), &mut std::fs::File::open(wav).expect("failed to open $WAV"),
).expect("failed to decode $WAV"); ).expect("failed to decode $WAV");
assert!(header.channel_count == 1); assert!(header.channel_count == 1);
assert!(header.sampling_rate == 16000); assert!(header.sampling_rate == 16_000);
let data16 = data.as_sixteen().expect("wav is not 32bit floats"); let data16 = data.as_sixteen().expect("wav is not 32bit floats");
let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16); let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16);
let result = w.transcribe(&audio_data).unwrap(); w.transcribe(&audio_data);
println!("{}", result);
}, },
None => { None => {
let mut buffer = vec![]; let mut buffer = vec![];
@@ -47,15 +46,15 @@ fn main() {
new_listener().listen(move |data: Vec<f32>| { new_listener().listen(move |data: Vec<f32>| {
data.iter().for_each(|x| buffer.push(*x)); data.iter().for_each(|x| buffer.push(*x));
if Instant::now() - last > stream_step { if Instant::now() - last > stream_step {
let result = w.transcribe(&buffer).unwrap(); match w.transcribe_async(&buffer) {
eprintln!("{}", chrono::Local::now()); Ok(_) => (),
println!("{}", result); Err(msg) => eprintln!("{}", msg),
};
let retain = buffer.len() - (buffer.len() as f32 * stream_churn) as usize; for i in stream_retain..buffer.len() {
for i in retain..buffer.len() { buffer[i - stream_retain] = buffer[i]
buffer[i - retain] = buffer[i]
} }
buffer.truncate(retain); buffer.truncate(stream_retain);
last = Instant::now(); last = Instant::now();
} }
}); });
@@ -64,29 +63,83 @@ fn main() {
} }
struct Whisper { struct Whisper {
ctx: WhisperContext, jobs: std::sync::mpsc::SyncSender<AWhisper>,
threads: i32,
} }
fn new_whisper(model_path: String, threads: i32) -> Result<Whisper, String> { struct WhisperEngine {
match WhisperContext::new(&model_path) { ctx: WhisperContext,
Ok(ctx) => Ok(Whisper{ threads: i32,
ctx: ctx, handler: Handler,
threads: threads, }
}),
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), fn new_whisper(model_path: String, threads: i32, handler: Handler) -> Result<Whisper, String> {
match new_whisper_engine(model_path, threads, handler) {
Ok(engine) => {
let (send, recv) = std::sync::mpsc::sync_channel(100);
thread::spawn(move || { engine.transcribe_asyncs(recv); });
Ok(Whisper{jobs: send})
},
Err(msg) => Err(format!("failed to initialize engine: {}", msg)),
} }
} }
impl Whisper { impl Whisper {
fn transcribe(&self, data: &Vec<f32>) -> Result<String, String> { fn transcribe(&self, data: &Vec<f32>) {
match self._transcribe(&data) { let (send, recv) = std::sync::mpsc::sync_channel(1);
Ok(result) => Ok(result), self._transcribe_async(data, Some(send)).unwrap();
Err(msg) => Err(format!("failed to transcribe: {}", msg)), recv.recv().unwrap();
}
fn transcribe_async(&self, data: &Vec<f32>) -> Result<(), String> {
self._transcribe_async(data, None)
}
fn _transcribe_async(&self, data: &Vec<f32>, ack: Option<std::sync::mpsc::SyncSender<bool>>) -> Result<(), String> {
match self.jobs.try_send(AWhisper{
data: data.clone().to_vec(),
ack: ack,
}) {
Ok(_) => Ok(()),
Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)),
}
}
}
fn new_whisper_engine(model_path: String, threads: i32, handler: Handler) -> Result<WhisperEngine, String> {
match WhisperContext::new(&model_path) {
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads, handler: handler}),
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
}
}
impl WhisperEngine {
fn transcribe_asyncs(&self, recv: std::sync::mpsc::Receiver<AWhisper>) {
loop {
match recv.recv() {
Ok(job) => {
match self.transcribe(&job.data) {
Ok(result) => {
self.handler.on_success(result);
match job.ack {
Some(ack) => { let _ = ack.send(true); },
None => (),
};
},
Err(msg) => {
self.handler.on_error(format!("failed to transcribe: {}", msg));
match job.ack {
Some(ack) => { let _ = ack.send(false); },
None => (),
};
},
};
},
Err(_) => return,
};
} }
} }
fn _transcribe(&self, data: &Vec<f32>) -> Result<String, WhisperError> { fn transcribe(&self, data: &Vec<f32>) -> Result<String, WhisperError> {
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
params.set_no_context(true); params.set_no_context(true);
params.set_n_threads(self.threads); params.set_n_threads(self.threads);
@@ -112,6 +165,23 @@ impl Whisper {
} }
} }
struct AWhisper {
data: Vec<f32>,
ack: Option<std::sync::mpsc::SyncSender<bool>>,
}
struct Handler {}
impl Handler {
fn on_success(&self, result: String) {
eprintln!("{}", chrono::Local::now());
println!("{}", result);
}
fn on_error(&self, msg: String) {
eprintln!("error: {}", msg);
}
}
struct Listener { struct Listener {
} }
@@ -121,7 +191,7 @@ fn new_listener() -> Listener {
impl Listener { impl Listener {
fn listen(self, mut cb: impl FnMut(Vec<f32>)) { fn listen(self, mut cb: impl FnMut(Vec<f32>)) {
let (send, recv) = std::sync::mpsc::sync_channel(5_000_000); let (send, recv) = std::sync::mpsc::sync_channel(5_000);
thread::spawn(move || { self._listen(send); }); thread::spawn(move || { self._listen(send); });
loop { loop {
match recv.recv() { match recv.recv() {
@@ -131,7 +201,7 @@ impl Listener {
} }
} }
fn _listen(self, ch: std::sync::mpsc::SyncSender<Vec<f32>>) { fn _listen(self, send: std::sync::mpsc::SyncSender<Vec<f32>>) {
let host = cpal::default_host(); let host = cpal::default_host();
let device = host.default_input_device().unwrap(); let device = host.default_input_device().unwrap();
let cfg = device.supported_input_configs() let cfg = device.supported_input_configs()
@@ -141,7 +211,7 @@ impl Listener {
.unwrap() .unwrap()
.with_max_sample_rate(); .with_max_sample_rate();
let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16000.0); let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0);
let stream = device.build_input_stream( let stream = device.build_input_stream(
&cfg.clone().into(), &cfg.clone().into(),
move |data: &[f32], _: &cpal::InputCallbackInfo| { move |data: &[f32], _: &cpal::InputCallbackInfo| {
@@ -153,7 +223,7 @@ impl Listener {
} }
downsampled_data.push(data[upsampled as usize]); downsampled_data.push(data[upsampled as usize]);
} }
match ch.send(downsampled_data) { match send.try_send(downsampled_data) {
Ok(_) => (), Ok(_) => (),
Err(msg) => eprintln!("failed to ingest audio: {}", msg), Err(msg) => eprintln!("failed to ingest audio: {}", msg),
}; };