414 lines
12 KiB
Rust
414 lines
12 KiB
Rust
use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError};
|
|
use clap::Parser;
|
|
use std::thread;
|
|
use std::fs::File;
|
|
use std::io::Write;
|
|
|
|
#[derive(Parser, Debug, Clone)]
|
|
pub struct Flags {
|
|
#[arg(long, default_value = "../models/ggml-tiny.en.bin")]
|
|
pub model_path: Option<String>,
|
|
#[arg(long, default_value = None)]
|
|
pub model_buffer: Option<Vec<u8>>,
|
|
|
|
#[arg(long, default_value = "8")]
|
|
pub threads: i32,
|
|
|
|
#[arg(long, default_value = "5")]
|
|
pub stream_step: u64,
|
|
#[arg(long, default_value = "0.6")]
|
|
pub stream_retain: f32,
|
|
#[arg(long, default_value = "0.3")]
|
|
pub stream_head: f32,
|
|
#[arg(long, default_value = "0.3")]
|
|
pub stream_tail: f32,
|
|
|
|
#[arg(long, default_value = "false")]
|
|
pub debug: bool,
|
|
|
|
#[arg(long, default_value = None)]
|
|
pub wav: Option<String>,
|
|
#[arg(long, default_value = None)]
|
|
pub stream_device: Option<String>,
|
|
}
|
|
|
|
pub fn wav<F>(flags: Flags, handler_fn: F, wav_path: String) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
|
let w = new_service(
|
|
flags.model_path,
|
|
flags.model_buffer,
|
|
flags.threads,
|
|
flags.stream_head,
|
|
flags.stream_tail,
|
|
handler_fn,
|
|
).unwrap();
|
|
w.transcribe(&f32_from_wav_file(&wav_path).unwrap())
|
|
}
|
|
|
|
pub fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
|
|
let f = std::fs::File::open(path);
|
|
if let Some(err) = f.as_ref().err() {
|
|
return Err(format!("failed to open wav file: {}", err));
|
|
}
|
|
let wav_read = wav::read(&mut f.unwrap());
|
|
if let Some(err) = wav_read.as_ref().err() {
|
|
return Err(format!("failed to parse wav file: {}", err));
|
|
}
|
|
let (header, data) = wav_read.unwrap();
|
|
if header.channel_count != 1 {
|
|
return Err("!= 1 channel".to_string());
|
|
}
|
|
if header.sampling_rate != 16_000 {
|
|
return Err("!= 16_000 hz".to_string());
|
|
}
|
|
match data.as_sixteen() {
|
|
Some(data16) => {
|
|
let mut floats = Vec::with_capacity(data16.len());
|
|
for sample in data16 {
|
|
floats.push(*sample as f32 / 32768.0);
|
|
}
|
|
Ok(floats)
|
|
},
|
|
None => Err(format!("couldnt translate wav to 16s")),
|
|
}
|
|
}
|
|
|
|
pub fn wav_channel<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
|
let path = flags.wav.as_ref().unwrap();
|
|
let mut audio = f32_from_wav_file(&path).unwrap();
|
|
let mut iter = vec![];
|
|
let n = audio.len() / match audio.len() % 100 {
|
|
0 => 100,
|
|
_ => 99,
|
|
};
|
|
for _ in 0..100 {
|
|
iter.push(audio.drain(0..n.clamp(0, audio.len())).collect());
|
|
}
|
|
let (fin_send, fin_recv) = std::sync::mpsc::sync_channel::<Option<i32>>(1);
|
|
channel_and_close(flags.clone(), handler_fn, iter, move || { fin_send.send(None).unwrap(); });
|
|
match fin_recv.recv() {
|
|
Ok(_) => {},
|
|
Err(x) => panic!("failed to receive: {}", x),
|
|
};
|
|
}
|
|
|
|
pub fn channel<F, I>(flags: Flags, handler_fn: F, stream: I) where F: FnMut(Result<Transcribed, String>) + Send + 'static, I: IntoIterator<Item = Vec<f32>> {
|
|
channel_and_close(flags, handler_fn, stream, || {});
|
|
}
|
|
|
|
fn channel_and_close<F, I, G>(flags: Flags, handler_fn: F, stream: I, mut close_fn: G) where F: FnMut(Result<Transcribed, String>) + Send + 'static, I: IntoIterator<Item = Vec<f32>>, G: FnMut() + Send + 'static {
|
|
let w = new_service(
|
|
flags.model_path,
|
|
flags.model_buffer,
|
|
flags.threads,
|
|
flags.stream_head,
|
|
flags.stream_tail,
|
|
handler_fn,
|
|
).unwrap();
|
|
let stream_retain = (flags.stream_retain * 16_000.0) as usize;
|
|
match &flags.debug {
|
|
true => { File::create("/tmp/page.rawf32audio").unwrap(); },
|
|
false => {},
|
|
};
|
|
let mut buffer = vec![];
|
|
for data in stream {
|
|
data.iter().for_each(|x| buffer.push(*x));
|
|
if buffer.len() >= (flags.stream_step * 16_000) as usize {
|
|
w.transcribe_async(&buffer).unwrap();
|
|
|
|
match &flags.debug {
|
|
true => {
|
|
let mut f = File::options().append(true).open("/tmp/page.rawf32audio").unwrap();
|
|
let mut wav_data = vec![];
|
|
for i in buffer.iter() {
|
|
for j in i.to_le_bytes() {
|
|
wav_data.push(j);
|
|
}
|
|
}
|
|
f.write_all(wav_data.as_slice()).unwrap();
|
|
},
|
|
false => {},
|
|
};
|
|
|
|
for i in 0..stream_retain {
|
|
buffer[i] = buffer[buffer.len() - stream_retain + i];
|
|
}
|
|
buffer.truncate(stream_retain);
|
|
}
|
|
}
|
|
if buffer.len() > 0 {
|
|
w.transcribe(&buffer);
|
|
}
|
|
close_fn();
|
|
}
|
|
|
|
struct Service {
|
|
jobs: std::sync::mpsc::SyncSender<ATranscribe>,
|
|
}
|
|
|
|
fn new_service<F>(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result<Service, String> where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
|
match new_engine(model_path, model_buffer, threads) {
|
|
Ok(engine) => {
|
|
let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn);
|
|
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
|
thread::spawn(move || { whisper.transcribe_asyncs(recv); });
|
|
Ok(Service{jobs: send})
|
|
},
|
|
Err(msg) => Err(format!("failed to initialize engine: {}", msg)),
|
|
}
|
|
}
|
|
|
|
impl Service {
|
|
fn transcribe(&self, data: &Vec<f32>) {
|
|
let (send, recv) = std::sync::mpsc::sync_channel(0);
|
|
self._transcribe_async(data, Some(send)).unwrap();
|
|
recv.recv().unwrap();
|
|
}
|
|
|
|
fn transcribe_async(&self, data: &Vec<f32>) -> Result<(), String> {
|
|
self._transcribe_async(data, None)
|
|
}
|
|
|
|
fn _transcribe_async(&self, data: &Vec<f32>, ack: Option<std::sync::mpsc::SyncSender<bool>>) -> Result<(), String> {
|
|
match self.jobs.try_send(ATranscribe{
|
|
data: data.clone().to_vec(),
|
|
ack: ack,
|
|
}) {
|
|
Ok(_) => Ok(()),
|
|
Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct Impl {
|
|
engine: Engine,
|
|
stream_head: f32,
|
|
stream_tail: f32,
|
|
handler_fn: Option<Box<dyn FnMut(Result<Transcribed, String>) + Send + 'static>>
|
|
}
|
|
|
|
fn new_whisper_impl<F>(engine: Engine, stream_head: f32, stream_tail: f32, handler_fn: F) -> Impl where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
|
Impl {
|
|
engine: engine,
|
|
stream_head: stream_head,
|
|
stream_tail: stream_tail,
|
|
handler_fn: Some(Box::new(handler_fn)),
|
|
}
|
|
}
|
|
|
|
impl Impl {
|
|
fn transcribe_asyncs(&mut self, recv: std::sync::mpsc::Receiver<ATranscribe>) {
|
|
loop {
|
|
match recv.recv() {
|
|
Ok(job) => {
|
|
let result = self.transcribe(&job).is_ok();
|
|
match job.ack {
|
|
Some(ack) => {
|
|
ack.send(result).unwrap();
|
|
},
|
|
None => (),
|
|
};
|
|
}
|
|
Err(_) => return,
|
|
};
|
|
}
|
|
}
|
|
|
|
fn transcribe(&mut self, a_whisper: &ATranscribe) -> Result<(), ()> {
|
|
match self.engine.transcribe(&a_whisper.data) {
|
|
Ok(result) => {
|
|
self.on_success(&result);
|
|
Ok(())
|
|
},
|
|
Err(msg) => {
|
|
self.on_error(msg.to_string());
|
|
Err(())
|
|
},
|
|
}
|
|
}
|
|
|
|
fn on_success(&mut self, whispered: &Transcribed) {
|
|
let result = whispered
|
|
.after(&(self.stream_head * 100.0))
|
|
.before(&(self.stream_tail * 100.0));
|
|
(self.handler_fn.as_mut().unwrap())(Ok(result));
|
|
}
|
|
|
|
fn on_error(&mut self, msg: String) {
|
|
(self.handler_fn.as_mut().unwrap())(Err(format!("failed to transcribe: {}", &msg)));
|
|
}
|
|
}
|
|
|
|
struct Engine {
|
|
ctx: WhisperContext,
|
|
threads: i32,
|
|
}
|
|
|
|
fn new_engine(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32) -> Result<Engine, String> {
|
|
let whisper_context_result = match model_path {
|
|
Some(model_path) => WhisperContext::new(&model_path),
|
|
None => WhisperContext::new_from_buffer(&model_buffer.unwrap()),
|
|
};
|
|
match whisper_context_result {
|
|
Ok(ctx) => Ok(Engine{ctx: ctx, threads: threads}),
|
|
Err(msg) => Err(format!("failed to load model: {}", msg)),
|
|
}
|
|
}
|
|
|
|
impl Engine {
|
|
fn transcribe(&self, data: &Vec<f32>) -> Result<Transcribed, String> {
|
|
match self._transcribe(data) {
|
|
Ok(transcribed) => Ok(transcribed),
|
|
Err(msg) => Err(format!("{}", msg)),
|
|
}
|
|
}
|
|
|
|
fn _transcribe(&self, data: &Vec<f32>) -> Result<Transcribed, WhisperError> {
|
|
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
|
|
params.set_no_context(true);
|
|
params.set_n_threads(self.threads);
|
|
params.set_translate(false);
|
|
params.set_detect_language(false);
|
|
params.set_language(Some("en"));
|
|
params.set_print_special(false);
|
|
params.set_print_progress(false);
|
|
params.set_print_realtime(false);
|
|
params.set_print_timestamps(false);
|
|
|
|
let mut state = self.ctx.create_state()?;
|
|
state.full(params, &data[..])?;
|
|
|
|
let mut result = new_whispered();
|
|
let num_segments = state.full_n_segments()?;
|
|
for i in 0..num_segments {
|
|
let data = state.full_get_segment_text(i)?;
|
|
let start = state.full_get_segment_t0(i)?;
|
|
let stop = state.full_get_segment_t1(i)?;
|
|
result.push(data, start, stop);
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
struct ATranscribe {
|
|
data: Vec<f32>,
|
|
ack: Option<std::sync::mpsc::SyncSender<bool>>,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct Transcribed {
|
|
data: Vec<ATranscribed>,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
struct ATranscribed {
|
|
pub data: String,
|
|
pub offset: i64,
|
|
pub length: i64,
|
|
}
|
|
|
|
fn new_whispered() -> Transcribed {
|
|
Transcribed{data: vec![]}
|
|
}
|
|
|
|
fn new_a_whispered(data: String, start: i64, stop: i64) -> ATranscribed {
|
|
ATranscribed{
|
|
data: data,
|
|
offset: start.clone(),
|
|
length: stop - start,
|
|
}
|
|
}
|
|
|
|
impl Transcribed {
|
|
pub fn to_string(&self) -> String {
|
|
let mut result = "".to_string();
|
|
for i in 0..self.data.len() {
|
|
result = format!("{} {}", result, &self.data[i].data);
|
|
}
|
|
result
|
|
}
|
|
|
|
fn after(&self, t: &f32) -> Transcribed {
|
|
let mut result = new_whispered();
|
|
self.data
|
|
.iter()
|
|
.filter(|x| x.offset as f32 >= *t)
|
|
.for_each(|x| result.data.push(x.clone()));
|
|
result
|
|
}
|
|
|
|
fn before(&self, t: &f32) -> Transcribed {
|
|
let mut result = new_whispered();
|
|
let end = match self.data.iter().map(|x| x.offset + x.length).max() {
|
|
Some(x) => x,
|
|
None => 1,
|
|
};
|
|
let t = (end as f32) - *t;
|
|
self.data
|
|
.iter()
|
|
.filter(|x| ((x.offset) as f32) <= t)
|
|
.for_each(|x| result.data.push(x.clone()));
|
|
result
|
|
}
|
|
|
|
fn push(&mut self, data: String, start: i64, stop: i64) {
|
|
let words: Vec<_> = data.split_whitespace().collect();
|
|
let per_word = (stop - start) / (words.len() as i64);
|
|
for i in 0..words.len() {
|
|
let start = (i as i64) * per_word;
|
|
let stop = start.clone() + per_word;
|
|
self.data.push(new_a_whispered(words[i].to_string(), start, stop));
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_transcribe_tiny_jfk_wav_whisper_rs_wav_channel() {
|
|
wav_channel(
|
|
Flags {
|
|
model_path: None,
|
|
model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()),
|
|
threads: 8,
|
|
stream_step: 30,
|
|
stream_retain: 0.0,
|
|
stream_head: 0.0,
|
|
stream_tail: 0.0,
|
|
wav: Some("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()),
|
|
debug: false,
|
|
stream_device: None,
|
|
},
|
|
move | result | {
|
|
assert!(result.is_ok());
|
|
assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country.");
|
|
},
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_transcribe_tiny_jfk_wav_whisper_rs() {
|
|
wav(
|
|
Flags {
|
|
model_path: None,
|
|
model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()),
|
|
threads: 8,
|
|
stream_step: 0,
|
|
stream_retain: 0.0,
|
|
stream_head: 0.0,
|
|
stream_tail: 0.0,
|
|
wav: Some("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()),
|
|
debug: false,
|
|
stream_device: None,
|
|
},
|
|
| result | {
|
|
assert!(result.is_ok());
|
|
assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country.");
|
|
},
|
|
"../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string(),
|
|
);
|
|
}
|
|
}
|