From 437d7cac39da1ac37fa9d1b9151cd937c4699748 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Tue, 28 Nov 2023 19:18:05 -0700 Subject: [PATCH] successful refactor --- rust-whisper.d/src/main.rs | 55 +++++++++++++++----------------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/rust-whisper.d/src/main.rs b/rust-whisper.d/src/main.rs index a77247a..0e4397f 100644 --- a/rust-whisper.d/src/main.rs +++ b/rust-whisper.d/src/main.rs @@ -1,23 +1,10 @@ use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError}; fn main() { - let ctx = WhisperContext::new( - &std::env::var("MODEL").unwrap_or(String::from("../models/ggml-tiny.en.bin")) - ).expect("failed to load $MODEL"); - let mut state = ctx.create_state().expect("failed to create state"); - - // create a params object - let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); - params.set_n_threads( - std::env::var("P").unwrap_or(String::from("8")).parse::().expect("$P must be a number") - ); - params.set_translate(false); - params.set_detect_language(false); - params.set_language(Some("en")); - params.set_print_special(false); - params.set_print_progress(false); - params.set_print_realtime(false); - params.set_print_timestamps(false); + let w = new_whisper( + std::env::var("MODEL").unwrap_or(String::from("../models/ggml-tiny.en.bin")), + std::env::var("P").unwrap_or(String::from("8")).parse::().expect("$P must be a number"), + ).unwrap(); let (header, data) = wav::read(&mut std::fs::File::open( &std::env::var("WAV").unwrap_or(String::from("../git.d/samples/jfk.wav")) @@ -27,43 +14,40 @@ fn main() { let data16 = data.as_sixteen().expect("wav is not 32bit floats"); let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16); - state.full(params, &audio_data[..]).expect("failed to run model"); - - let num_segments = state.full_n_segments().expect("failed to get number of segments"); - for i in 0..num_segments { - let segment = state.full_get_segment_text(i).expect("failed to get segment"); - print!("{} ", segment); + for _ in 0..3 { + let result = w.transcribe(&audio_data).unwrap(); + println!("{}", result); } - println!(""); } struct Whisper { ctx: WhisperContext, + threads: i32, } -fn new_whisper(model_path: String) -> Result { +fn new_whisper(model_path: String, threads: i32) -> Result { match WhisperContext::new(&model_path) { - Ok(ctx) => Ok(Whisper{ctx: ctx}), + Ok(ctx) => Ok(Whisper{ + ctx: ctx, + threads: threads, + }), Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), } } impl Whisper { - fn transcribe(&self, data: Vec) -> Result { - match self._transcribe(data) { + fn transcribe(&self, data: &Vec) -> Result { + match self._transcribe(&data) { Ok(result) => Ok(result), Err(msg) => Err(format!("failed to transcribe: {}", msg)), } } - fn _transcribe(&self, data: Vec) -> Result { - println!("{:?}", data); + fn _transcribe(&self, data: &Vec) -> Result { + //eprintln!("{:?} ({})", data, data.len()); - let mut state = self.ctx.create_state()?; let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); - params.set_n_threads( - std::env::var("P").unwrap_or(String::from("8")).parse::().expect("$P must be a number") - ); + params.set_n_threads(self.threads); params.set_translate(false); params.set_detect_language(false); params.set_language(Some("en")); @@ -71,6 +55,8 @@ impl Whisper { params.set_print_progress(false); params.set_print_realtime(false); params.set_print_timestamps(false); + + let mut state = self.ctx.create_state()?; state.full(params, &data[..])?; let num_segments = state.full_n_segments()?; @@ -79,6 +65,7 @@ impl Whisper { let segment = state.full_get_segment_text(i)?; result = format!("{} {}", result, segment); } + Ok(result) } }