found burn depends on zstd that doesnt wasm

This commit is contained in:
Bel LaPointe
2023-12-21 22:13:23 -05:00
parent d63d05505c
commit 1c9d646a50
6 changed files with 7282 additions and 636 deletions

View File

@@ -1,4 +1,3 @@
use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError};
use clap::Parser;
use std::thread;
use std::fs::File;
@@ -61,7 +60,13 @@ fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
return Err("!= 16_000 hz".to_string());
}
match data.as_sixteen() {
Some(data16) => Ok(whisper_rs::convert_integer_to_float_audio(&data16)),
Some(data16) => {
let mut floats = Vec::with_capacity(data16.len());
for sample in data16 {
floats.push(*sample as f32 / 32768.0);
}
Ok(floats)
},
None => Err(format!("couldnt translate wav to 16s")),
}
}
@@ -113,7 +118,7 @@ struct Service {
}
fn new_service<F>(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32, stream_head: f32, stream_tail: f32, handler_fn: F) -> Result<Service, String> where F: FnMut(Result<Transcribed, String>) + Send + 'static {
match new_engine(model_path, model_buffer, threads) {
match new_engine_2() {
Ok(engine) => {
let mut whisper = new_whisper_impl(engine, stream_head, stream_tail, handler_fn);
let (send, recv) = std::sync::mpsc::sync_channel(100);
@@ -147,13 +152,13 @@ impl Service {
}
struct Impl {
engine: Engine,
engine: Engine2,
stream_head: f32,
stream_tail: f32,
handler_fn: Option<Box<dyn FnMut(Result<Transcribed, String>) + Send + 'static>>
}
fn new_whisper_impl<F>(engine: Engine, stream_head: f32, stream_tail: f32, handler_fn: F) -> Impl where F: FnMut(Result<Transcribed, String>) + Send + 'static {
fn new_whisper_impl<F>(engine: Engine2, stream_head: f32, stream_tail: f32, handler_fn: F) -> Impl where F: FnMut(Result<Transcribed, String>) + Send + 'static {
Impl {
engine: engine,
stream_head: stream_head,
@@ -205,58 +210,6 @@ impl Impl {
}
}
struct Engine {
ctx: WhisperContext,
threads: i32,
}
fn new_engine(model_path: Option<String>, model_buffer: Option<Vec<u8>>, threads: i32) -> Result<Engine, String> {
let whisper_context_result = match model_path {
Some(model_path) => WhisperContext::new(&model_path),
None => WhisperContext::new_from_buffer(&model_buffer.unwrap()),
};
match whisper_context_result {
Ok(ctx) => Ok(Engine{ctx: ctx, threads: threads}),
Err(msg) => Err(format!("failed to load model: {}", msg)),
}
}
impl Engine {
fn transcribe(&self, data: &Vec<f32>) -> Result<Transcribed, String> {
match self._transcribe(data) {
Ok(transcribed) => Ok(transcribed),
Err(msg) => Err(format!("{}", msg)),
}
}
fn _transcribe(&self, data: &Vec<f32>) -> Result<Transcribed, WhisperError> {
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
params.set_no_context(true);
params.set_n_threads(self.threads);
params.set_translate(false);
params.set_detect_language(false);
params.set_language(Some("en"));
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
let mut state = self.ctx.create_state()?;
state.full(params, &data[..])?;
let mut result = new_whispered();
let num_segments = state.full_n_segments()?;
for i in 0..num_segments {
let data = state.full_get_segment_text(i)?;
let start = state.full_get_segment_t0(i)?;
let stop = state.full_get_segment_t1(i)?;
result.push(data, start, stop);
}
Ok(result)
}
}
struct Engine2 {
}
@@ -346,29 +299,6 @@ impl Transcribed {
mod tests {
use super::*;
#[test]
fn test_transcribe_tiny_jfk_wav_whisper_rs() {
wav(
Flags {
model_path: None,
model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()),
threads: 8,
stream_step: 0,
stream_retain: 0.0,
stream_head: 0.0,
stream_tail: 0.0,
wav: Some("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()),
debug: false,
stream_device: None,
},
| result | {
assert!(result.is_ok());
assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country.");
},
"../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string(),
);
}
#[test]
fn test_transcribe_tiny_jfk_wav_candle() {
let wav_path = "../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string();