From 3093a91d848a9e78d502ff706599c7df544de9a5 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Tue, 28 Nov 2023 19:10:07 -0700 Subject: [PATCH] wip --- rust-whisper.d/src/main.rs | 48 +++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/rust-whisper.d/src/main.rs b/rust-whisper.d/src/main.rs index 05c0d5a..a77247a 100644 --- a/rust-whisper.d/src/main.rs +++ b/rust-whisper.d/src/main.rs @@ -1,4 +1,4 @@ -use whisper_rs::{WhisperContext, FullParams, SamplingStrategy}; +use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError}; fn main() { let ctx = WhisperContext::new( @@ -36,3 +36,49 @@ fn main() { } println!(""); } + +struct Whisper { + ctx: WhisperContext, +} + +fn new_whisper(model_path: String) -> Result { + match WhisperContext::new(&model_path) { + Ok(ctx) => Ok(Whisper{ctx: ctx}), + Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), + } +} + +impl Whisper { + fn transcribe(&self, data: Vec) -> Result { + match self._transcribe(data) { + Ok(result) => Ok(result), + Err(msg) => Err(format!("failed to transcribe: {}", msg)), + } + } + + fn _transcribe(&self, data: Vec) -> Result { + println!("{:?}", data); + + let mut state = self.ctx.create_state()?; + let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); + params.set_n_threads( + std::env::var("P").unwrap_or(String::from("8")).parse::().expect("$P must be a number") + ); + params.set_translate(false); + params.set_detect_language(false); + params.set_language(Some("en")); + params.set_print_special(false); + params.set_print_progress(false); + params.set_print_realtime(false); + params.set_print_timestamps(false); + state.full(params, &data[..])?; + + let num_segments = state.full_n_segments()?; + let mut result = "".to_string(); + for i in 0..num_segments { + let segment = state.full_get_segment_text(i)?; + result = format!("{} {}", result, segment); + } + Ok(result) + } +}