accept model_path or model_buffer in flags

master
Bel LaPointe 2023-12-19 22:14:22 -05:00
parent 82b24cabeb
commit 95f638d3d0
1 changed files with 21 additions and 8 deletions

View File

@ -10,7 +10,8 @@ use std::io::Write;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
pub struct Flags { pub struct Flags {
#[arg(long, default_value = "./models/ggml-tiny.en.bin")] #[arg(long, default_value = "./models/ggml-tiny.en.bin")]
pub model: String, pub model_path: Option<String>,
pub model_buffer: Option<Vec<u8>>,
#[arg(long, default_value = "8")] #[arg(long, default_value = "8")]
pub threads: i32, pub threads: i32,
@ -32,7 +33,8 @@ pub struct Flags {
pub fn main<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Whispered, String>) + Send + 'static { pub fn main<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Whispered, String>) + Send + 'static {
let w = new_whisper_service( let w = new_whisper_service(
flags.model, flags.model_path,
flags.model_buffer,
flags.threads, flags.threads,
flags.stream_head, flags.stream_head,
flags.stream_tail, flags.stream_tail,
@ -93,8 +95,8 @@ struct WhisperService {
jobs: std::sync::mpsc::SyncSender<AWhisper>, jobs: std::sync::mpsc::SyncSender<AWhisper>,
} }
fn new_whisper_service<F>(model_path: String, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result<WhisperService, String> where F: FnMut(Result<Whispered, String>) + Send + 'static { fn new_whisper_service<F>(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result<WhisperService, String> where F: FnMut(Result<Whispered, String>) + Send + 'static {
match new_whisper_engine(model_path, threads) { match new_whisper_engine(model_path, model_buffer, threads) {
Ok(engine) => { Ok(engine) => {
let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn); let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn);
let (send, recv) = std::sync::mpsc::sync_channel(100); let (send, recv) = std::sync::mpsc::sync_channel(100);
@ -191,11 +193,22 @@ struct WhisperEngine {
threads: i32, threads: i32,
} }
fn new_whisper_engine(model_path: String, threads: i32) -> Result<WhisperEngine, String> { fn new_whisper_engine(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32) -> Result<WhisperEngine, String> {
match WhisperContext::new(&model_path) { if model_path.is_some() {
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}), let model_path = model_path.unwrap();
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), return match WhisperContext::new(&model_path.clone()) {
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
};
} }
if model_buffer.is_some() {
let model_buffer = model_buffer.unwrap();
return match WhisperContext::new_from_buffer(&model_buffer) {
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
Err(msg) => Err(format!("failed to load buffer: {}", msg)),
};
}
Err("neither model path nor buffer provided".to_string())
} }
impl WhisperEngine { impl WhisperEngine {