IT RUNS WITH RWHISPER WOO

master
Bel LaPointe 2023-12-21 14:20:19 -05:00
parent a420ef5381
commit 0d8315550a
3 changed files with 2312 additions and 16 deletions

File diff suppressed because it is too large Load Diff

View File

@ -12,3 +12,6 @@ tokio = "1.27"
byteorder = "1.5.0"
chrono = "0.4.31"
clap = { version = "4.4.10", features = ["derive"] }
rwhisper = "0.1.2"
rodio = "0.17.3"
futures = "0.3.29"

View File

@ -3,6 +3,7 @@ use clap::Parser;
use std::thread;
use std::fs::File;
use std::io::Write;
use rwhisper;
#[derive(Parser, Debug, Clone)]
pub struct Flags {
@ -52,6 +53,28 @@ pub fn wav<F>(flags: Flags, handler_fn: F, wav_path: String) where F: FnMut(Resu
w.transcribe(&audio_data);
}
fn f32_from_wav_file(path: String) -> Result<Vec<f32>, String> {
let f = std::fs::File::open(path);
if let Some(err) = f.as_ref().err() {
return Err(format!("failed to open wav file: {}", err));
}
let wav_read = wav::read(&mut f.unwrap());
if let Some(err) = wav_read.as_ref().err() {
return Err(format!("failed to parse wav file: {}", err));
}
let (header, data) = wav_read.unwrap();
if header.channel_count != 1 {
return Err("!= 1 channel".to_string());
}
if header.sampling_rate != 16_000 {
return Err("!= 16_000 hz".to_string());
}
match data.as_sixteen() {
Some(data16) => Ok(whisper_rs::convert_integer_to_float_audio(&data16)),
None => Err(format!("couldnt translate wav to 16s")),
}
}
pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver<Vec<f32>>) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
let w = new_service(
flags.model_path,
@ -243,6 +266,46 @@ impl Engine {
}
}
struct Engine2 {
}
fn new_engine2(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32) -> Result<Engine2, String> {
Ok(Engine2{})
}
impl Engine2 {
fn transcribe(&self, data: &Vec<f32>) -> Result<Transcribed, String> {
self._transcribe(data)
}
fn _transcribe(&self, data: &Vec<f32>) -> Result<Transcribed, String> {
let model = rwhisper::WhisperBuilder::default()
.with_cpu(true)
.with_language(Some(rwhisper::WhisperLanguage::English))
.with_source(rwhisper::WhisperSource::TinyEn)
.build().unwrap();
let buffer = rodio::buffer::SamplesBuffer::new(1, 16_000, data.clone());
let future = async {
let mut w: Vec<u8> = vec![];
let mut stream = model.transcribe(buffer).unwrap();
stream.write_to(&mut w).await.unwrap();
w
};
let w = futures::executor::block_on(future);
let mut result = new_whispered();
result.push(String::from_utf8(w).unwrap(), 0, (100 * data.len() / 16_000) as i64);
Ok(result)
/*
let mut w: Vec<u8> = vec![];
stream.write_to(&mut w).await.unwrap();
let mut result = new_whispered();
result.push(String::from_utf8(w).unwrap(), 0, (100 * data.len() / 16_000) as i64);
Ok(result)
*/
}
}
struct ATranscribe {
data: Vec<f32>,
ack: Option<std::sync::mpsc::SyncSender<bool>>,
@ -320,7 +383,7 @@ mod tests {
use super::*;
#[test]
fn test_transcribe_tiny_jfk_wav() {
fn test_transcribe_tiny_jfk_wav_whisper_rs() {
wav(
Flags {
model_path: None,
@ -341,4 +404,16 @@ mod tests {
"../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string(),
);
}
#[test]
fn test_transcribe_tiny_jfk_wav_rwhisper() {
let engine_2 = new_engine2(
Some("../models/ggml-tiny.en.bin".to_string()),
None,
4,
).expect("failed to make new engine2");
let data = f32_from_wav_file("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()).expect("failed to read jfk.wav");
let result = engine_2.transcribe(&data).expect("failed to transcribe");
assert_eq!(" And so my fellow American asked not what your country can do for you, ask what you can do for your country.".to_string(), result.to_string());
}
}