drop original root-level
parent
9055cd4aee
commit
2d4451f97a
File diff suppressed because it is too large
Load Diff
16
Cargo.toml
16
Cargo.toml
|
|
@ -1,16 +0,0 @@
|
|||
[package]
|
||||
name = "rust-whisper"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
whisper-rs = { path = "./gitea-whisper-rs", version = "0.8.0" }
|
||||
wav = "1"
|
||||
tokio = "1.27"
|
||||
cpal = "0.15.2"
|
||||
signal-hook = "0.3.17"
|
||||
byteorder = "1.5.0"
|
||||
chrono = "0.4.31"
|
||||
clap = { version = "4.4.10", features = ["derive"] }
|
||||
374
src/main.rs
374
src/main.rs
|
|
@ -1,374 +0,0 @@
|
|||
use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError};
|
||||
use cpal::traits::{HostTrait, DeviceTrait, StreamTrait};
|
||||
use signal_hook::{iterator::Signals, consts::signal::SIGINT};
|
||||
use std::time::{Duration, Instant};
|
||||
use chrono;
|
||||
use clap::Parser;
|
||||
use std::thread;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
struct Flags {
|
||||
#[arg(long, default_value = "./models/ggml-tiny.en.bin")]
|
||||
model: String,
|
||||
|
||||
#[arg(long, default_value = "8")]
|
||||
threads: i32,
|
||||
|
||||
#[arg(long, default_value = "5")]
|
||||
stream_step: u64,
|
||||
#[arg(long, default_value = "0.6")]
|
||||
stream_retain: f32,
|
||||
#[arg(long, default_value = "0.3")]
|
||||
stream_head: f32,
|
||||
#[arg(long, default_value = "0.3")]
|
||||
stream_tail: f32,
|
||||
|
||||
wav: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "false")]
|
||||
debug: bool,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let flags = Flags::parse();
|
||||
|
||||
let w = new_whisper_service(
|
||||
flags.model,
|
||||
flags.threads,
|
||||
flags.stream_head,
|
||||
flags.stream_tail,
|
||||
|result: Result<Whispered, String>| {
|
||||
match result {
|
||||
Ok(whispered) => {
|
||||
eprintln!("{}: {:?}", chrono::Local::now(), whispered);
|
||||
println!("{}", whispered.to_string());
|
||||
},
|
||||
Err(msg) => { eprintln!("Error whispering: {}", msg); },
|
||||
};
|
||||
},
|
||||
).unwrap();
|
||||
let stream_retain = (flags.stream_retain * 16_000.0) as usize;
|
||||
let stream_step = Duration::new(flags.stream_step, 0);
|
||||
match flags.wav {
|
||||
Some(wav) => {
|
||||
let (header, data) = wav::read(
|
||||
&mut std::fs::File::open(wav).expect("failed to open $WAV"),
|
||||
).expect("failed to decode $WAV");
|
||||
assert!(header.channel_count == 1);
|
||||
assert!(header.sampling_rate == 16_000);
|
||||
let data16 = data.as_sixteen().expect("wav is not 32bit floats");
|
||||
let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16);
|
||||
|
||||
w.transcribe(&audio_data);
|
||||
},
|
||||
None => {
|
||||
match &flags.debug {
|
||||
true => { File::create("/tmp/page.rawf32audio").unwrap(); },
|
||||
false => {},
|
||||
};
|
||||
let mut buffer = vec![];
|
||||
let mut last = Instant::now();
|
||||
new_listener().listen(move |data: Vec<f32>| {
|
||||
data.iter().for_each(|x| buffer.push(*x));
|
||||
if Instant::now() - last > stream_step {
|
||||
w.transcribe_async(&buffer).unwrap();
|
||||
|
||||
match &flags.debug {
|
||||
true => {
|
||||
let mut f = File::options().append(true).open("/tmp/page.rawf32audio").unwrap();
|
||||
let mut wav_data = vec![];
|
||||
for i in buffer.iter() {
|
||||
for j in i.to_le_bytes() {
|
||||
wav_data.push(j);
|
||||
}
|
||||
}
|
||||
f.write_all(wav_data.as_slice()).unwrap();
|
||||
},
|
||||
false => {},
|
||||
};
|
||||
|
||||
for i in 0..stream_retain {
|
||||
buffer[i] = buffer[buffer.len() - stream_retain + i];
|
||||
}
|
||||
buffer.truncate(stream_retain);
|
||||
last = Instant::now();
|
||||
}
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
struct WhisperService {
|
||||
jobs: std::sync::mpsc::SyncSender<AWhisper>,
|
||||
}
|
||||
|
||||
fn new_whisper_service<F>(model_path: String, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result<WhisperService, String> where F: FnMut(Result<Whispered, String>) + Send + 'static {
|
||||
match new_whisper_engine(model_path, threads) {
|
||||
Ok(engine) => {
|
||||
let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn);
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
||||
thread::spawn(move || { whisper.transcribe_asyncs(recv); });
|
||||
Ok(WhisperService{jobs: send})
|
||||
},
|
||||
Err(msg) => Err(format!("failed to initialize engine: {}", msg)),
|
||||
}
|
||||
}
|
||||
|
||||
impl WhisperService {
|
||||
fn transcribe(&self, data: &Vec<f32>) {
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(0);
|
||||
self._transcribe_async(data, Some(send)).unwrap();
|
||||
recv.recv().unwrap();
|
||||
}
|
||||
|
||||
fn transcribe_async(&self, data: &Vec<f32>) -> Result<(), String> {
|
||||
self._transcribe_async(data, None)
|
||||
}
|
||||
|
||||
fn _transcribe_async(&self, data: &Vec<f32>, ack: Option<std::sync::mpsc::SyncSender<bool>>) -> Result<(), String> {
|
||||
match self.jobs.try_send(AWhisper{
|
||||
data: data.clone().to_vec(),
|
||||
ack: ack,
|
||||
}) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(msg) => Err(format!("failed to enqueue transcription: {}", msg)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct WhisperImpl {
|
||||
engine: WhisperEngine,
|
||||
stream_head: f32,
|
||||
stream_tail: f32,
|
||||
handler_fn: Option<Box<dyn FnMut(Result<Whispered, String>) + Send + 'static>>
|
||||
}
|
||||
|
||||
fn new_whisper_impl<F>(engine: WhisperEngine, stream_head: f32, stream_tail: f32, handler_fn: F) -> WhisperImpl where F: FnMut(Result<Whispered, String>) + Send + 'static {
|
||||
WhisperImpl {
|
||||
engine: engine,
|
||||
stream_head: stream_head,
|
||||
stream_tail: stream_tail,
|
||||
handler_fn: Some(Box::new(handler_fn)),
|
||||
}
|
||||
}
|
||||
|
||||
impl WhisperImpl {
|
||||
fn transcribe_asyncs(&mut self, recv: std::sync::mpsc::Receiver<AWhisper>) {
|
||||
loop {
|
||||
match recv.recv() {
|
||||
Ok(job) => {
|
||||
let result = self.transcribe(&job).is_ok();
|
||||
match job.ack {
|
||||
Some(ack) => {
|
||||
ack.send(result).unwrap();
|
||||
},
|
||||
None => (),
|
||||
};
|
||||
}
|
||||
Err(_) => return,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn transcribe(&mut self, a_whisper: &AWhisper) -> Result<(), ()> {
|
||||
match self.engine.transcribe(&a_whisper.data) {
|
||||
Ok(result) => {
|
||||
self.on_success(&result);
|
||||
Ok(())
|
||||
},
|
||||
Err(msg) => {
|
||||
self.on_error(msg.to_string());
|
||||
Err(())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn on_success(&mut self, whispered: &Whispered) {
|
||||
let result = whispered
|
||||
.after(&(self.stream_head * 100.0))
|
||||
.before(&(self.stream_tail * 100.0));
|
||||
(self.handler_fn.as_mut().unwrap())(Ok(result));
|
||||
}
|
||||
|
||||
fn on_error(&mut self, msg: String) {
|
||||
(self.handler_fn.as_mut().unwrap())(Err(format!("failed to transcribe: {}", &msg)));
|
||||
}
|
||||
}
|
||||
|
||||
struct WhisperEngine {
|
||||
ctx: WhisperContext,
|
||||
threads: i32,
|
||||
}
|
||||
|
||||
fn new_whisper_engine(model_path: String, threads: i32) -> Result<WhisperEngine, String> {
|
||||
match WhisperContext::new(&model_path) {
|
||||
Ok(ctx) => Ok(WhisperEngine{ctx: ctx, threads: threads}),
|
||||
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
|
||||
}
|
||||
}
|
||||
|
||||
impl WhisperEngine {
|
||||
fn transcribe(&self, data: &Vec<f32>) -> Result<Whispered, WhisperError> {
|
||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
|
||||
params.set_no_context(true);
|
||||
params.set_n_threads(self.threads);
|
||||
params.set_translate(false);
|
||||
params.set_detect_language(false);
|
||||
params.set_language(Some("en"));
|
||||
params.set_print_special(false);
|
||||
params.set_print_progress(false);
|
||||
params.set_print_realtime(false);
|
||||
params.set_print_timestamps(false);
|
||||
|
||||
let mut state = self.ctx.create_state()?;
|
||||
state.full(params, &data[..])?;
|
||||
|
||||
let mut result = new_whispered();
|
||||
let num_segments = state.full_n_segments()?;
|
||||
for i in 0..num_segments {
|
||||
let data = state.full_get_segment_text(i)?;
|
||||
let start = state.full_get_segment_t0(i)?;
|
||||
let stop = state.full_get_segment_t1(i)?;
|
||||
result.push(data, start, stop);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
struct AWhisper {
|
||||
data: Vec<f32>,
|
||||
ack: Option<std::sync::mpsc::SyncSender<bool>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Whispered {
|
||||
data: Vec<AWhispered>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct AWhispered {
|
||||
data: String,
|
||||
offset: i64,
|
||||
length: i64,
|
||||
}
|
||||
|
||||
fn new_whispered() -> Whispered {
|
||||
Whispered{data: vec![]}
|
||||
}
|
||||
|
||||
fn new_a_whispered(data: String, start: i64, stop: i64) -> AWhispered {
|
||||
AWhispered{
|
||||
data: data,
|
||||
offset: start.clone(),
|
||||
length: stop - start,
|
||||
}
|
||||
}
|
||||
|
||||
impl Whispered {
|
||||
fn to_string(&self) -> String {
|
||||
let mut result = "".to_string();
|
||||
for i in 0..self.data.len() {
|
||||
result = format!("{} {}", result, &self.data[i].data);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn after(&self, t: &f32) -> Whispered {
|
||||
let mut result = new_whispered();
|
||||
self.data
|
||||
.iter()
|
||||
.filter(|x| x.offset as f32 >= *t)
|
||||
.for_each(|x| result.data.push(x.clone()));
|
||||
result
|
||||
}
|
||||
|
||||
fn before(&self, t: &f32) -> Whispered {
|
||||
let mut result = new_whispered();
|
||||
let end = match self.data.iter().map(|x| x.offset + x.length).max() {
|
||||
Some(x) => x,
|
||||
None => 1,
|
||||
};
|
||||
let t = (end as f32) - *t;
|
||||
self.data
|
||||
.iter()
|
||||
.filter(|x| ((x.offset) as f32) <= t)
|
||||
.for_each(|x| result.data.push(x.clone()));
|
||||
result
|
||||
}
|
||||
|
||||
fn push(&mut self, data: String, start: i64, stop: i64) {
|
||||
let words: Vec<_> = data.split_whitespace().collect();
|
||||
let per_word = (stop - start) / (words.len() as i64);
|
||||
for i in 0..words.len() {
|
||||
let start = (i as i64) * per_word;
|
||||
let stop = start.clone() + per_word;
|
||||
self.data.push(new_a_whispered(words[i].to_string(), start, stop));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Listener {
|
||||
}
|
||||
|
||||
fn new_listener() -> Listener {
|
||||
Listener{}
|
||||
}
|
||||
|
||||
impl Listener {
|
||||
fn listen(self, mut cb: impl FnMut(Vec<f32>)) {
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
||||
thread::spawn(move || { self._listen(send); });
|
||||
loop {
|
||||
match recv.recv() {
|
||||
Ok(msg) => cb(msg),
|
||||
Err(_) => return,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn _listen(self, send: std::sync::mpsc::SyncSender<Vec<f32>>) {
|
||||
let host = cpal::default_host();
|
||||
let device = host.default_input_device().unwrap();
|
||||
let cfg = device.supported_input_configs()
|
||||
.unwrap()
|
||||
.filter(|x| x.sample_format() == cpal::SampleFormat::F32)
|
||||
.nth(0)
|
||||
.unwrap()
|
||||
.with_max_sample_rate();
|
||||
|
||||
let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0);
|
||||
let stream = device.build_input_stream(
|
||||
&cfg.clone().into(),
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let mut downsampled_data = vec![];
|
||||
for i in 0..(data.len() as f32 / downsample_ratio) as usize {
|
||||
let mut upsampled = i as f32 * downsample_ratio;
|
||||
if upsampled > (data.len()-1) as f32 {
|
||||
upsampled = (data.len()-1) as f32
|
||||
}
|
||||
downsampled_data.push(data[upsampled as usize]);
|
||||
}
|
||||
match send.try_send(downsampled_data) {
|
||||
Ok(_) => (),
|
||||
Err(msg) => eprintln!("failed to ingest audio: {}", msg),
|
||||
};
|
||||
},
|
||||
move |err| {
|
||||
eprintln!("input error: {}", err)
|
||||
},
|
||||
None,
|
||||
).unwrap();
|
||||
stream.play().unwrap();
|
||||
|
||||
eprintln!("listening on {}", device.name().unwrap());
|
||||
let mut signals = Signals::new(&[SIGINT]).unwrap();
|
||||
for sig in signals.forever() {
|
||||
eprintln!("sig {}", sig);
|
||||
break;
|
||||
}
|
||||
stream.pause().unwrap();
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue