just stream and wav

This commit is contained in:
Bel LaPointe
2023-12-20 09:33:57 -05:00
parent 6e32967067
commit 4f44697ef0
5 changed files with 8 additions and 13 deletions

View File

@@ -31,11 +31,7 @@ pub struct Flags {
pub wav: Option<String>,
}
pub fn main<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
main_wav(flags.clone(), handler_fn, flags.wav.unwrap());
}
pub fn main_wav<F>(flags: Flags, handler_fn: F, wav_path: String) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
pub fn wav<F>(flags: Flags, handler_fn: F, wav_path: String) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
let w = new_service(
flags.model_path,
flags.model_buffer,
@@ -55,7 +51,7 @@ pub fn main_wav<F>(flags: Flags, handler_fn: F, wav_path: String) where F: FnMut
w.transcribe(&audio_data);
}
pub fn main_stream<F>(flags: Flags, handler_fn: F, _stream: std::sync::mpsc::Receiver<Vec<f32>>) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver<Vec<f32>>) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
let w = new_service(
flags.model_path,
flags.model_buffer,
@@ -72,7 +68,7 @@ pub fn main_stream<F>(flags: Flags, handler_fn: F, _stream: std::sync::mpsc::Rec
};
let mut buffer = vec![];
let mut last = Instant::now();
for data in _stream.iter() {
for data in stream.iter() {
data.iter().for_each(|x| buffer.push(*x));
if Instant::now() - last > stream_step {
w.transcribe_async(&buffer).unwrap();
@@ -334,7 +330,7 @@ mod tests {
#[test]
fn test_transcribe_tiny_jfk_wav() {
main(
wav(
Flags {
model_path: None,
model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()),
@@ -350,6 +346,7 @@ mod tests {
assert!(result.is_ok());
assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country.");
},
"../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string(),
);
}
}