refactor to whisper_service enqueues, whisper_impl transforms, whisper_engine provides raw

master
Bel LaPointe 2023-12-19 21:08:59 -05:00
parent 091958e08d
commit a2fee32fbc
1 changed files with 79 additions and 75 deletions

View File

@ -31,17 +31,10 @@ struct Flags {
debug: bool, debug: bool,
} }
static mut G_HANDLER: Handler = new_handler(&0.0, &0.0);
fn main() { fn main() {
let flags = Flags::parse(); let flags = Flags::parse();
unsafe { let w = new_whisper_service(flags.model, flags.threads, flags.stream_head, flags.stream_tail, new_handler()).unwrap();
G_HANDLER.head = flags.stream_head.clone();
G_HANDLER.tail = flags.stream_tail.clone();
}
let w = new_whisper(flags.model, flags.threads, new_handler(&flags.stream_head, &flags.stream_tail)).unwrap();
let stream_retain = (flags.stream_retain * 16_000.0) as usize; let stream_retain = (flags.stream_retain * 16_000.0) as usize;
let stream_step = Duration::new(flags.stream_step, 0); let stream_step = Duration::new(flags.stream_step, 0);
match flags.wav { match flags.wav {
@ -66,15 +59,7 @@ fn main() {
new_listener().listen(move |data: Vec<f32>| { new_listener().listen(move |data: Vec<f32>| {
data.iter().for_each(|x| buffer.push(*x)); data.iter().for_each(|x| buffer.push(*x));
if Instant::now() - last > stream_step { if Instant::now() - last > stream_step {
match w.transcribe_async(&buffer, |result| { w.transcribe_async(&buffer).unwrap();
match result {
Ok(whispered) => unsafe { G_HANDLER.on_success(&whispered) },
Err(msg) => unsafe { G_HANDLER.on_error(msg) },
};
}) {
Ok(_) => (),
Err(msg) => eprintln!("{}", msg),
};
match &flags.debug { match &flags.debug {
true => { true => {
@ -101,43 +86,37 @@ fn main() {
}; };
} }
struct Whisper { struct WhisperService {
jobs: std::sync::mpsc::SyncSender<AWhisper>, jobs: std::sync::mpsc::SyncSender<AWhisper>,
} }
struct WhisperEngine { fn new_whisper_service(model_path: String, threads: i32, stream_head: f32, stream_tail: f32, handler: Handler) -> Result<WhisperService, String> {
ctx: WhisperContext, match new_whisper_engine(model_path, threads) {
threads: i32,
handler: Handler,
}
fn new_whisper(model_path: String, threads: i32, handler: Handler) -> Result<Whisper, String> {
match new_whisper_engine(model_path, threads, handler) {
Ok(engine) => { Ok(engine) => {
let whisper = new_whisper_impl(engine, stream_head, stream_tail, handler);
let (send, recv) = std::sync::mpsc::sync_channel(100); let (send, recv) = std::sync::mpsc::sync_channel(100);
thread::spawn(move || { engine.transcribe_asyncs(recv); }); thread::spawn(move || { whisper.transcribe_asyncs(recv); });
Ok(Whisper{jobs: send}) Ok(WhisperService{jobs: send})
}, },
Err(msg) => Err(format!("failed to initialize engine: {}", msg)), Err(msg) => Err(format!("failed to initialize engine: {}", msg)),
} }
} }
impl Whisper { impl WhisperService {
fn transcribe(&self, data: &Vec<f32>) { fn transcribe(&self, data: &Vec<f32>) {
let (send, recv) = std::sync::mpsc::sync_channel(0); let (send, recv) = std::sync::mpsc::sync_channel(0);
self._transcribe_async(data, Some(send), None).unwrap(); self._transcribe_async(data, Some(send)).unwrap();
recv.recv().unwrap(); recv.recv().unwrap();
} }
fn transcribe_async(&self, data: &Vec<f32>, callback: fn(Result<Whispered, String>)) -> Result<(), String> { fn transcribe_async(&self, data: &Vec<f32>) -> Result<(), String> {
self._transcribe_async(data, None, Some(callback)) self._transcribe_async(data, None)
} }
fn _transcribe_async(&self, data: &Vec<f32>, ack: Option<std::sync::mpsc::SyncSender<bool>>, callback: Option<fn(Result<Whispered, String>)>) -> Result<(), String> { fn _transcribe_async(&self, data: &Vec<f32>, ack: Option<std::sync::mpsc::SyncSender<bool>>) -> Result<(), String> {
match self.jobs.try_send(AWhisper{ match self.jobs.try_send(AWhisper{
data: data.clone().to_vec(), data: data.clone().to_vec(),
ack: ack, ack: ack,
callback: callback,
}) { }) {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)), Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)),
@ -145,48 +124,80 @@ impl Whisper {
} }
} }
fn new_whisper_engine(model_path: String, threads: i32, handler: Handler) -> Result<WhisperEngine, String> { struct WhisperImpl {
match WhisperContext::new(&model_path) { engine: WhisperEngine,
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads, handler: handler}), stream_head: f32,
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), stream_tail: f32,
handler: Handler,
}
fn new_whisper_impl(engine: WhisperEngine, stream_head: f32, stream_tail: f32, handler: Handler) -> WhisperImpl {
WhisperImpl {
engine: engine,
stream_head: stream_head,
stream_tail: stream_tail,
handler: handler,
} }
} }
impl WhisperEngine { impl WhisperImpl {
fn transcribe_asyncs(&self, recv: std::sync::mpsc::Receiver<AWhisper>) { fn transcribe_asyncs(&self, recv: std::sync::mpsc::Receiver<AWhisper>) {
loop { loop {
match recv.recv() { match recv.recv() {
Ok(job) => { Ok(job) => {
match self.transcribe(&job.data) { let result = self.transcribe(&job).is_ok();
Ok(result) => {
self.handler.on_success(&result);
match job.ack { match job.ack {
Some(ack) => { let _ = ack.send(true); }, Some(ack) => {
None => (), ack.send(result).unwrap();
};
match job.callback {
Some(foo) => foo(Ok(result)),
None => (),
};
}, },
Err(msg) => {
self.handler.on_error(format!("failed to transcribe: {}", &msg));
match job.ack {
Some(ack) => { let _ = ack.send(false); },
None => (), None => (),
}; };
match job.callback { }
Some(foo) => foo(Err(format!("failed to transcribe: {}", &msg))),
None => (),
};
},
};
},
Err(_) => return, Err(_) => return,
}; };
} }
} }
fn transcribe(&self, a_whisper: &AWhisper) -> Result<(), ()> {
match self.engine.transcribe(&a_whisper.data) {
Ok(result) => {
self.on_success(&result);
Ok(())
},
Err(msg) => {
self.on_error(msg.to_string());
Err(())
},
}
}
fn on_success(&self, whispered: &Whispered) {
eprintln!("{}: {:?}", chrono::Local::now(), whispered);
self.handler.on_success(
&whispered
.after(&(self.stream_head * 100.0))
.before(&(self.stream_tail * 100.0)),
);
}
fn on_error(&self, msg: String) {
self.handler.on_error(format!("failed to transcribe: {}", &msg));
}
}
struct WhisperEngine {
ctx: WhisperContext,
threads: i32,
}
fn new_whisper_engine(model_path: String, threads: i32) -> Result<WhisperEngine, String> {
match WhisperContext::new(&model_path) {
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
}
}
impl WhisperEngine {
fn transcribe(&self, data: &Vec<f32>) -> Result<Whispered, WhisperError> { fn transcribe(&self, data: &Vec<f32>) -> Result<Whispered, WhisperError> {
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
params.set_no_context(true); params.set_no_context(true);
@ -218,7 +229,6 @@ impl WhisperEngine {
struct AWhisper { struct AWhisper {
data: Vec<f32>, data: Vec<f32>,
ack: Option<std::sync::mpsc::SyncSender<bool>>, ack: Option<std::sync::mpsc::SyncSender<bool>>,
callback: Option<fn(Result<Whispered, String>)>,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -289,24 +299,18 @@ impl Whispered {
} }
#[derive(Clone)] #[derive(Clone)]
struct Handler { struct Handler {}
head: f32,
tail: f32,
}
const fn new_handler(seconds_head: &f32, seconds_tail: &f32) -> Handler { const fn new_handler() -> Handler {
Handler{ Handler{}
head: *seconds_head,
tail: *seconds_tail,
}
} }
impl Handler { impl Handler {
fn on_success(&self, result: &Whispered) { fn on_success(&self, result: &Whispered) {
eprintln!("{}: {:?}", chrono::Local::now(), &result); eprintln!("{}: {:?}", chrono::Local::now(), &result);
println!("{}", result println!("{}", result
.after(&(self.head * 100.0)) .after(&(0.25 * 100.0))
.before(&(self.tail * 100.0)) .before(&(0.25 * 100.0))
.to_string(), .to_string(),
); );
} }
@ -347,7 +351,7 @@ impl Listener {
let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0); let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0);
let stream = device.build_input_stream( let stream = device.build_input_stream(
&cfg.clone().into(), &cfg.clone().into(),
move |data: &[f32], _: &cpal::InputCallbackInfo| { move |data: &[f32], _: &cpal::InputCallbackInfo| { // TODO why cant i do this...
let mut downsampled_data = vec![]; let mut downsampled_data = vec![];
for i in 0..(data.len() as f32 / downsample_ratio) as usize { for i in 0..(data.len() as f32 / downsample_ratio) as usize {
let mut upsampled = i as f32 * downsample_ratio; let mut upsampled = i as f32 * downsample_ratio;