Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b79024bdd | ||
|
|
de099d9917 | ||
|
|
0d8315550a |
2248
rust-whisper-lib/Cargo.lock
generated
2248
rust-whisper-lib/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -12,3 +12,6 @@ tokio = "1.27"
|
|||||||
byteorder = "1.5.0"
|
byteorder = "1.5.0"
|
||||||
chrono = "0.4.31"
|
chrono = "0.4.31"
|
||||||
clap = { version = "4.4.10", features = ["derive"] }
|
clap = { version = "4.4.10", features = ["derive"] }
|
||||||
|
rwhisper = "0.1.2"
|
||||||
|
rodio = "0.17.3"
|
||||||
|
futures = "0.3.29"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use clap::Parser;
|
|||||||
use std::thread;
|
use std::thread;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use rwhisper;
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct Flags {
|
pub struct Flags {
|
||||||
@@ -41,15 +42,29 @@ pub fn wav<F>(flags: Flags, handler_fn: F, wav_path: String) where F: FnMut(Resu
|
|||||||
flags.stream_tail,
|
flags.stream_tail,
|
||||||
handler_fn,
|
handler_fn,
|
||||||
).unwrap();
|
).unwrap();
|
||||||
let (header, data) = wav::read(
|
w.transcribe(&f32_from_wav_file(&wav_path).unwrap())
|
||||||
&mut std::fs::File::open(wav_path).expect("failed to open $WAV"),
|
}
|
||||||
).expect("failed to decode $WAV");
|
|
||||||
assert!(header.channel_count == 1);
|
|
||||||
assert!(header.sampling_rate == 16_000);
|
|
||||||
let data16 = data.as_sixteen().expect("wav is not 32bit floats");
|
|
||||||
let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16);
|
|
||||||
|
|
||||||
w.transcribe(&audio_data);
|
fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
|
||||||
|
let f = std::fs::File::open(path);
|
||||||
|
if let Some(err) = f.as_ref().err() {
|
||||||
|
return Err(format!("failed to open wav file: {}", err));
|
||||||
|
}
|
||||||
|
let wav_read = wav::read(&mut f.unwrap());
|
||||||
|
if let Some(err) = wav_read.as_ref().err() {
|
||||||
|
return Err(format!("failed to parse wav file: {}", err));
|
||||||
|
}
|
||||||
|
let (header, data) = wav_read.unwrap();
|
||||||
|
if header.channel_count != 1 {
|
||||||
|
return Err("!= 1 channel".to_string());
|
||||||
|
}
|
||||||
|
if header.sampling_rate != 16_000 {
|
||||||
|
return Err("!= 16_000 hz".to_string());
|
||||||
|
}
|
||||||
|
match data.as_sixteen() {
|
||||||
|
Some(data16) => Ok(whisper_rs::convert_integer_to_float_audio(&data16)),
|
||||||
|
None => Err(format!("couldnt translate wav to 16s")),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver<Vec<f32>>) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver<Vec<f32>>) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
||||||
@@ -243,6 +258,54 @@ impl Engine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Engine2 {
|
||||||
|
model: rwhisper::Whisper,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_engine2(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32) -> Result<Engine2, String> {
|
||||||
|
match rwhisper::WhisperBuilder::default()
|
||||||
|
.with_cpu(true)
|
||||||
|
.with_language(Some(rwhisper::WhisperLanguage::English))
|
||||||
|
.with_source(rwhisper::WhisperSource::TinyEn)
|
||||||
|
.build() {
|
||||||
|
Ok(model) => Ok(Engine2{model: model}),
|
||||||
|
Err(msg) => Err(format!("failed to create model: {}", msg)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Engine2 {
|
||||||
|
fn transcribe(&self, data: &Vec<f32>) -> Result<Transcribed, String> {
|
||||||
|
self._transcribe(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _transcribe(&self, data: &Vec<f32>) -> Result<Transcribed, String> {
|
||||||
|
let buffer = rodio::buffer::SamplesBuffer::new(1, 16_000, data.clone());
|
||||||
|
let stream = self.model.transcribe(buffer);
|
||||||
|
if stream.as_ref().is_err() {
|
||||||
|
return Err(format!("failed to start transcribing: {}", stream.err().unwrap()));
|
||||||
|
}
|
||||||
|
let stream = stream.unwrap();
|
||||||
|
|
||||||
|
let future = async {
|
||||||
|
let mut w: Vec<u8> = vec![];
|
||||||
|
stream.write_to(&mut w).await.unwrap();
|
||||||
|
w
|
||||||
|
};
|
||||||
|
let w = futures::executor::block_on(future);
|
||||||
|
let mut result = new_whispered();
|
||||||
|
result.push(String::from_utf8(w).unwrap(), 0, (100 * data.len() / 16_000) as i64);
|
||||||
|
Ok(result)
|
||||||
|
|
||||||
|
/*
|
||||||
|
let mut w: Vec<u8> = vec![];
|
||||||
|
stream.write_to(&mut w).await.unwrap();
|
||||||
|
let mut result = new_whispered();
|
||||||
|
result.push(String::from_utf8(w).unwrap(), 0, (100 * data.len() / 16_000) as i64);
|
||||||
|
Ok(result)
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ATranscribe {
|
struct ATranscribe {
|
||||||
data: Vec<f32>,
|
data: Vec<f32>,
|
||||||
ack: Option<std::sync::mpsc::SyncSender<bool>>,
|
ack: Option<std::sync::mpsc::SyncSender<bool>>,
|
||||||
@@ -320,7 +383,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_transcribe_tiny_jfk_wav() {
|
fn test_transcribe_tiny_jfk_wav_whisper_rs() {
|
||||||
wav(
|
wav(
|
||||||
Flags {
|
Flags {
|
||||||
model_path: None,
|
model_path: None,
|
||||||
@@ -341,4 +404,20 @@ mod tests {
|
|||||||
"../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string(),
|
"../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_transcribe_tiny_jfk_wav_rwhisper() {
|
||||||
|
let engine_2 = new_engine2(
|
||||||
|
Some("../models/ggml-tiny.en.bin".to_string()),
|
||||||
|
None,
|
||||||
|
4,
|
||||||
|
).expect("failed to make new engine2");
|
||||||
|
let data = f32_from_wav_file(&"../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()).expect("failed to read jfk.wav");
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
for i in 0..2 {
|
||||||
|
let result = engine_2.transcribe(&data).expect("failed to transcribe");
|
||||||
|
println!("rwhisper = {}s", start.elapsed().as_secs_f32());
|
||||||
|
assert_eq!(" And so my fellow American asked not what your country can do for you, ask what you can do for your country.".to_string(), result.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user