accept model_path or model_buffer in flags
parent
82b24cabeb
commit
95f638d3d0
|
|
@ -10,7 +10,8 @@ use std::io::Write;
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
pub struct Flags {
|
pub struct Flags {
|
||||||
#[arg(long, default_value = "./models/ggml-tiny.en.bin")]
|
#[arg(long, default_value = "./models/ggml-tiny.en.bin")]
|
||||||
pub model: String,
|
pub model_path: Option<String>,
|
||||||
|
pub model_buffer: Option<Vec<u8>>,
|
||||||
|
|
||||||
#[arg(long, default_value = "8")]
|
#[arg(long, default_value = "8")]
|
||||||
pub threads: i32,
|
pub threads: i32,
|
||||||
|
|
@ -32,7 +33,8 @@ pub struct Flags {
|
||||||
|
|
||||||
pub fn main<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Whispered, String>) + Send + 'static {
|
pub fn main<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Whispered, String>) + Send + 'static {
|
||||||
let w = new_whisper_service(
|
let w = new_whisper_service(
|
||||||
flags.model,
|
flags.model_path,
|
||||||
|
flags.model_buffer,
|
||||||
flags.threads,
|
flags.threads,
|
||||||
flags.stream_head,
|
flags.stream_head,
|
||||||
flags.stream_tail,
|
flags.stream_tail,
|
||||||
|
|
@ -93,8 +95,8 @@ struct WhisperService {
|
||||||
jobs: std::sync::mpsc::SyncSender<AWhisper>,
|
jobs: std::sync::mpsc::SyncSender<AWhisper>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_whisper_service<F>(model_path: String, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result<WhisperService, String> where F: FnMut(Result<Whispered, String>) + Send + 'static {
|
fn new_whisper_service<F>(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result<WhisperService, String> where F: FnMut(Result<Whispered, String>) + Send + 'static {
|
||||||
match new_whisper_engine(model_path, threads) {
|
match new_whisper_engine(model_path, model_buffer, threads) {
|
||||||
Ok(engine) => {
|
Ok(engine) => {
|
||||||
let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn);
|
let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn);
|
||||||
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
||||||
|
|
@ -191,11 +193,22 @@ struct WhisperEngine {
|
||||||
threads: i32,
|
threads: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_whisper_engine(model_path: String, threads: i32) -> Result<WhisperEngine, String> {
|
fn new_whisper_engine(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32) -> Result<WhisperEngine, String> {
|
||||||
match WhisperContext::new(&model_path) {
|
if model_path.is_some() {
|
||||||
|
let model_path = model_path.unwrap();
|
||||||
|
return match WhisperContext::new(&model_path.clone()) {
|
||||||
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
|
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
|
||||||
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
|
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
if model_buffer.is_some() {
|
||||||
|
let model_buffer = model_buffer.unwrap();
|
||||||
|
return match WhisperContext::new_from_buffer(&model_buffer) {
|
||||||
|
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
|
||||||
|
Err(msg) => Err(format!("failed to load buffer: {}", msg)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Err("neither model path nor buffer provided".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WhisperEngine {
|
impl WhisperEngine {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue