transcribing results as callbacks
parent
eea4b75bc8
commit
2391d07994
47
src/main.rs
47
src/main.rs
|
|
@ -14,8 +14,8 @@ struct Flags {
|
|||
#[arg(long, default_value = "8")]
|
||||
threads: i32,
|
||||
|
||||
#[arg(long, default_value = "0.8")]
|
||||
stream_churn: f32,
|
||||
#[arg(long, default_value = "1.0")]
|
||||
stream_retain: f32,
|
||||
#[arg(long, default_value = "5")]
|
||||
stream_step: u64,
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ fn main() {
|
|||
let flags = Flags::parse();
|
||||
|
||||
let w = new_whisper(flags.model, flags.threads).unwrap();
|
||||
let stream_churn = flags.stream_churn;
|
||||
let stream_retain = (flags.stream_retain * 16_000.0) as usize;
|
||||
let stream_step = Duration::new(flags.stream_step, 0);
|
||||
match flags.wav {
|
||||
Some(wav) => {
|
||||
|
|
@ -34,12 +34,14 @@ fn main() {
|
|||
&mut std::fs::File::open(wav).expect("failed to open $WAV"),
|
||||
).expect("failed to decode $WAV");
|
||||
assert!(header.channel_count == 1);
|
||||
assert!(header.sampling_rate == 16000);
|
||||
assert!(header.sampling_rate == 16_000);
|
||||
let data16 = data.as_sixteen().expect("wav is not 32bit floats");
|
||||
let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16);
|
||||
|
||||
let result = w.transcribe(&audio_data).unwrap();
|
||||
println!("{}", result);
|
||||
w.transcribe(&audio_data,
|
||||
|result| println!("{}", result),
|
||||
|err| eprintln!("failed to transcribe wav: {}", err),
|
||||
);
|
||||
},
|
||||
None => {
|
||||
let mut buffer = vec![];
|
||||
|
|
@ -47,15 +49,18 @@ fn main() {
|
|||
new_listener().listen(move |data: Vec<f32>| {
|
||||
data.iter().for_each(|x| buffer.push(*x));
|
||||
if Instant::now() - last > stream_step {
|
||||
let result = w.transcribe(&buffer).unwrap();
|
||||
eprintln!("{}", chrono::Local::now());
|
||||
println!("{}", result);
|
||||
w.transcribe(&buffer,
|
||||
|result| {
|
||||
eprintln!("{}", chrono::Local::now());
|
||||
println!("{}", result);
|
||||
},
|
||||
|err| eprintln!("failed to transcribe stream: {}", err),
|
||||
);
|
||||
|
||||
let retain = buffer.len() - (buffer.len() as f32 * stream_churn) as usize;
|
||||
for i in retain..buffer.len() {
|
||||
buffer[i - retain] = buffer[i]
|
||||
for i in stream_retain..buffer.len() {
|
||||
buffer[i - stream_retain] = buffer[i]
|
||||
}
|
||||
buffer.truncate(retain);
|
||||
buffer.truncate(stream_retain);
|
||||
last = Instant::now();
|
||||
}
|
||||
});
|
||||
|
|
@ -79,11 +84,11 @@ fn new_whisper(model_path: String, threads: i32) -> Result<Whisper, String> {
|
|||
}
|
||||
|
||||
impl Whisper {
|
||||
fn transcribe(&self, data: &Vec<f32>) -> Result<String, String> {
|
||||
fn transcribe(&self, data: &Vec<f32>, on_success: impl Fn(String), on_error: impl Fn(String)) {
|
||||
match self._transcribe(&data) {
|
||||
Ok(result) => Ok(result),
|
||||
Err(msg) => Err(format!("failed to transcribe: {}", msg)),
|
||||
}
|
||||
Ok(result) => on_success(result),
|
||||
Err(msg) => on_error(format!("failed to transcribe: {}", msg)),
|
||||
};
|
||||
}
|
||||
|
||||
fn _transcribe(&self, data: &Vec<f32>) -> Result<String, WhisperError> {
|
||||
|
|
@ -121,7 +126,7 @@ fn new_listener() -> Listener {
|
|||
|
||||
impl Listener {
|
||||
fn listen(self, mut cb: impl FnMut(Vec<f32>)) {
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(5_000_000);
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(5_000);
|
||||
thread::spawn(move || { self._listen(send); });
|
||||
loop {
|
||||
match recv.recv() {
|
||||
|
|
@ -131,7 +136,7 @@ impl Listener {
|
|||
}
|
||||
}
|
||||
|
||||
fn _listen(self, ch: std::sync::mpsc::SyncSender<Vec<f32>>) {
|
||||
fn _listen(self, send: std::sync::mpsc::SyncSender<Vec<f32>>) {
|
||||
let host = cpal::default_host();
|
||||
let device = host.default_input_device().unwrap();
|
||||
let cfg = device.supported_input_configs()
|
||||
|
|
@ -141,7 +146,7 @@ impl Listener {
|
|||
.unwrap()
|
||||
.with_max_sample_rate();
|
||||
|
||||
let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16000.0);
|
||||
let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0);
|
||||
let stream = device.build_input_stream(
|
||||
&cfg.clone().into(),
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
|
|
@ -153,7 +158,7 @@ impl Listener {
|
|||
}
|
||||
downsampled_data.push(data[upsampled as usize]);
|
||||
}
|
||||
match ch.send(downsampled_data) {
|
||||
match send.try_send(downsampled_data) {
|
||||
Ok(_) => (),
|
||||
Err(msg) => eprintln!("failed to ingest audio: {}", msg),
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue