Compare commits
7 Commits
393100973c
...
0.2.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d94cbd6927 | ||
|
|
7a5db3b2ac | ||
|
|
d0dc9571d7 | ||
|
|
6082f7e446 | ||
|
|
dd6f980266 | ||
|
|
601fe517d7 | ||
|
|
cd339de334 |
@@ -19,6 +19,15 @@ pub fn wav<F>(
|
|||||||
rust_whisper_lib::wav(flags.clone(), handler_fn, flags.wav.unwrap());
|
rust_whisper_lib::wav(flags.clone(), handler_fn, flags.wav.unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn wav_channel<F>(
|
||||||
|
mut flags: rust_whisper_lib::Flags,
|
||||||
|
handler_fn: F
|
||||||
|
) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static {
|
||||||
|
flags.model_path = None;
|
||||||
|
flags.model_buffer = Some(include_bytes!("../../models/ggml-base.en.bin").to_vec());
|
||||||
|
rust_whisper_lib::wav_channel(flags, handler_fn);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
|
pub fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
|
||||||
rust_whisper_lib::f32_from_wav_file(path)
|
rust_whisper_lib::f32_from_wav_file(path)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,50 +7,36 @@ use std::thread;
|
|||||||
fn main() {
|
fn main() {
|
||||||
let flags = rust_whisper_lib::Flags::parse();
|
let flags = rust_whisper_lib::Flags::parse();
|
||||||
match flags.wav.clone() {
|
match flags.wav.clone() {
|
||||||
Some(path) => wav_ch(flags, path),
|
Some(_) => wav_channel(flags),
|
||||||
None => channel(flags),
|
None => channel(flags),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wav_ch(flags: rust_whisper_lib::Flags, path: String) {
|
fn wav_channel(flags: rust_whisper_lib::Flags) {
|
||||||
let mut audio = rust_whisper_baked_lib::f32_from_wav_file(&path).unwrap();
|
let mut w = new_destutterer();
|
||||||
|
rust_whisper_baked_lib::wav_channel(
|
||||||
let (send, recv) = std::sync::mpsc::sync_channel(100);
|
|
||||||
let n = audio.len() / match audio.len() % 100 {
|
|
||||||
0 => 100,
|
|
||||||
_ => 99,
|
|
||||||
};
|
|
||||||
for _ in 0..100 {
|
|
||||||
send.send(audio.drain(0..n).collect()).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
let (fin_send, fin_recv) = std::sync::mpsc::sync_channel::<Option<i32>>(1);
|
|
||||||
thread::spawn(move || {
|
|
||||||
let mut i = 0;
|
|
||||||
rust_whisper_baked_lib::channel(
|
|
||||||
flags.clone(),
|
flags.clone(),
|
||||||
move |result: Result<rust_whisper_lib::Transcribed, String>| {
|
move |result: Result<rust_whisper_lib::Transcribed, String>| {
|
||||||
match result {
|
match result {
|
||||||
Ok(transcribed) => { println!("{}", transcribed.to_string()); },
|
Ok(transcribed) => {
|
||||||
|
let s = w.step(transcribed.to_string());
|
||||||
|
println!("{}", s);
|
||||||
|
},
|
||||||
Err(msg) => { eprintln!("error: {}", msg); },
|
Err(msg) => { eprintln!("error: {}", msg); },
|
||||||
};
|
};
|
||||||
i += 1;
|
|
||||||
if i == 100 {
|
|
||||||
fin_send.send(None).unwrap();
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
recv,
|
|
||||||
);
|
);
|
||||||
});
|
|
||||||
|
|
||||||
let _ = fin_recv.recv();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wav(flags: rust_whisper_lib::Flags, _path: String) {
|
fn wav(flags: rust_whisper_lib::Flags, _path: String) {
|
||||||
|
let mut w = new_destutterer();
|
||||||
rust_whisper_baked_lib::wav(flags,
|
rust_whisper_baked_lib::wav(flags,
|
||||||
|result: Result<rust_whisper_lib::Transcribed, String>| {
|
move |result: Result<rust_whisper_lib::Transcribed, String>| {
|
||||||
match result {
|
match result {
|
||||||
Ok(transcribed) => { println!("{}", transcribed.to_string()); },
|
Ok(transcribed) => {
|
||||||
|
let s = w.step(transcribed.to_string());
|
||||||
|
println!("{}", s);
|
||||||
|
},
|
||||||
Err(msg) => { eprintln!("error: {}", msg); },
|
Err(msg) => { eprintln!("error: {}", msg); },
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
@@ -96,3 +82,66 @@ fn channel(flags: rust_whisper_lib::Flags) {
|
|||||||
}
|
}
|
||||||
eprintln!("/listen lib main...");
|
eprintln!("/listen lib main...");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Destutterer {
|
||||||
|
prev: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_destutterer() -> Destutterer {
|
||||||
|
Destutterer{prev: None}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Destutterer {
|
||||||
|
fn step(&mut self, next: String) -> String {
|
||||||
|
let next = next.trim().to_string();
|
||||||
|
if next.len() == 0 {
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
match &self.prev {
|
||||||
|
None => {
|
||||||
|
self.prev = Some(next.clone());
|
||||||
|
next
|
||||||
|
},
|
||||||
|
Some(prev) => {
|
||||||
|
let without_trailing_punctuation = {
|
||||||
|
let mut next = next.clone();
|
||||||
|
while next.ends_with("?") || next.ends_with(".") {
|
||||||
|
next = next[..next.len()-1].to_string();
|
||||||
|
}
|
||||||
|
next
|
||||||
|
};
|
||||||
|
let trailing_punctuation = next[without_trailing_punctuation.len() ..].to_string();
|
||||||
|
let next = without_trailing_punctuation;
|
||||||
|
let next = {
|
||||||
|
let mut n = prev.len().clamp(0, next.len());
|
||||||
|
while n > 0 {
|
||||||
|
if prev[prev.len() - n..] == next[..n] {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
n -= 1;
|
||||||
|
}
|
||||||
|
next[n..].to_string()
|
||||||
|
};
|
||||||
|
if next.len() == 0 {
|
||||||
|
return "".to_string();
|
||||||
|
}
|
||||||
|
self.prev = Some(next.clone());
|
||||||
|
next + &trailing_punctuation
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_destutterer() {
|
||||||
|
let mut w = new_destutterer();
|
||||||
|
assert_eq!("abcde".to_string(), w.step("abcde".to_string()));
|
||||||
|
assert_eq!("fg".to_string(), w.step("cdefg".to_string()));
|
||||||
|
assert_eq!("hij".to_string(), w.step("fghij".to_string()));
|
||||||
|
assert_eq!("fghij".to_string(), w.step("fghij".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -72,7 +72,30 @@ pub fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver<Vec<f32>>) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
pub fn wav_channel<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
|
||||||
|
let path = flags.wav.as_ref().unwrap();
|
||||||
|
let mut audio = f32_from_wav_file(&path).unwrap();
|
||||||
|
let mut iter = vec![];
|
||||||
|
let n = audio.len() / match audio.len() % 100 {
|
||||||
|
0 => 100,
|
||||||
|
_ => 99,
|
||||||
|
};
|
||||||
|
for _ in 0..100 {
|
||||||
|
iter.push(audio.drain(0..n.clamp(0, audio.len())).collect());
|
||||||
|
}
|
||||||
|
let (fin_send, fin_recv) = std::sync::mpsc::sync_channel::<Option<i32>>(1);
|
||||||
|
channel_and_close(flags.clone(), handler_fn, iter, move || { fin_send.send(None).unwrap(); });
|
||||||
|
match fin_recv.recv() {
|
||||||
|
Ok(_) => {},
|
||||||
|
Err(x) => panic!("failed to receive: {}", x),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn channel<F, I>(flags: Flags, handler_fn: F, stream: I) where F: FnMut(Result<Transcribed, String>) + Send + 'static, I: IntoIterator<Item = Vec<f32>> {
|
||||||
|
channel_and_close(flags, handler_fn, stream, || {});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn channel_and_close<F, I, G>(flags: Flags, handler_fn: F, stream: I, mut close_fn: G) where F: FnMut(Result<Transcribed, String>) + Send + 'static, I: IntoIterator<Item = Vec<f32>>, G: FnMut() + Send + 'static {
|
||||||
let w = new_service(
|
let w = new_service(
|
||||||
flags.model_path,
|
flags.model_path,
|
||||||
flags.model_buffer,
|
flags.model_buffer,
|
||||||
@@ -87,7 +110,7 @@ pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver
|
|||||||
false => {},
|
false => {},
|
||||||
};
|
};
|
||||||
let mut buffer = vec![];
|
let mut buffer = vec![];
|
||||||
for data in stream.iter() {
|
for data in stream {
|
||||||
data.iter().for_each(|x| buffer.push(*x));
|
data.iter().for_each(|x| buffer.push(*x));
|
||||||
if buffer.len() >= (flags.stream_step * 16_000) as usize {
|
if buffer.len() >= (flags.stream_step * 16_000) as usize {
|
||||||
w.transcribe_async(&buffer).unwrap();
|
w.transcribe_async(&buffer).unwrap();
|
||||||
@@ -112,6 +135,10 @@ pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver
|
|||||||
buffer.truncate(stream_retain);
|
buffer.truncate(stream_retain);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if buffer.len() > 0 {
|
||||||
|
w.transcribe(&buffer);
|
||||||
|
}
|
||||||
|
close_fn();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Service {
|
struct Service {
|
||||||
@@ -203,8 +230,10 @@ impl Impl {
|
|||||||
let result = whispered
|
let result = whispered
|
||||||
.after(&(self.stream_head * 100.0))
|
.after(&(self.stream_head * 100.0))
|
||||||
.before(&(self.stream_tail * 100.0));
|
.before(&(self.stream_tail * 100.0));
|
||||||
|
if result.to_string().trim().len() > 0 {
|
||||||
(self.handler_fn.as_mut().unwrap())(Ok(result));
|
(self.handler_fn.as_mut().unwrap())(Ok(result));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn on_error(&mut self, msg: String) {
|
fn on_error(&mut self, msg: String) {
|
||||||
(self.handler_fn.as_mut().unwrap())(Err(format!("failed to transcribe: {}", &msg)));
|
(self.handler_fn.as_mut().unwrap())(Err(format!("failed to transcribe: {}", &msg)));
|
||||||
@@ -339,6 +368,28 @@ impl Transcribed {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_transcribe_tiny_jfk_wav_whisper_rs_wav_channel() {
|
||||||
|
wav_channel(
|
||||||
|
Flags {
|
||||||
|
model_path: None,
|
||||||
|
model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()),
|
||||||
|
threads: 8,
|
||||||
|
stream_step: 30,
|
||||||
|
stream_retain: 0.0,
|
||||||
|
stream_head: 0.0,
|
||||||
|
stream_tail: 0.0,
|
||||||
|
wav: Some("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()),
|
||||||
|
debug: false,
|
||||||
|
stream_device: None,
|
||||||
|
},
|
||||||
|
move | result | {
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country.");
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_transcribe_tiny_jfk_wav_whisper_rs() {
|
fn test_transcribe_tiny_jfk_wav_whisper_rs() {
|
||||||
wav(
|
wav(
|
||||||
|
|||||||
Reference in New Issue
Block a user