refactor to whisper_service enqueues, whisper_impl transforms, whisper_engine provides raw
parent
091958e08d
commit
a2fee32fbc
150
src/main.rs
150
src/main.rs
|
|
@ -31,17 +31,10 @@ struct Flags {
|
|||
debug: bool,
|
||||
}
|
||||
|
||||
static mut G_HANDLER: Handler = new_handler(&0.0, &0.0);
|
||||
|
||||
fn main() {
|
||||
let flags = Flags::parse();
|
||||
|
||||
unsafe {
|
||||
G_HANDLER.head = flags.stream_head.clone();
|
||||
G_HANDLER.tail = flags.stream_tail.clone();
|
||||
}
|
||||
|
||||
let w = new_whisper(flags.model, flags.threads, new_handler(&flags.stream_head, &flags.stream_tail)).unwrap();
|
||||
let w = new_whisper_service(flags.model, flags.threads, flags.stream_head, flags.stream_tail, new_handler()).unwrap();
|
||||
let stream_retain = (flags.stream_retain * 16_000.0) as usize;
|
||||
let stream_step = Duration::new(flags.stream_step, 0);
|
||||
match flags.wav {
|
||||
|
|
@ -66,15 +59,7 @@ fn main() {
|
|||
new_listener().listen(move |data: Vec<f32>| {
|
||||
data.iter().for_each(|x| buffer.push(*x));
|
||||
if Instant::now() - last > stream_step {
|
||||
match w.transcribe_async(&buffer, |result| {
|
||||
match result {
|
||||
Ok(whispered) => unsafe { G_HANDLER.on_success(&whispered) },
|
||||
Err(msg) => unsafe { G_HANDLER.on_error(msg) },
|
||||
};
|
||||
}) {
|
||||
Ok(_) => (),
|
||||
Err(msg) => eprintln!("{}", msg),
|
||||
};
|
||||
w.transcribe_async(&buffer).unwrap();
|
||||
|
||||
match &flags.debug {
|
||||
true => {
|
||||
|
|
@ -101,43 +86,37 @@ fn main() {
|
|||
};
|
||||
}
|
||||
|
||||
struct Whisper {
|
||||
struct WhisperService {
|
||||
jobs: std::sync::mpsc::SyncSender<AWhisper>,
|
||||
}
|
||||
|
||||
struct WhisperEngine {
|
||||
ctx: WhisperContext,
|
||||
threads: i32,
|
||||
handler: Handler,
|
||||
}
|
||||
|
||||
fn new_whisper(model_path: String, threads: i32, handler: Handler) -> Result<Whisper, String> {
|
||||
match new_whisper_engine(model_path, threads, handler) {
|
||||
fn new_whisper_service(model_path: String, threads: i32, stream_head: f32, stream_tail: f32, handler: Handler) -> Result<WhisperService, String> {
|
||||
match new_whisper_engine(model_path, threads) {
|
||||
Ok(engine) => {
|
||||
let whisper = new_whisper_impl(engine, stream_head, stream_tail, handler);
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
||||
thread::spawn(move || { engine.transcribe_asyncs(recv); });
|
||||
Ok(Whisper{jobs: send})
|
||||
thread::spawn(move || { whisper.transcribe_asyncs(recv); });
|
||||
Ok(WhisperService{jobs: send})
|
||||
},
|
||||
Err(msg) => Err(format!("failed to initialize engine: {}", msg)),
|
||||
}
|
||||
}
|
||||
|
||||
impl Whisper {
|
||||
impl WhisperService {
|
||||
fn transcribe(&self, data: &Vec<f32>) {
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(0);
|
||||
self._transcribe_async(data, Some(send), None).unwrap();
|
||||
self._transcribe_async(data, Some(send)).unwrap();
|
||||
recv.recv().unwrap();
|
||||
}
|
||||
|
||||
fn transcribe_async(&self, data: &Vec<f32>, callback: fn(Result<Whispered, String>)) -> Result<(), String> {
|
||||
self._transcribe_async(data, None, Some(callback))
|
||||
fn transcribe_async(&self, data: &Vec<f32>) -> Result<(), String> {
|
||||
self._transcribe_async(data, None)
|
||||
}
|
||||
|
||||
fn _transcribe_async(&self, data: &Vec<f32>, ack: Option<std::sync::mpsc::SyncSender<bool>>, callback: Option<fn(Result<Whispered, String>)>) -> Result<(), String> {
|
||||
fn _transcribe_async(&self, data: &Vec<f32>, ack: Option<std::sync::mpsc::SyncSender<bool>>) -> Result<(), String> {
|
||||
match self.jobs.try_send(AWhisper{
|
||||
data: data.clone().to_vec(),
|
||||
ack: ack,
|
||||
callback: callback,
|
||||
}) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)),
|
||||
|
|
@ -145,48 +124,80 @@ impl Whisper {
|
|||
}
|
||||
}
|
||||
|
||||
fn new_whisper_engine(model_path: String, threads: i32, handler: Handler) -> Result<WhisperEngine, String> {
|
||||
match WhisperContext::new(&model_path) {
|
||||
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads, handler: handler}),
|
||||
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
|
||||
struct WhisperImpl {
|
||||
engine: WhisperEngine,
|
||||
stream_head: f32,
|
||||
stream_tail: f32,
|
||||
handler: Handler,
|
||||
}
|
||||
|
||||
fn new_whisper_impl(engine: WhisperEngine, stream_head: f32, stream_tail: f32, handler: Handler) -> WhisperImpl {
|
||||
WhisperImpl {
|
||||
engine: engine,
|
||||
stream_head: stream_head,
|
||||
stream_tail: stream_tail,
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
impl WhisperEngine {
|
||||
impl WhisperImpl {
|
||||
fn transcribe_asyncs(&self, recv: std::sync::mpsc::Receiver<AWhisper>) {
|
||||
loop {
|
||||
match recv.recv() {
|
||||
Ok(job) => {
|
||||
match self.transcribe(&job.data) {
|
||||
Ok(result) => {
|
||||
self.handler.on_success(&result);
|
||||
let result = self.transcribe(&job).is_ok();
|
||||
match job.ack {
|
||||
Some(ack) => { let _ = ack.send(true); },
|
||||
None => (),
|
||||
};
|
||||
match job.callback {
|
||||
Some(foo) => foo(Ok(result)),
|
||||
None => (),
|
||||
};
|
||||
Some(ack) => {
|
||||
ack.send(result).unwrap();
|
||||
},
|
||||
Err(msg) => {
|
||||
self.handler.on_error(format!("failed to transcribe: {}", &msg));
|
||||
match job.ack {
|
||||
Some(ack) => { let _ = ack.send(false); },
|
||||
None => (),
|
||||
};
|
||||
match job.callback {
|
||||
Some(foo) => foo(Err(format!("failed to transcribe: {}", &msg))),
|
||||
None => (),
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
}
|
||||
Err(_) => return,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn transcribe(&self, a_whisper: &AWhisper) -> Result<(), ()> {
|
||||
match self.engine.transcribe(&a_whisper.data) {
|
||||
Ok(result) => {
|
||||
self.on_success(&result);
|
||||
Ok(())
|
||||
},
|
||||
Err(msg) => {
|
||||
self.on_error(msg.to_string());
|
||||
Err(())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn on_success(&self, whispered: &Whispered) {
|
||||
eprintln!("{}: {:?}", chrono::Local::now(), whispered);
|
||||
self.handler.on_success(
|
||||
&whispered
|
||||
.after(&(self.stream_head * 100.0))
|
||||
.before(&(self.stream_tail * 100.0)),
|
||||
);
|
||||
}
|
||||
|
||||
fn on_error(&self, msg: String) {
|
||||
self.handler.on_error(format!("failed to transcribe: {}", &msg));
|
||||
}
|
||||
}
|
||||
|
||||
struct WhisperEngine {
|
||||
ctx: WhisperContext,
|
||||
threads: i32,
|
||||
}
|
||||
|
||||
fn new_whisper_engine(model_path: String, threads: i32) -> Result<WhisperEngine, String> {
|
||||
match WhisperContext::new(&model_path) {
|
||||
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
|
||||
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
|
||||
}
|
||||
}
|
||||
|
||||
impl WhisperEngine {
|
||||
fn transcribe(&self, data: &Vec<f32>) -> Result<Whispered, WhisperError> {
|
||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
|
||||
params.set_no_context(true);
|
||||
|
|
@ -218,7 +229,6 @@ impl WhisperEngine {
|
|||
struct AWhisper {
|
||||
data: Vec<f32>,
|
||||
ack: Option<std::sync::mpsc::SyncSender<bool>>,
|
||||
callback: Option<fn(Result<Whispered, String>)>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
|
@ -289,24 +299,18 @@ impl Whispered {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Handler {
|
||||
head: f32,
|
||||
tail: f32,
|
||||
}
|
||||
struct Handler {}
|
||||
|
||||
const fn new_handler(seconds_head: &f32, seconds_tail: &f32) -> Handler {
|
||||
Handler{
|
||||
head: *seconds_head,
|
||||
tail: *seconds_tail,
|
||||
}
|
||||
const fn new_handler() -> Handler {
|
||||
Handler{}
|
||||
}
|
||||
|
||||
impl Handler {
|
||||
fn on_success(&self, result: &Whispered) {
|
||||
eprintln!("{}: {:?}", chrono::Local::now(), &result);
|
||||
println!("{}", result
|
||||
.after(&(self.head * 100.0))
|
||||
.before(&(self.tail * 100.0))
|
||||
.after(&(0.25 * 100.0))
|
||||
.before(&(0.25 * 100.0))
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
|
@ -347,7 +351,7 @@ impl Listener {
|
|||
let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0);
|
||||
let stream = device.build_input_stream(
|
||||
&cfg.clone().into(),
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| { // TODO why cant i do this...
|
||||
let mut downsampled_data = vec![];
|
||||
for i in 0..(data.len() as f32 / downsample_ratio) as usize {
|
||||
let mut upsampled = i as f32 * downsample_ratio;
|
||||
|
|
|
|||
Loading…
Reference in New Issue