accept model_path or model_buffer in flags
parent
82b24cabeb
commit
95f638d3d0
|
|
@ -10,7 +10,8 @@ use std::io::Write;
|
|||
#[derive(Parser, Debug)]
|
||||
pub struct Flags {
|
||||
#[arg(long, default_value = "./models/ggml-tiny.en.bin")]
|
||||
pub model: String,
|
||||
pub model_path: Option<String>,
|
||||
pub model_buffer: Option<Vec<u8>>,
|
||||
|
||||
#[arg(long, default_value = "8")]
|
||||
pub threads: i32,
|
||||
|
|
@ -32,7 +33,8 @@ pub struct Flags {
|
|||
|
||||
pub fn main<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Whispered, String>) + Send + 'static {
|
||||
let w = new_whisper_service(
|
||||
flags.model,
|
||||
flags.model_path,
|
||||
flags.model_buffer,
|
||||
flags.threads,
|
||||
flags.stream_head,
|
||||
flags.stream_tail,
|
||||
|
|
@ -93,8 +95,8 @@ struct WhisperService {
|
|||
jobs: std::sync::mpsc::SyncSender<AWhisper>,
|
||||
}
|
||||
|
||||
fn new_whisper_service<F>(model_path: String, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result<WhisperService, String> where F: FnMut(Result<Whispered, String>) + Send + 'static {
|
||||
match new_whisper_engine(model_path, threads) {
|
||||
fn new_whisper_service<F>(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result<WhisperService, String> where F: FnMut(Result<Whispered, String>) + Send + 'static {
|
||||
match new_whisper_engine(model_path, model_buffer, threads) {
|
||||
Ok(engine) => {
|
||||
let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn);
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
||||
|
|
@ -191,11 +193,22 @@ struct WhisperEngine {
|
|||
threads: i32,
|
||||
}
|
||||
|
||||
fn new_whisper_engine(model_path: String, threads: i32) -> Result<WhisperEngine, String> {
|
||||
match WhisperContext::new(&model_path) {
|
||||
fn new_whisper_engine(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32) -> Result<WhisperEngine, String> {
|
||||
if model_path.is_some() {
|
||||
let model_path = model_path.unwrap();
|
||||
return match WhisperContext::new(&model_path.clone()) {
|
||||
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
|
||||
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
|
||||
};
|
||||
}
|
||||
if model_buffer.is_some() {
|
||||
let model_buffer = model_buffer.unwrap();
|
||||
return match WhisperContext::new_from_buffer(&model_buffer) {
|
||||
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
|
||||
Err(msg) => Err(format!("failed to load buffer: {}", msg)),
|
||||
};
|
||||
}
|
||||
Err("neither model path nor buffer provided".to_string())
|
||||
}
|
||||
|
||||
impl WhisperEngine {
|
||||
|
|
|
|||
Loading…
Reference in New Issue