master
Bel LaPointe 2023-11-28 19:10:07 -07:00
parent f58e3a0331
commit 3093a91d84
1 changed files with 47 additions and 1 deletions

View File

@ -1,4 +1,4 @@
use whisper_rs::{WhisperContext, FullParams, SamplingStrategy}; use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError};
fn main() { fn main() {
let ctx = WhisperContext::new( let ctx = WhisperContext::new(
@ -36,3 +36,49 @@ fn main() {
} }
println!(""); println!("");
} }
struct Whisper {
ctx: WhisperContext,
}
fn new_whisper(model_path: String) -> Result<Whisper, String> {
match WhisperContext::new(&model_path) {
Ok(ctx) => Ok(Whisper{ctx: ctx}),
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
}
}
impl Whisper {
fn transcribe(&self, data: Vec<f32>) -> Result<String, String> {
match self._transcribe(data) {
Ok(result) => Ok(result),
Err(msg) => Err(format!("failed to transcribe: {}", msg)),
}
}
fn _transcribe(&self, data: Vec<f32>) -> Result<String, WhisperError> {
println!("{:?}", data);
let mut state = self.ctx.create_state()?;
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
params.set_n_threads(
std::env::var("P").unwrap_or(String::from("8")).parse::<i32>().expect("$P must be a number")
);
params.set_translate(false);
params.set_detect_language(false);
params.set_language(Some("en"));
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
state.full(params, &data[..])?;
let num_segments = state.full_n_segments()?;
let mut result = "".to_string();
for i in 0..num_segments {
let segment = state.full_get_segment_text(i)?;
result = format!("{} {}", result, segment);
}
Ok(result)
}
}