successful refactor

master
Bel LaPointe 2023-11-28 19:18:05 -07:00
parent 3093a91d84
commit 437d7cac39
1 changed files with 21 additions and 34 deletions

View File

@ -1,23 +1,10 @@
use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError}; use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError};
fn main() { fn main() {
let ctx = WhisperContext::new( let w = new_whisper(
&std::env::var("MODEL").unwrap_or(String::from("../models/ggml-tiny.en.bin")) std::env::var("MODEL").unwrap_or(String::from("../models/ggml-tiny.en.bin")),
).expect("failed to load $MODEL"); std::env::var("P").unwrap_or(String::from("8")).parse::<i32>().expect("$P must be a number"),
let mut state = ctx.create_state().expect("failed to create state"); ).unwrap();
// create a params object
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
params.set_n_threads(
std::env::var("P").unwrap_or(String::from("8")).parse::<i32>().expect("$P must be a number")
);
params.set_translate(false);
params.set_detect_language(false);
params.set_language(Some("en"));
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
let (header, data) = wav::read(&mut std::fs::File::open( let (header, data) = wav::read(&mut std::fs::File::open(
&std::env::var("WAV").unwrap_or(String::from("../git.d/samples/jfk.wav")) &std::env::var("WAV").unwrap_or(String::from("../git.d/samples/jfk.wav"))
@ -27,43 +14,40 @@ fn main() {
let data16 = data.as_sixteen().expect("wav is not 32bit floats"); let data16 = data.as_sixteen().expect("wav is not 32bit floats");
let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16); let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16);
state.full(params, &audio_data[..]).expect("failed to run model"); for _ in 0..3 {
let result = w.transcribe(&audio_data).unwrap();
let num_segments = state.full_n_segments().expect("failed to get number of segments"); println!("{}", result);
for i in 0..num_segments {
let segment = state.full_get_segment_text(i).expect("failed to get segment");
print!("{} ", segment);
} }
println!("");
} }
struct Whisper { struct Whisper {
ctx: WhisperContext, ctx: WhisperContext,
threads: i32,
} }
fn new_whisper(model_path: String) -> Result<Whisper, String> { fn new_whisper(model_path: String, threads: i32) -> Result<Whisper, String> {
match WhisperContext::new(&model_path) { match WhisperContext::new(&model_path) {
Ok(ctx) => Ok(Whisper{ctx: ctx}), Ok(ctx) => Ok(Whisper{
ctx: ctx,
threads: threads,
}),
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)), Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
} }
} }
impl Whisper { impl Whisper {
fn transcribe(&self, data: Vec<f32>) -> Result<String, String> { fn transcribe(&self, data: &Vec<f32>) -> Result<String, String> {
match self._transcribe(data) { match self._transcribe(&data) {
Ok(result) => Ok(result), Ok(result) => Ok(result),
Err(msg) => Err(format!("failed to transcribe: {}", msg)), Err(msg) => Err(format!("failed to transcribe: {}", msg)),
} }
} }
fn _transcribe(&self, data: Vec<f32>) -> Result<String, WhisperError> { fn _transcribe(&self, data: &Vec<f32>) -> Result<String, WhisperError> {
println!("{:?}", data); //eprintln!("{:?} ({})", data, data.len());
let mut state = self.ctx.create_state()?;
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
params.set_n_threads( params.set_n_threads(self.threads);
std::env::var("P").unwrap_or(String::from("8")).parse::<i32>().expect("$P must be a number")
);
params.set_translate(false); params.set_translate(false);
params.set_detect_language(false); params.set_detect_language(false);
params.set_language(Some("en")); params.set_language(Some("en"));
@ -71,6 +55,8 @@ impl Whisper {
params.set_print_progress(false); params.set_print_progress(false);
params.set_print_realtime(false); params.set_print_realtime(false);
params.set_print_timestamps(false); params.set_print_timestamps(false);
let mut state = self.ctx.create_state()?;
state.full(params, &data[..])?; state.full(params, &data[..])?;
let num_segments = state.full_n_segments()?; let num_segments = state.full_n_segments()?;
@ -79,6 +65,7 @@ impl Whisper {
let segment = state.full_get_segment_text(i)?; let segment = state.full_get_segment_text(i)?;
result = format!("{} {}", result, segment); result = format!("{} {}", result, segment);
} }
Ok(result) Ok(result)
} }
} }