rust-whisper-lib::wav_channel()
parent
393100973c
commit
cd339de334
|
|
@ -72,7 +72,30 @@ pub fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver<Vec<f32>>) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
||||
pub fn wav_channel<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
||||
let path = flags.wav.as_ref().unwrap();
|
||||
let mut audio = f32_from_wav_file(&path).unwrap();
|
||||
let mut iter = vec![];
|
||||
let n = audio.len() / match audio.len() % 100 {
|
||||
0 => 100,
|
||||
_ => 99,
|
||||
};
|
||||
for _ in 0..100 {
|
||||
iter.push(audio.drain(0..n.clamp(0, audio.len())).collect());
|
||||
}
|
||||
let (fin_send, fin_recv) = std::sync::mpsc::sync_channel::<Option<i32>>(1);
|
||||
channel_and_close(flags.clone(), handler_fn, iter, move || { fin_send.send(None).unwrap(); });
|
||||
match fin_recv.recv() {
|
||||
Ok(_) => {},
|
||||
Err(x) => panic!("failed to receive: {}", x),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn channel<F, I>(flags: Flags, handler_fn: F, stream: I) where F: FnMut(Result<Transcribed, String>) + Send + 'static, I: IntoIterator<Item = Vec<f32>> {
|
||||
channel_and_close(flags, handler_fn, stream, || {});
|
||||
}
|
||||
|
||||
fn channel_and_close<F, I, G>(flags: Flags, handler_fn: F, stream: I, mut close_fn: G) where F: FnMut(Result<Transcribed, String>) + Send + 'static, I: IntoIterator<Item = Vec<f32>>, G: FnMut() + Send + 'static {
|
||||
let w = new_service(
|
||||
flags.model_path,
|
||||
flags.model_buffer,
|
||||
|
|
@ -87,7 +110,7 @@ pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver
|
|||
false => {},
|
||||
};
|
||||
let mut buffer = vec![];
|
||||
for data in stream.iter() {
|
||||
for data in stream {
|
||||
data.iter().for_each(|x| buffer.push(*x));
|
||||
if buffer.len() >= (flags.stream_step * 16_000) as usize {
|
||||
w.transcribe_async(&buffer).unwrap();
|
||||
|
|
@ -112,6 +135,10 @@ pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver
|
|||
buffer.truncate(stream_retain);
|
||||
}
|
||||
}
|
||||
if buffer.len() > 0 {
|
||||
w.transcribe(&buffer);
|
||||
}
|
||||
close_fn();
|
||||
}
|
||||
|
||||
struct Service {
|
||||
|
|
@ -339,6 +366,28 @@ impl Transcribed {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_transcribe_tiny_jfk_wav_whisper_rs_wav_channel() {
|
||||
wav_channel(
|
||||
Flags {
|
||||
model_path: None,
|
||||
model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()),
|
||||
threads: 8,
|
||||
stream_step: 30,
|
||||
stream_retain: 0.0,
|
||||
stream_head: 0.0,
|
||||
stream_tail: 0.0,
|
||||
wav: Some("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()),
|
||||
debug: false,
|
||||
stream_device: None,
|
||||
},
|
||||
move | result | {
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country.");
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transcribe_tiny_jfk_wav_whisper_rs() {
|
||||
wav(
|
||||
|
|
|
|||
Loading…
Reference in New Issue