Compare commits
12 Commits
f58e3a0331
...
v0.1.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dd631872c | ||
|
|
72a1420638 | ||
|
|
1009c4230e | ||
|
|
30e5515da1 | ||
|
|
b4c9ecb98b | ||
|
|
4ef419e6c0 | ||
|
|
54964ec59b | ||
|
|
62e764436a | ||
|
|
d631def834 | ||
|
|
3168968cae | ||
|
|
437d7cac39 | ||
|
|
3093a91d84 |
859
rust-whisper.d/Cargo.lock
generated
859
rust-whisper.d/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,3 +9,8 @@ edition = "2021"
|
||||
whisper-rs = { path = "./gitea-whisper-rs", version = "0.8.0" }
|
||||
wav = "1"
|
||||
tokio = "1.27"
|
||||
cpal = "0.15.2"
|
||||
signal-hook = "0.3.17"
|
||||
byteorder = "1.5.0"
|
||||
chrono = "0.4.31"
|
||||
clap = { version = "4.4.10", features = ["derive"] }
|
||||
|
||||
@@ -1,16 +1,130 @@
|
||||
use whisper_rs::{WhisperContext, FullParams, SamplingStrategy};
|
||||
use whisper_rs::{WhisperContext, FullParams, SamplingStrategy, WhisperError};
|
||||
use cpal::traits::{HostTrait, DeviceTrait, StreamTrait};
|
||||
use signal_hook::{iterator::Signals, consts::signal::SIGINT};
|
||||
use std::time::{Duration, Instant};
|
||||
use chrono;
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
struct Flags {
|
||||
#[arg(long, default_value = "../models/ggml-tiny.en.bin")]
|
||||
model: String,
|
||||
|
||||
#[arg(long, default_value = "8")]
|
||||
threads: i32,
|
||||
|
||||
#[arg(long, default_value = "0.8")]
|
||||
stream_churn: f32,
|
||||
#[arg(long, default_value = "5")]
|
||||
stream_step: u64,
|
||||
|
||||
wav: Option<String>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let ctx = WhisperContext::new(
|
||||
&std::env::var("MODEL").unwrap_or(String::from("../models/ggml-tiny.en.bin"))
|
||||
).expect("failed to load $MODEL");
|
||||
let mut state = ctx.create_state().expect("failed to create state");
|
||||
let flags = Flags::parse();
|
||||
|
||||
// create a params object
|
||||
let w = new_whisper(flags.model, flags.threads).unwrap();
|
||||
let stream_churn = flags.stream_churn;
|
||||
let stream_step = Duration::new(flags.stream_step, 0);
|
||||
match flags.wav {
|
||||
Some(wav) => {
|
||||
let (header, data) = wav::read(
|
||||
&mut std::fs::File::open(wav).expect("failed to open $WAV"),
|
||||
).expect("failed to decode $WAV");
|
||||
assert!(header.channel_count == 1);
|
||||
assert!(header.sampling_rate == 16000);
|
||||
let data16 = data.as_sixteen().expect("wav is not 32bit floats");
|
||||
let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16);
|
||||
|
||||
let result = w.transcribe(&audio_data).unwrap();
|
||||
println!("{}", result);
|
||||
},
|
||||
None => {
|
||||
let host = cpal::default_host();
|
||||
let device = host.default_input_device().unwrap();
|
||||
let cfg = device.supported_input_configs()
|
||||
.unwrap()
|
||||
.filter(|x| x.sample_format() == cpal::SampleFormat::F32)
|
||||
.nth(0)
|
||||
.unwrap()
|
||||
.with_max_sample_rate();
|
||||
|
||||
let channels = cfg.channels();
|
||||
let downsample_ratio = cfg.sample_rate().0 as f32 / 16000.0;
|
||||
let mut buffer = vec![];
|
||||
let mut last = Instant::now();
|
||||
let stream = device.build_input_stream(
|
||||
&cfg.clone().into(),
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let mono_data: Vec<f32> = data.iter().map(|x| *x).step_by(channels.into()).collect();
|
||||
let mut downsampled_data = vec![];
|
||||
for i in 0..(mono_data.len() as f32 / downsample_ratio) as usize {
|
||||
let mut upsampled = i as f32 * downsample_ratio;
|
||||
if upsampled > (mono_data.len()-1) as f32 {
|
||||
upsampled = (mono_data.len()-1) as f32
|
||||
}
|
||||
downsampled_data.push(mono_data[upsampled as usize]);
|
||||
}
|
||||
downsampled_data.iter().for_each(|x| buffer.push(*x));
|
||||
if Instant::now() - last > stream_step {
|
||||
let result = w.transcribe(&buffer).unwrap();
|
||||
eprintln!("{}", chrono::Local::now());
|
||||
println!("{}", result);
|
||||
|
||||
let retain = buffer.len() - (buffer.len() as f32 * stream_churn) as usize;
|
||||
for i in retain..buffer.len() {
|
||||
buffer[i - retain] = buffer[i]
|
||||
}
|
||||
buffer.truncate(retain);
|
||||
last = Instant::now();
|
||||
}
|
||||
},
|
||||
move |err| {
|
||||
eprintln!("input error: {}", err)
|
||||
},
|
||||
None,
|
||||
).unwrap();
|
||||
stream.play().unwrap();
|
||||
|
||||
eprintln!("listening on {}", device.name().unwrap());
|
||||
let mut signals = Signals::new(&[SIGINT]).unwrap();
|
||||
for sig in signals.forever() {
|
||||
eprintln!("sig {}", sig);
|
||||
break;
|
||||
}
|
||||
stream.pause().unwrap();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
struct Whisper {
|
||||
ctx: WhisperContext,
|
||||
threads: i32,
|
||||
}
|
||||
|
||||
fn new_whisper(model_path: String, threads: i32) -> Result<Whisper, String> {
|
||||
match WhisperContext::new(&model_path) {
|
||||
Ok(ctx) => Ok(Whisper{
|
||||
ctx: ctx,
|
||||
threads: threads,
|
||||
}),
|
||||
Err(msg) => Err(format!("failed to load {}: {}", model_path, msg)),
|
||||
}
|
||||
}
|
||||
|
||||
impl Whisper {
|
||||
fn transcribe(&self, data: &Vec<f32>) -> Result<String, String> {
|
||||
match self._transcribe(&data) {
|
||||
Ok(result) => Ok(result),
|
||||
Err(msg) => Err(format!("failed to transcribe: {}", msg)),
|
||||
}
|
||||
}
|
||||
|
||||
fn _transcribe(&self, data: &Vec<f32>) -> Result<String, WhisperError> {
|
||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
|
||||
params.set_n_threads(
|
||||
std::env::var("P").unwrap_or(String::from("8")).parse::<i32>().expect("$P must be a number")
|
||||
);
|
||||
params.set_no_context(true);
|
||||
params.set_n_threads(self.threads);
|
||||
params.set_translate(false);
|
||||
params.set_detect_language(false);
|
||||
params.set_language(Some("en"));
|
||||
@@ -19,20 +133,16 @@ fn main() {
|
||||
params.set_print_realtime(false);
|
||||
params.set_print_timestamps(false);
|
||||
|
||||
let (header, data) = wav::read(&mut std::fs::File::open(
|
||||
&std::env::var("WAV").unwrap_or(String::from("../git.d/samples/jfk.wav"))
|
||||
).expect("failed to open $WAV")).expect("failed to decode $WAV");
|
||||
assert!(header.channel_count == 1);
|
||||
assert!(header.sampling_rate == 16000);
|
||||
let data16 = data.as_sixteen().expect("wav is not 32bit floats");
|
||||
let audio_data = &whisper_rs::convert_integer_to_float_audio(&data16);
|
||||
let mut state = self.ctx.create_state()?;
|
||||
state.full(params, &data[..])?;
|
||||
|
||||
state.full(params, &audio_data[..]).expect("failed to run model");
|
||||
|
||||
let num_segments = state.full_n_segments().expect("failed to get number of segments");
|
||||
let num_segments = state.full_n_segments()?;
|
||||
let mut result = "".to_string();
|
||||
for i in 0..num_segments {
|
||||
let segment = state.full_get_segment_text(i).expect("failed to get segment");
|
||||
print!("{} ", segment);
|
||||
let segment = state.full_get_segment_text(i)?;
|
||||
result = format!("{} {}", result, segment);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
println!("");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user