From 95f638d3d03cd3c28348f6b574b82d68da4ec7c9 Mon Sep 17 00:00:00 2001 From: Bel LaPointe <153096461+breel-render@users.noreply.github.com> Date: Tue, 19 Dec 2023 22:14:22 -0500 Subject: [PATCH] accept model_path or model_buffer in flags --- rust-whisper-lib/src/lib.rs | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/rust-whisper-lib/src/lib.rs b/rust-whisper-lib/src/lib.rs index 04b0a5c..fe4f2ee 100644 --- a/rust-whisper-lib/src/lib.rs +++ b/rust-whisper-lib/src/lib.rs @@ -10,7 +10,8 @@ use std::io::Write; #[derive(Parser, Debug)] pub struct Flags { #[arg(long, default_value = "./models/ggml-tiny.en.bin")] - pub model: String, + pub model_path: Option, + pub model_buffer: Option>, #[arg(long, default_value = "8")] pub threads: i32, @@ -32,7 +33,8 @@ pub struct Flags { pub fn main(flags: Flags, handler_fn: F) where F: FnMut(Result) + Send + 'static { let w = new_whisper_service( - flags.model, + flags.model_path, + flags.model_buffer, flags.threads, flags.stream_head, flags.stream_tail, @@ -93,8 +95,8 @@ struct WhisperService { jobs: std::sync::mpsc::SyncSender, } -fn new_whisper_service(model_path: String, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result where F: FnMut(Result) + Send + 'static { - match new_whisper_engine(model_path, threads) { +fn new_whisper_service(model_path: Option, model_buffer: Option>, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result where F: FnMut(Result) + Send + 'static { + match new_whisper_engine(model_path, model_buffer, threads) { Ok(engine) => { let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn); let (send, recv) = std::sync::mpsc::sync_channel(100); @@ -191,11 +193,22 @@ struct WhisperEngine { threads: i32, } -fn new_whisper_engine(model_path: String, threads: i32) -> Result { - match WhisperContext::new(&model_path) { - Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}), - Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), +fn new_whisper_engine(model_path: Option, model_buffer: Option>, threads: i32) -> Result { + if model_path.is_some() { + let model_path = model_path.unwrap(); + return match WhisperContext::new(&model_path.clone()) { + Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}), + Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), + }; } + if model_buffer.is_some() { + let model_buffer = model_buffer.unwrap(); + return match WhisperContext::new_from_buffer(&model_buffer) { + Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}), + Err(msg) => Err(format!("failed to load buffer: {}", msg)), + }; + } + Err("neither model path nor buffer provided".to_string()) } impl WhisperEngine {