Compare commits
3 Commits
eea4b75bc8
...
532ae22908
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
532ae22908 | ||
|
|
deffc420ca | ||
|
|
2391d07994 |
134
src/main.rs
134
src/main.rs
@@ -14,8 +14,8 @@ struct Flags {
|
|||||||
#[arg(long, default_value = "8")]
|
#[arg(long, default_value = "8")]
|
||||||
threads: i32,
|
threads: i32,
|
||||||
|
|
||||||
#[arg(long, default_value = "0.8")]
|
#[arg(long, default_value = "1.0")]
|
||||||
stream_churn: f32,
|
stream_retain: f32,
|
||||||
#[arg(long, default_value = "5")]
|
#[arg(long, default_value = "5")]
|
||||||
stream_step: u64,
|
stream_step: u64,
|
||||||
|
|
||||||
@@ -25,8 +25,8 @@ struct Flags {
|
|||||||
fn main() {
|
fn main() {
|
||||||
let flags = Flags::parse();
|
let flags = Flags::parse();
|
||||||
|
|
||||||
let w = new_whisper(flags.model, flags.threads).unwrap();
|
let w = new_whisper(flags.model, flags.threads, Handler{}).unwrap();
|
||||||
let stream_churn = flags.stream_churn;
|
let stream_retain = (flags.stream_retain * 16_000.0) as usize;
|
||||||
let stream_step = Duration::new(flags.stream_step, 0);
|
let stream_step = Duration::new(flags.stream_step, 0);
|
||||||
match flags.wav {
|
match flags.wav {
|
||||||
Some(wav) => {
|
Some(wav) => {
|
||||||
@@ -34,12 +34,11 @@ fn main() {
|
|||||||
&mut std::fs::File::open(wav).expect("failed to open $WAV"),
|
&mut std::fs::File::open(wav).expect("failed to open $WAV"),
|
||||||
).expect("failed to decode $WAV");
|
).expect("failed to decode $WAV");
|
||||||
assert!(header.channel_count == 1);
|
assert!(header.channel_count == 1);
|
||||||
assert!(header.sampling_rate == 16000);
|
assert!(header.sampling_rate == 16_000);
|
||||||
let data16 = data.as_sixteen().expect("wav is not 32bit floats");
|
let data16 = data.as_sixteen().expect("wav is not 32bit floats");
|
||||||
let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16);
|
let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16);
|
||||||
|
|
||||||
let result = w.transcribe(&audio_data).unwrap();
|
w.transcribe(&audio_data);
|
||||||
println!("{}", result);
|
|
||||||
},
|
},
|
||||||
None => {
|
None => {
|
||||||
let mut buffer = vec![];
|
let mut buffer = vec![];
|
||||||
@@ -47,15 +46,15 @@ fn main() {
|
|||||||
new_listener().listen(move |data: Vec<f32>| {
|
new_listener().listen(move |data: Vec<f32>| {
|
||||||
data.iter().for_each(|x| buffer.push(*x));
|
data.iter().for_each(|x| buffer.push(*x));
|
||||||
if Instant::now() - last > stream_step {
|
if Instant::now() - last > stream_step {
|
||||||
let result = w.transcribe(&buffer).unwrap();
|
match w.transcribe_async(&buffer) {
|
||||||
eprintln!("{}", chrono::Local::now());
|
Ok(_) => (),
|
||||||
println!("{}", result);
|
Err(msg) => eprintln!("{}", msg),
|
||||||
|
};
|
||||||
|
|
||||||
let retain = buffer.len() - (buffer.len() as f32 * stream_churn) as usize;
|
for i in stream_retain..buffer.len() {
|
||||||
for i in retain..buffer.len() {
|
buffer[i - stream_retain] = buffer[i]
|
||||||
buffer[i - retain] = buffer[i]
|
|
||||||
}
|
}
|
||||||
buffer.truncate(retain);
|
buffer.truncate(stream_retain);
|
||||||
last = Instant::now();
|
last = Instant::now();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -64,29 +63,83 @@ fn main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Whisper {
|
struct Whisper {
|
||||||
ctx: WhisperContext,
|
jobs: std::sync::mpsc::SyncSender<AWhisper>,
|
||||||
threads: i32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_whisper(model_path: String, threads: i32) -> Result<Whisper, String> {
|
struct WhisperEngine {
|
||||||
match WhisperContext::new(&model_path) {
|
ctx: WhisperContext,
|
||||||
Ok(ctx) => Ok(Whisper{
|
threads: i32,
|
||||||
ctx: ctx,
|
handler: Handler,
|
||||||
threads: threads,
|
}
|
||||||
}),
|
|
||||||
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
|
fn new_whisper(model_path: String, threads: i32, handler: Handler) -> Result<Whisper, String> {
|
||||||
|
match new_whisper_engine(model_path, threads, handler) {
|
||||||
|
Ok(engine) => {
|
||||||
|
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
||||||
|
thread::spawn(move || { engine.transcribe_asyncs(recv); });
|
||||||
|
Ok(Whisper{jobs: send})
|
||||||
|
},
|
||||||
|
Err(msg) => Err(format!("failed to initialize engine: {}", msg)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Whisper {
|
impl Whisper {
|
||||||
fn transcribe(&self, data: &Vec<f32>) -> Result<String, String> {
|
fn transcribe(&self, data: &Vec<f32>) {
|
||||||
match self._transcribe(&data) {
|
let (send, recv) = std::sync::mpsc::sync_channel(1);
|
||||||
Ok(result) => Ok(result),
|
self._transcribe_async(data, Some(send)).unwrap();
|
||||||
Err(msg) => Err(format!("failed to transcribe: {}", msg)),
|
recv.recv().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transcribe_async(&self, data: &Vec<f32>) -> Result<(), String> {
|
||||||
|
self._transcribe_async(data, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _transcribe_async(&self, data: &Vec<f32>, ack: Option<std::sync::mpsc::SyncSender<bool>>) -> Result<(), String> {
|
||||||
|
match self.jobs.try_send(AWhisper{
|
||||||
|
data: data.clone().to_vec(),
|
||||||
|
ack: ack,
|
||||||
|
}) {
|
||||||
|
Ok(_) => Ok(()),
|
||||||
|
Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_whisper_engine(model_path: String, threads: i32, handler: Handler) -> Result<WhisperEngine, String> {
|
||||||
|
match WhisperContext::new(&model_path) {
|
||||||
|
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads, handler: handler}),
|
||||||
|
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WhisperEngine {
|
||||||
|
fn transcribe_asyncs(&self, recv: std::sync::mpsc::Receiver<AWhisper>) {
|
||||||
|
loop {
|
||||||
|
match recv.recv() {
|
||||||
|
Ok(job) => {
|
||||||
|
match self.transcribe(&job.data) {
|
||||||
|
Ok(result) => {
|
||||||
|
self.handler.on_success(result);
|
||||||
|
match job.ack {
|
||||||
|
Some(ack) => { let _ = ack.send(true); },
|
||||||
|
None => (),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
Err(msg) => {
|
||||||
|
self.handler.on_error(format!("failed to transcribe: {}", msg));
|
||||||
|
match job.ack {
|
||||||
|
Some(ack) => { let _ = ack.send(false); },
|
||||||
|
None => (),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
Err(_) => return,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _transcribe(&self, data: &Vec<f32>) -> Result<String, WhisperError> {
|
fn transcribe(&self, data: &Vec<f32>) -> Result<String, WhisperError> {
|
||||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
|
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
|
||||||
params.set_no_context(true);
|
params.set_no_context(true);
|
||||||
params.set_n_threads(self.threads);
|
params.set_n_threads(self.threads);
|
||||||
@@ -112,6 +165,23 @@ impl Whisper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct AWhisper {
|
||||||
|
data: Vec<f32>,
|
||||||
|
ack: Option<std::sync::mpsc::SyncSender<bool>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Handler {}
|
||||||
|
|
||||||
|
impl Handler {
|
||||||
|
fn on_success(&self, result: String) {
|
||||||
|
eprintln!("{}", chrono::Local::now());
|
||||||
|
println!("{}", result);
|
||||||
|
}
|
||||||
|
fn on_error(&self, msg: String) {
|
||||||
|
eprintln!("error: {}", msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Listener {
|
struct Listener {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +191,7 @@ fn new_listener() -> Listener {
|
|||||||
|
|
||||||
impl Listener {
|
impl Listener {
|
||||||
fn listen(self, mut cb: impl FnMut(Vec<f32>)) {
|
fn listen(self, mut cb: impl FnMut(Vec<f32>)) {
|
||||||
let (send, recv) = std::sync::mpsc::sync_channel(5_000_000);
|
let (send, recv) = std::sync::mpsc::sync_channel(5_000);
|
||||||
thread::spawn(move || { self._listen(send); });
|
thread::spawn(move || { self._listen(send); });
|
||||||
loop {
|
loop {
|
||||||
match recv.recv() {
|
match recv.recv() {
|
||||||
@@ -131,7 +201,7 @@ impl Listener {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _listen(self, ch: std::sync::mpsc::SyncSender<Vec<f32>>) {
|
fn _listen(self, send: std::sync::mpsc::SyncSender<Vec<f32>>) {
|
||||||
let host = cpal::default_host();
|
let host = cpal::default_host();
|
||||||
let device = host.default_input_device().unwrap();
|
let device = host.default_input_device().unwrap();
|
||||||
let cfg = device.supported_input_configs()
|
let cfg = device.supported_input_configs()
|
||||||
@@ -141,7 +211,7 @@ impl Listener {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.with_max_sample_rate();
|
.with_max_sample_rate();
|
||||||
|
|
||||||
let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16000.0);
|
let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0);
|
||||||
let stream = device.build_input_stream(
|
let stream = device.build_input_stream(
|
||||||
&cfg.clone().into(),
|
&cfg.clone().into(),
|
||||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||||
@@ -153,7 +223,7 @@ impl Listener {
|
|||||||
}
|
}
|
||||||
downsampled_data.push(data[upsampled as usize]);
|
downsampled_data.push(data[upsampled as usize]);
|
||||||
}
|
}
|
||||||
match ch.send(downsampled_data) {
|
match send.try_send(downsampled_data) {
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
Err(msg) => eprintln!("failed to ingest audio: {}", msg),
|
Err(msg) => eprintln!("failed to ingest audio: {}", msg),
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user