WIP trim the head and tail from text output because low confidence

master
Bel LaPointe 2023-12-19 09:09:38 -05:00
parent 116f3f58c9
commit 15a3f8430a
1 changed files with 77 additions and 8 deletions

View File

@ -18,6 +18,10 @@ struct Flags {
stream_retain: f32, stream_retain: f32,
#[arg(long, default_value = "5")] #[arg(long, default_value = "5")]
stream_step: u64, stream_step: u64,
#[arg(long, default_value = "1.0")]
stream_head: f32,
#[arg(long, default_value = "1.0")]
stream_tail: f32,
wav: Option<String>, wav: Option<String>,
} }
@ -25,7 +29,11 @@ struct Flags {
fn main() { fn main() {
let flags = Flags::parse(); let flags = Flags::parse();
let w = new_whisper(flags.model, flags.threads, Handler{}).unwrap(); let handler = Handler{
head: flags.stream_head,
tail: flags.stream_tail,
};
let w = new_whisper(flags.model, flags.threads, handler).unwrap();
let stream_retain = (flags.stream_retain * 16_000.0) as usize; let stream_retain = (flags.stream_retain * 16_000.0) as usize;
let stream_step = Duration::new(flags.stream_step, 0); let stream_step = Duration::new(flags.stream_step, 0);
match flags.wav { match flags.wav {
@ -139,7 +147,7 @@ impl WhisperEngine {
} }
} }
fn transcribe(&self, data: &Vec<f32>) -> Result<String, WhisperError> { fn transcribe(&self, data: &Vec<f32>) -> Result<Whispered, WhisperError> {
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
params.set_no_context(true); params.set_no_context(true);
params.set_n_threads(self.threads); params.set_n_threads(self.threads);
@ -154,11 +162,13 @@ impl WhisperEngine {
let mut state = self.ctx.create_state()?; let mut state = self.ctx.create_state()?;
state.full(params, &data[..])?; state.full(params, &data[..])?;
let mut result = new_whispered();
let num_segments = state.full_n_segments()?; let num_segments = state.full_n_segments()?;
let mut result = "".to_string();
for i in 0..num_segments { for i in 0..num_segments {
let segment = state.full_get_segment_text(i)?; let data = state.full_get_segment_text(i)?;
result = format!("{} {}", result, segment); let start = state.full_get_segment_t0(i)?;
let stop = state.full_get_segment_t1(i)?;
result.push(data, start, stop);
} }
Ok(result) Ok(result)
@ -170,12 +180,71 @@ struct AWhisper {
ack: Option<std::sync::mpsc::SyncSender<bool>>, ack: Option<std::sync::mpsc::SyncSender<bool>>,
} }
struct Handler {} #[derive(Clone)]
struct Whispered {
data: Vec<AWhispered>,
}
#[derive(Clone)]
struct AWhispered {
data: String,
offset: i64,
length: i64,
}
fn new_whispered() -> Whispered {
Whispered{data: vec![]}
}
fn new_a_whispered(data: String, start: i64, stop: i64) -> AWhispered {
AWhispered{
data: data,
offset: start.clone(),
length: stop - start,
}
}
impl Whispered {
fn as_string(&self) -> String {
let mut result = "".to_string();
for i in 0..self.data.len() {
result = format!("{} {}", result, &self.data[i].data);
}
result
}
fn after(&self, t: &f32) -> Whispered {
let mut result = new_whispered();
self.data
.iter()
.filter(|x| x.offset as f32 > *t)
.for_each(|x| result.data.push(x.clone()));
result
}
fn before(&self, t: &f32) -> Whispered {
let mut result = new_whispered();
self.data
.iter()
.filter(|x| ((x.offset + x.length) as f32) < *t)
.for_each(|x| result.data.push(x.clone()));
result
}
fn push(&mut self, data: String, start: i64, stop: i64) {
self.data.push(new_a_whispered(data, start, stop));
}
}
struct Handler {
head: f32,
tail: f32,
}
impl Handler { impl Handler {
fn on_success(&self, result: String) { fn on_success(&self, result: Whispered) {
eprintln!("{}", chrono::Local::now()); eprintln!("{}", chrono::Local::now());
println!("{}", result); println!("{}", result.after(&self.head).before(&self.tail).as_string());
} }
fn on_error(&self, msg: String) { fn on_error(&self, msg: String) {
eprintln!("error: {}", msg); eprintln!("error: {}", msg);