From c23352b7e54267980955619ddd7cd9aec936c317 Mon Sep 17 00:00:00 2001
From: Bel LaPointe <153096461+breel-render@users.noreply.github.com>
Date: Thu, 21 Dec 2023 22:26:49 -0500
Subject: [PATCH] found an example that just works
---
.gitignore | 4 +
candle-wasm-examples-whisper/Cargo.toml | 52 ++
candle-wasm-examples-whisper/README.md | 68 +++
candle-wasm-examples-whisper/build-lib.sh | 2 +
candle-wasm-examples-whisper/index.html | 40 ++
candle-wasm-examples-whisper/lib-example.html | 351 +++++++++++++
candle-wasm-examples-whisper/main.js | 6 +
candle-wasm-examples-whisper/src/app.rs | 283 ++++++++++
candle-wasm-examples-whisper/src/audio.rs | 215 ++++++++
candle-wasm-examples-whisper/src/bin/app.rs | 4 +
candle-wasm-examples-whisper/src/bin/m.rs | 53 ++
.../src/bin/worker.rs | 4 +
candle-wasm-examples-whisper/src/languages.rs | 101 ++++
candle-wasm-examples-whisper/src/lib.rs | 29 ++
candle-wasm-examples-whisper/src/worker.rs | 492 ++++++++++++++++++
candle-wasm-examples-whisper/whisperWorker.js | 116 +++++
16 files changed, 1820 insertions(+)
create mode 100644 candle-wasm-examples-whisper/Cargo.toml
create mode 100644 candle-wasm-examples-whisper/README.md
create mode 100644 candle-wasm-examples-whisper/build-lib.sh
create mode 100644 candle-wasm-examples-whisper/index.html
create mode 100644 candle-wasm-examples-whisper/lib-example.html
create mode 100644 candle-wasm-examples-whisper/main.js
create mode 100644 candle-wasm-examples-whisper/src/app.rs
create mode 100644 candle-wasm-examples-whisper/src/audio.rs
create mode 100644 candle-wasm-examples-whisper/src/bin/app.rs
create mode 100644 candle-wasm-examples-whisper/src/bin/m.rs
create mode 100644 candle-wasm-examples-whisper/src/bin/worker.rs
create mode 100644 candle-wasm-examples-whisper/src/languages.rs
create mode 100644 candle-wasm-examples-whisper/src/lib.rs
create mode 100644 candle-wasm-examples-whisper/src/worker.rs
create mode 100644 candle-wasm-examples-whisper/whisperWorker.js
diff --git a/.gitignore b/.gitignore
index 4484247..c38dbcc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,7 @@ snowboy-2022/snowboy
**/*.wav
snowboy-2022/Dockerfile
**/target
+/candle-wasm-examples-whisper/mel_filters*
+/candle-wasm-examples-whisper/whisper-tiny*
+/candle-wasm-examples-whisper/quantized*
+/candle-wasm-examples-whisper/audios*
diff --git a/candle-wasm-examples-whisper/Cargo.toml b/candle-wasm-examples-whisper/Cargo.toml
new file mode 100644
index 0000000..63ecefd
--- /dev/null
+++ b/candle-wasm-examples-whisper/Cargo.toml
@@ -0,0 +1,52 @@
+[package]
+name = "candle-wasm-example-whisper"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+
+[dependencies]
+candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.3.2" }
+candle-transformers = { path = "../../candle-transformers", version = "0.3.2" }
+num-traits = { workspace = true }
+tokenizers = { workspace = true, features = ["unstable_wasm"] }
+
+# App crates.
+anyhow = { workspace = true }
+log = { workspace = true }
+rand = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+wav = { workspace = true }
+safetensors = { workspace = true }
+
+# Wasm specific crates.
+getrandom = { version = "0.2", features = ["js"] }
+gloo = "0.8"
+js-sys = "0.3.64"
+wasm-bindgen = "0.2.87"
+wasm-bindgen-futures = "0.4.37"
+wasm-logger = "0.2"
+yew-agent = "0.2.0"
+yew = { version = "0.20.0", features = ["csr"] }
+
+[dependencies.web-sys]
+version = "0.3.64"
+features = [
+ 'Blob',
+ 'Document',
+ 'Element',
+ 'HtmlElement',
+ 'Node',
+ 'Window',
+ 'Request',
+ 'RequestCache',
+ 'RequestInit',
+ 'RequestMode',
+ 'Response',
+ 'Performance',
+]
diff --git a/candle-wasm-examples-whisper/README.md b/candle-wasm-examples-whisper/README.md
new file mode 100644
index 0000000..85a5234
--- /dev/null
+++ b/candle-wasm-examples-whisper/README.md
@@ -0,0 +1,68 @@
+## Running Whisper Examples
+
+Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes.
+
+### Pure Rust UI
+
+To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install)
+From the `candle-wasm-examples/whisper` directory run:
+
+Download assets:
+
+```bash
+# mel filters
+wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors
+# Model and tokenizer tiny.en
+wget -c https://huggingface.co/openai/whisper-tiny.en/resolve/main/model.safetensors -P whisper-tiny.en
+wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/tokenizer.json -P whisper-tiny.en
+wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/config.json -P whisper-tiny.en
+# model and tokenizer tiny multilanguage
+wget -c https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors -P whisper-tiny
+wget -c https://huggingface.co/openai/whisper-tiny/raw/main/tokenizer.json -P whisper-tiny
+wget -c https://huggingface.co/openai/whisper-tiny/raw/main/config.json -P whisper-tiny
+
+#quantized
+wget -c https://huggingface.co/lmz/candle-whisper/resolve/main/model-tiny-en-q80.gguf -P quantized
+wget -c https://huggingface.co/lmz/candle-whisper/raw/main/tokenizer-tiny-en.json -P quantized
+wget -c https://huggingface.co/lmz/candle-whisper/raw/main/config-tiny-en.json -P quantized
+
+
+
+# Audio samples
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -P audios
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -P audios
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -P audios
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -P audios
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -P audios
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -P audios
+
+```
+
+Run hot reload server:
+
+```bash
+trunk serve --release --public-url / --port 8080
+```
+
+### Vanilla JS and WebWorkers
+
+To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
+
+```bash
+sh build-lib.sh
+```
+
+This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
+
+```js
+import init, { Decoder } from "./build/m.js";
+```
+
+The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
+Finally, you can preview the example by running a local HTTP server. For example:
+
+```bash
+python -m http.server
+```
+
+Then open `http://localhost:8000/lib-example.html` in your browser.
diff --git a/candle-wasm-examples-whisper/build-lib.sh b/candle-wasm-examples-whisper/build-lib.sh
new file mode 100644
index 0000000..b0ebb18
--- /dev/null
+++ b/candle-wasm-examples-whisper/build-lib.sh
@@ -0,0 +1,2 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
diff --git a/candle-wasm-examples-whisper/index.html b/candle-wasm-examples-whisper/index.html
new file mode 100644
index 0000000..9842820
--- /dev/null
+++ b/candle-wasm-examples-whisper/index.html
@@ -0,0 +1,40 @@
+
+
+
+
+ Welcome to Candle!
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/candle-wasm-examples-whisper/lib-example.html b/candle-wasm-examples-whisper/lib-example.html
new file mode 100644
index 0000000..1154c48
--- /dev/null
+++ b/candle-wasm-examples-whisper/lib-example.html
@@ -0,0 +1,351 @@
+
+
+
+ Candle Whisper Rust/WASM
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 🕯️
+
+
Candle Whisper
+
Rust/WASM Demo
+
+ Transcribe audio in the browser using rust/wasm with an audio file.
+ This demo uses the
+
+ OpenAI Whisper models
+
+ and WASM runtime built with
+ Candle
+
+
+
+
+
+ Models Options:
+
+
+
+
+
+
+
+
+
+
+
+
+ Drag and drop your audio here
+ or
+ Click to upload
+
+
+
+
+
+
+
+
+
+
Examples:
+
+ jfk.wav
+ (352 kB)
+
+
+ a13.wav
+ (960 kB)
+
+
+ mm0.wav
+ (957 kB)
+
+
+ gb0.wav
+ (4.08 MB)
+
+
+ gb1.wav
+ (6.36 MB)
+
+
+ hp0.wav
+ (8.75 MB)
+
+
+
+
+
+
+ Transcribe Audio
+
+
+
+
Transcription:
+
+
+
No transcription results yet
+
+
+
+
+
diff --git a/candle-wasm-examples-whisper/main.js b/candle-wasm-examples-whisper/main.js
new file mode 100644
index 0000000..c27e0d6
--- /dev/null
+++ b/candle-wasm-examples-whisper/main.js
@@ -0,0 +1,6 @@
+import init, { run_app } from './pkg/candle_wasm_example_whisper.js';
+async function main() {
+ await init('/pkg/candle_wasm_example_whisper_bg.wasm');
+ run_app();
+}
+main()
diff --git a/candle-wasm-examples-whisper/src/app.rs b/candle-wasm-examples-whisper/src/app.rs
new file mode 100644
index 0000000..1cb3119
--- /dev/null
+++ b/candle-wasm-examples-whisper/src/app.rs
@@ -0,0 +1,283 @@
+use crate::console_log;
+use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput};
+use js_sys::Date;
+use wasm_bindgen::prelude::*;
+use wasm_bindgen_futures::JsFuture;
+use yew::{html, Component, Context, Html};
+use yew_agent::{Bridge, Bridged};
+
+const SAMPLE_NAMES: [&str; 6] = [
+ "audios/samples_jfk.wav",
+ "audios/samples_a13.wav",
+ "audios/samples_gb0.wav",
+ "audios/samples_gb1.wav",
+ "audios/samples_hp0.wav",
+ "audios/samples_mm0.wav",
+];
+
+async fn fetch_url(url: &str) -> Result, JsValue> {
+ use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response};
+ let window = web_sys::window().ok_or("window")?;
+ let mut opts = RequestInit::new();
+ let opts = opts
+ .method("GET")
+ .mode(RequestMode::Cors)
+ .cache(RequestCache::NoCache);
+
+ let request = Request::new_with_str_and_init(url, opts)?;
+
+ let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
+
+ // `resp_value` is a `Response` object.
+ assert!(resp_value.is_instance_of::());
+ let resp: Response = resp_value.dyn_into()?;
+ let data = JsFuture::from(resp.blob()?).await?;
+ let blob = web_sys::Blob::from(data);
+ let array_buffer = JsFuture::from(blob.array_buffer()).await?;
+ let data = js_sys::Uint8Array::new(&array_buffer).to_vec();
+ Ok(data)
+}
+
+pub enum Msg {
+ Run(usize),
+ UpdateStatus(String),
+ SetDecoder(ModelData),
+ WorkerInMsg(WorkerInput),
+ WorkerOutMsg(Result),
+}
+
+pub struct CurrentDecode {
+ start_time: Option,
+}
+
+pub struct App {
+ status: String,
+ loaded: bool,
+ segments: Vec,
+ current_decode: Option,
+ worker: Box>,
+}
+
+async fn model_data_load() -> Result {
+ let quantized = false;
+ let is_multilingual = false;
+
+ let (tokenizer, mel_filters, weights, config) = if quantized {
+ console_log!("loading quantized weights");
+ let tokenizer = fetch_url("quantized/tokenizer-tiny-en.json").await?;
+ let mel_filters = fetch_url("mel_filters.safetensors").await?;
+ let weights = fetch_url("quantized/model-tiny-en-q80.gguf").await?;
+ let config = fetch_url("quantized/config-tiny-en.json").await?;
+ (tokenizer, mel_filters, weights, config)
+ } else {
+ console_log!("loading float weights");
+ if is_multilingual {
+ let mel_filters = fetch_url("mel_filters.safetensors").await?;
+ let tokenizer = fetch_url("whisper-tiny/tokenizer.json").await?;
+ let weights = fetch_url("whisper-tiny/model.safetensors").await?;
+ let config = fetch_url("whisper-tiny/config.json").await?;
+ (tokenizer, mel_filters, weights, config)
+ } else {
+ let mel_filters = fetch_url("mel_filters.safetensors").await?;
+ let tokenizer = fetch_url("whisper-tiny.en/tokenizer.json").await?;
+ let weights = fetch_url("whisper-tiny.en/model.safetensors").await?;
+ let config = fetch_url("whisper-tiny.en/config.json").await?;
+ (tokenizer, mel_filters, weights, config)
+ }
+ };
+
+ let timestamps = true;
+ let _task = Some("transcribe".to_string());
+ console_log!("{}", weights.len());
+ Ok(ModelData {
+ tokenizer,
+ mel_filters,
+ weights,
+ config,
+ quantized,
+ timestamps,
+ task: None,
+ is_multilingual,
+ language: None,
+ })
+}
+
+fn performance_now() -> Option {
+ let window = web_sys::window()?;
+ let performance = window.performance()?;
+ Some(performance.now() / 1000.)
+}
+
+impl Component for App {
+ type Message = Msg;
+ type Properties = ();
+
+ fn create(ctx: &Context) -> Self {
+ let status = "loading weights".to_string();
+ let cb = {
+ let link = ctx.link().clone();
+ move |e| link.send_message(Self::Message::WorkerOutMsg(e))
+ };
+ let worker = Worker::bridge(std::rc::Rc::new(cb));
+ Self {
+ status,
+ segments: vec![],
+ current_decode: None,
+ worker,
+ loaded: false,
+ }
+ }
+
+ fn rendered(&mut self, ctx: &Context, first_render: bool) {
+ if first_render {
+ ctx.link().send_future(async {
+ match model_data_load().await {
+ Err(err) => {
+ let status = format!("{err:?}");
+ Msg::UpdateStatus(status)
+ }
+ Ok(model_data) => Msg::SetDecoder(model_data),
+ }
+ });
+ }
+ }
+
+ fn update(&mut self, ctx: &Context, msg: Self::Message) -> bool {
+ match msg {
+ Msg::SetDecoder(md) => {
+ self.status = "weights loaded successfully!".to_string();
+ self.loaded = true;
+ console_log!("loaded weights");
+ self.worker.send(WorkerInput::ModelData(md));
+ true
+ }
+ Msg::Run(sample_index) => {
+ let sample = SAMPLE_NAMES[sample_index];
+ if self.current_decode.is_some() {
+ self.status = "already decoding some sample at the moment".to_string()
+ } else {
+ let start_time = performance_now();
+ self.current_decode = Some(CurrentDecode { start_time });
+ self.status = format!("decoding {sample}");
+ self.segments.clear();
+ ctx.link().send_future(async move {
+ match fetch_url(sample).await {
+ Err(err) => {
+ let output = Err(format!("decoding error: {err:?}"));
+ // Mimic a worker output to so as to release current_decode
+ Msg::WorkerOutMsg(output)
+ }
+ Ok(wav_bytes) => {
+ Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
+ }
+ }
+ })
+ }
+ //
+ true
+ }
+ Msg::WorkerOutMsg(output) => {
+ let dt = self.current_decode.as_ref().and_then(|current_decode| {
+ current_decode.start_time.and_then(|start_time| {
+ performance_now().map(|stop_time| stop_time - start_time)
+ })
+ });
+ self.current_decode = None;
+ match output {
+ Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
+ Ok(WorkerOutput::Decoded(segments)) => {
+ self.status = match dt {
+ None => "decoding succeeded!".to_string(),
+ Some(dt) => format!("decoding succeeded in {:.2}s", dt),
+ };
+ self.segments = segments;
+ }
+ Err(err) => {
+ self.status = format!("decoding error {err:?}");
+ }
+ }
+ true
+ }
+ Msg::WorkerInMsg(inp) => {
+ self.worker.send(inp);
+ true
+ }
+ Msg::UpdateStatus(status) => {
+ self.status = status;
+ true
+ }
+ }
+ }
+
+ fn view(&self, ctx: &Context) -> Html {
+ html! {
+
+
+
+
+ {"Sample"}
+
+
+
+
+
+ {
+ SAMPLE_NAMES.iter().enumerate().map(|(i, name)| { html! {
+
+ {name}
+
+ { if self.loaded {
+ html!( { "run" } )
+ }else{html!()}
+ }
+
+ }
+ }).collect::()
+ }
+
+
+
+ {&self.status}
+
+ {
+ if !self.loaded{
+ html! {
}
+ } else if self.current_decode.is_some() {
+ html! {
}
+ } else { html!{
+
+
+ {
+ self.segments.iter().map(|segment| { html! {
+ <>
+
+ {
+ format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})",
+ segment.start,
+ segment.start + segment.duration,
+ segment.dr.avg_logprob,
+ segment.dr.no_speech_prob,
+ )
+ }
+
+
+ {&segment.dr.text}
+
+ >
+ } }).collect::()
+ }
+
+
+ }
+ }
+ }
+
+ // Display the current date and time the page was rendered
+
+
+ }
+ }
+}
diff --git a/candle-wasm-examples-whisper/src/audio.rs b/candle-wasm-examples-whisper/src/audio.rs
new file mode 100644
index 0000000..b87f7df
--- /dev/null
+++ b/candle-wasm-examples-whisper/src/audio.rs
@@ -0,0 +1,215 @@
+// Audio processing code, adapted from whisper.cpp
+// https://github.com/ggerganov/whisper.cpp
+use super::worker;
+
+pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
+
+impl Float for f32 {}
+impl Float for f64 {}
+
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
+fn fft(inp: &[T]) -> Vec {
+ let n = inp.len();
+ let zero = T::zero();
+ if n == 1 {
+ return vec![inp[0], zero];
+ }
+ if n % 2 == 1 {
+ return dft(inp);
+ }
+ let mut out = vec![zero; n * 2];
+
+ let mut even = Vec::with_capacity(n / 2);
+ let mut odd = Vec::with_capacity(n / 2);
+
+ for (i, &inp) in inp.iter().enumerate() {
+ if i % 2 == 0 {
+ even.push(inp)
+ } else {
+ odd.push(inp);
+ }
+ }
+
+ let even_fft = fft(&even);
+ let odd_fft = fft(&odd);
+
+ let two_pi = T::PI() + T::PI();
+ let n_t = T::from(n).unwrap();
+ for k in 0..n / 2 {
+ let k_t = T::from(k).unwrap();
+ let theta = two_pi * k_t / n_t;
+ let re = theta.cos();
+ let im = -theta.sin();
+
+ let re_odd = odd_fft[2 * k];
+ let im_odd = odd_fft[2 * k + 1];
+
+ out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
+ out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
+
+ out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
+ out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
+ }
+ out
+}
+
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
+fn dft(inp: &[T]) -> Vec {
+ let zero = T::zero();
+ let n = inp.len();
+ let two_pi = T::PI() + T::PI();
+
+ let mut out = Vec::with_capacity(2 * n);
+ let n_t = T::from(n).unwrap();
+ for k in 0..n {
+ let k_t = T::from(k).unwrap();
+ let mut re = zero;
+ let mut im = zero;
+
+ for (j, &inp) in inp.iter().enumerate() {
+ let j_t = T::from(j).unwrap();
+ let angle = two_pi * k_t * j_t / n_t;
+ re += inp * angle.cos();
+ im -= inp * angle.sin();
+ }
+
+ out.push(re);
+ out.push(im);
+ }
+ out
+}
+
+#[allow(clippy::too_many_arguments)]
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
+fn log_mel_spectrogram_w(
+ ith: usize,
+ hann: &[T],
+ samples: &[T],
+ filters: &[T],
+ fft_size: usize,
+ fft_step: usize,
+ speed_up: bool,
+ n_len: usize,
+ n_mel: usize,
+ n_threads: usize,
+) -> Vec {
+ let n_fft = if speed_up {
+ 1 + fft_size / 4
+ } else {
+ 1 + fft_size / 2
+ };
+
+ let zero = T::zero();
+ let half = T::from(0.5).unwrap();
+ let mut fft_in = vec![zero; fft_size];
+ let mut mel = vec![zero; n_len * n_mel];
+
+ for i in (ith..n_len).step_by(n_threads) {
+ let offset = i * fft_step;
+
+ // apply Hanning window
+ for j in 0..fft_size {
+ fft_in[j] = if offset + j < samples.len() {
+ hann[j] * samples[offset + j]
+ } else {
+ zero
+ }
+ }
+
+ // FFT -> mag^2
+ let mut fft_out: Vec = fft(&fft_in);
+
+ for j in 0..fft_size {
+ fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
+ }
+ for j in 1..fft_size / 2 {
+ let v = fft_out[fft_size - j];
+ fft_out[j] += v;
+ }
+
+ if speed_up {
+ // scale down in the frequency domain results in a speed up in the time domain
+ for j in 0..n_fft {
+ fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
+ }
+ }
+
+ // mel spectrogram
+ for j in 0..n_mel {
+ let mut sum = zero;
+ for k in 0..n_fft {
+ sum += fft_out[k] * filters[j * n_fft + k];
+ }
+ mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
+ }
+ }
+ mel
+}
+
+fn log_mel_spectrogram_(
+ samples: &[T],
+ filters: &[T],
+ fft_size: usize,
+ fft_step: usize,
+ n_mel: usize,
+ speed_up: bool,
+) -> Vec {
+ let zero = T::zero();
+ let two_pi = T::PI() + T::PI();
+ let half = T::from(0.5).unwrap();
+ let one = T::from(1.0).unwrap();
+ let four = T::from(4.0).unwrap();
+ let fft_size_t = T::from(fft_size).unwrap();
+
+ let hann: Vec = (0..fft_size)
+ .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
+ .collect();
+ let n_len = samples.len() / fft_step;
+
+ // pad audio with at least one extra chunk of zeros
+ let pad = 100 * worker::m::CHUNK_LENGTH / 2;
+ let n_len = if n_len % pad != 0 {
+ (n_len / pad + 1) * pad
+ } else {
+ n_len
+ };
+ let n_len = n_len + pad;
+ let samples = {
+ let mut samples_padded = samples.to_vec();
+ let to_add = n_len * fft_step - samples.len();
+ samples_padded.extend(std::iter::repeat(zero).take(to_add));
+ samples_padded
+ };
+
+ // Use a single thread for now.
+ let mut mel = log_mel_spectrogram_w(
+ 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
+ );
+ let mmax = mel
+ .iter()
+ .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
+ .copied()
+ .unwrap_or(zero)
+ - T::from(8).unwrap();
+ for m in mel.iter_mut() {
+ let v = T::max(*m, mmax);
+ *m = v / four + one
+ }
+ mel
+}
+
+pub fn pcm_to_mel(
+ cfg: &worker::m::Config,
+ samples: &[T],
+ filters: &[T],
+) -> anyhow::Result> {
+ let mel = log_mel_spectrogram_(
+ samples,
+ filters,
+ worker::m::N_FFT,
+ worker::m::HOP_LENGTH,
+ cfg.num_mel_bins,
+ false,
+ );
+ Ok(mel)
+}
diff --git a/candle-wasm-examples-whisper/src/bin/app.rs b/candle-wasm-examples-whisper/src/bin/app.rs
new file mode 100644
index 0000000..89efa7f
--- /dev/null
+++ b/candle-wasm-examples-whisper/src/bin/app.rs
@@ -0,0 +1,4 @@
+fn main() {
+ wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
+ yew::Renderer::::new().render();
+}
diff --git a/candle-wasm-examples-whisper/src/bin/m.rs b/candle-wasm-examples-whisper/src/bin/m.rs
new file mode 100644
index 0000000..67b7a18
--- /dev/null
+++ b/candle-wasm-examples-whisper/src/bin/m.rs
@@ -0,0 +1,53 @@
+use candle_wasm_example_whisper::worker::{Decoder as D, ModelData};
+use wasm_bindgen::prelude::*;
+
+#[wasm_bindgen]
+pub struct Decoder {
+ decoder: D,
+}
+
+#[wasm_bindgen]
+impl Decoder {
+ #[wasm_bindgen(constructor)]
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ weights: Vec,
+ tokenizer: Vec,
+ mel_filters: Vec,
+ config: Vec,
+ quantized: bool,
+ is_multilingual: bool,
+ timestamps: bool,
+ task: Option,
+ language: Option,
+ ) -> Result {
+ let decoder = D::load(ModelData {
+ tokenizer,
+ mel_filters,
+ config,
+ quantized,
+ weights,
+ is_multilingual,
+ timestamps,
+ task,
+ language,
+ });
+
+ match decoder {
+ Ok(decoder) => Ok(Self { decoder }),
+ Err(e) => Err(JsError::new(&e.to_string())),
+ }
+ }
+
+ #[wasm_bindgen]
+ pub fn decode(&mut self, wav_input: Vec) -> Result {
+ let segments = self
+ .decoder
+ .convert_and_run(&wav_input)
+ .map_err(|e| JsError::new(&e.to_string()))?;
+ let json = serde_json::to_string(&segments)?;
+ Ok(json)
+ }
+}
+
+fn main() {}
diff --git a/candle-wasm-examples-whisper/src/bin/worker.rs b/candle-wasm-examples-whisper/src/bin/worker.rs
new file mode 100644
index 0000000..b8c16b5
--- /dev/null
+++ b/candle-wasm-examples-whisper/src/bin/worker.rs
@@ -0,0 +1,4 @@
+use yew_agent::PublicWorker;
+fn main() {
+ candle_wasm_example_whisper::Worker::register();
+}
diff --git a/candle-wasm-examples-whisper/src/languages.rs b/candle-wasm-examples-whisper/src/languages.rs
new file mode 100644
index 0000000..fcbf9a7
--- /dev/null
+++ b/candle-wasm-examples-whisper/src/languages.rs
@@ -0,0 +1,101 @@
+pub const LANGUAGES: [(&str, &str); 99] = [
+ ("en", "english"),
+ ("zh", "chinese"),
+ ("de", "german"),
+ ("es", "spanish"),
+ ("ru", "russian"),
+ ("ko", "korean"),
+ ("fr", "french"),
+ ("ja", "japanese"),
+ ("pt", "portuguese"),
+ ("tr", "turkish"),
+ ("pl", "polish"),
+ ("ca", "catalan"),
+ ("nl", "dutch"),
+ ("ar", "arabic"),
+ ("sv", "swedish"),
+ ("it", "italian"),
+ ("id", "indonesian"),
+ ("hi", "hindi"),
+ ("fi", "finnish"),
+ ("vi", "vietnamese"),
+ ("he", "hebrew"),
+ ("uk", "ukrainian"),
+ ("el", "greek"),
+ ("ms", "malay"),
+ ("cs", "czech"),
+ ("ro", "romanian"),
+ ("da", "danish"),
+ ("hu", "hungarian"),
+ ("ta", "tamil"),
+ ("no", "norwegian"),
+ ("th", "thai"),
+ ("ur", "urdu"),
+ ("hr", "croatian"),
+ ("bg", "bulgarian"),
+ ("lt", "lithuanian"),
+ ("la", "latin"),
+ ("mi", "maori"),
+ ("ml", "malayalam"),
+ ("cy", "welsh"),
+ ("sk", "slovak"),
+ ("te", "telugu"),
+ ("fa", "persian"),
+ ("lv", "latvian"),
+ ("bn", "bengali"),
+ ("sr", "serbian"),
+ ("az", "azerbaijani"),
+ ("sl", "slovenian"),
+ ("kn", "kannada"),
+ ("et", "estonian"),
+ ("mk", "macedonian"),
+ ("br", "breton"),
+ ("eu", "basque"),
+ ("is", "icelandic"),
+ ("hy", "armenian"),
+ ("ne", "nepali"),
+ ("mn", "mongolian"),
+ ("bs", "bosnian"),
+ ("kk", "kazakh"),
+ ("sq", "albanian"),
+ ("sw", "swahili"),
+ ("gl", "galician"),
+ ("mr", "marathi"),
+ ("pa", "punjabi"),
+ ("si", "sinhala"),
+ ("km", "khmer"),
+ ("sn", "shona"),
+ ("yo", "yoruba"),
+ ("so", "somali"),
+ ("af", "afrikaans"),
+ ("oc", "occitan"),
+ ("ka", "georgian"),
+ ("be", "belarusian"),
+ ("tg", "tajik"),
+ ("sd", "sindhi"),
+ ("gu", "gujarati"),
+ ("am", "amharic"),
+ ("yi", "yiddish"),
+ ("lo", "lao"),
+ ("uz", "uzbek"),
+ ("fo", "faroese"),
+ ("ht", "haitian creole"),
+ ("ps", "pashto"),
+ ("tk", "turkmen"),
+ ("nn", "nynorsk"),
+ ("mt", "maltese"),
+ ("sa", "sanskrit"),
+ ("lb", "luxembourgish"),
+ ("my", "myanmar"),
+ ("bo", "tibetan"),
+ ("tl", "tagalog"),
+ ("mg", "malagasy"),
+ ("as", "assamese"),
+ ("tt", "tatar"),
+ ("haw", "hawaiian"),
+ ("ln", "lingala"),
+ ("ha", "hausa"),
+ ("ba", "bashkir"),
+ ("jw", "javanese"),
+ ("su", "sundanese"),
+];
diff --git a/candle-wasm-examples-whisper/src/lib.rs b/candle-wasm-examples-whisper/src/lib.rs
new file mode 100644
index 0000000..f183201
--- /dev/null
+++ b/candle-wasm-examples-whisper/src/lib.rs
@@ -0,0 +1,29 @@
+pub const WITH_TIMER: bool = true;
+
+struct Timer {
+ label: &'static str,
+}
+
+// impl Timer {
+// fn new(label: &'static str) -> Self {
+// if WITH_TIMER {
+// web_sys::console::time_with_label(label);
+// }
+// Self { label }
+// }
+// }
+
+impl Drop for Timer {
+ fn drop(&mut self) {
+ if WITH_TIMER {
+ web_sys::console::time_end_with_label(self.label)
+ }
+ }
+}
+
+mod app;
+mod audio;
+pub mod languages;
+pub mod worker;
+pub use app::App;
+pub use worker::Worker;
diff --git a/candle-wasm-examples-whisper/src/worker.rs b/candle-wasm-examples-whisper/src/worker.rs
new file mode 100644
index 0000000..fd91fa8
--- /dev/null
+++ b/candle-wasm-examples-whisper/src/worker.rs
@@ -0,0 +1,492 @@
+use crate::languages::LANGUAGES;
+use anyhow::Error as E;
+use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
+use candle_nn::{ops::softmax, VarBuilder};
+pub use candle_transformers::models::whisper::{self as m, Config};
+use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
+use serde::{Deserialize, Serialize};
+use tokenizers::Tokenizer;
+use wasm_bindgen::prelude::*;
+use yew_agent::{HandlerId, Public, WorkerLink};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string()))
+}
+
+pub const DTYPE: DType = DType::F32;
+
+pub enum Model {
+ Normal(m::model::Whisper),
+ Quantized(m::quantized_model::Whisper),
+}
+
+// Maybe we should use some traits rather than doing the dispatch for all these.
+impl Model {
+ pub fn config(&self) -> &Config {
+ match self {
+ Self::Normal(m) => &m.config,
+ Self::Quantized(m) => &m.config,
+ }
+ }
+
+ pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result {
+ match self {
+ Self::Normal(m) => m.encoder.forward(x, flush),
+ Self::Quantized(m) => m.encoder.forward(x, flush),
+ }
+ }
+
+ pub fn decoder_forward(
+ &mut self,
+ x: &Tensor,
+ xa: &Tensor,
+ flush: bool,
+ ) -> candle::Result {
+ match self {
+ Self::Normal(m) => m.decoder.forward(x, xa, flush),
+ Self::Quantized(m) => m.decoder.forward(x, xa, flush),
+ }
+ }
+
+ pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result {
+ match self {
+ Self::Normal(m) => m.decoder.final_linear(x),
+ Self::Quantized(m) => m.decoder.final_linear(x),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DecodingResult {
+ pub tokens: Vec,
+ pub text: String,
+ pub avg_logprob: f64,
+ pub no_speech_prob: f64,
+ temperature: f64,
+ compression_ratio: f64,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Segment {
+ pub start: f64,
+ pub duration: f64,
+ pub dr: DecodingResult,
+}
+
+pub struct Decoder {
+ model: Model,
+ rng: rand::rngs::StdRng,
+ task: Option,
+ language: Option,
+ is_multilingual: bool,
+ mel_filters: Vec,
+ timestamps: bool,
+ tokenizer: Tokenizer,
+ suppress_tokens: Tensor,
+ sot_token: u32,
+ transcribe_token: u32,
+ translate_token: u32,
+ eot_token: u32,
+ no_speech_token: u32,
+ no_timestamps_token: u32,
+}
+
+impl Decoder {
+ #[allow(clippy::too_many_arguments)]
+ fn new(
+ model: Model,
+ tokenizer: Tokenizer,
+ mel_filters: Vec,
+ device: &Device,
+ task: Option,
+ language: Option,
+ is_multilingual: bool,
+ timestamps: bool,
+ ) -> anyhow::Result {
+ let suppress_tokens: Vec = (0..model.config().vocab_size as u32)
+ .map(|i| {
+ if model.config().suppress_tokens.contains(&i) {
+ f32::NEG_INFINITY
+ } else {
+ 0f32
+ }
+ })
+ .collect();
+ let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
+ let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
+ let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
+ let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
+ let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
+ let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
+ let no_speech_token = m::NO_SPEECH_TOKENS
+ .iter()
+ .find_map(|token| token_id(&tokenizer, token).ok());
+ let no_speech_token = match no_speech_token {
+ None => anyhow::bail!("unable to find any non-speech token"),
+ Some(n) => n,
+ };
+ let seed = 299792458;
+ Ok(Self {
+ model,
+ rng: StdRng::seed_from_u64(seed),
+ tokenizer,
+ mel_filters,
+ task,
+ timestamps,
+ language,
+ is_multilingual,
+ suppress_tokens,
+ sot_token,
+ transcribe_token,
+ translate_token,
+ eot_token,
+ no_speech_token,
+ no_timestamps_token,
+ })
+ }
+
+ fn decode(&mut self, mel: &Tensor, t: f64) -> anyhow::Result {
+ let model = &mut self.model;
+ let language_token = match (self.is_multilingual, &self.language) {
+ (true, None) => Some(detect_language(model, &self.tokenizer, mel)?),
+ (false, None) => None,
+ (true, Some(language)) => {
+ match token_id(&self.tokenizer, &format!("<|{:?}|>", self.language)) {
+ Ok(token_id) => Some(token_id),
+ Err(_) => anyhow::bail!("language {language} is not supported"),
+ }
+ }
+ (false, Some(_)) => {
+ anyhow::bail!("a language cannot be set for non-multilingual models")
+ }
+ };
+
+ let audio_features = model.encoder_forward(mel, true)?;
+ println!("audio features: {:?}", audio_features.dims());
+ let sample_len = model.config().max_target_positions / 2;
+ let mut sum_logprob = 0f64;
+ let mut no_speech_prob = f64::NAN;
+ let mut tokens = vec![self.sot_token];
+ if let Some(language_token) = language_token {
+ tokens.push(language_token);
+ }
+ match self.task {
+ None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
+ Some(Task::Translate) => tokens.push(self.translate_token),
+ }
+ if !self.timestamps {
+ tokens.push(self.no_timestamps_token);
+ }
+ for i in 0..sample_len {
+ let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
+
+ // The model expects a batch dim but this inference loop does not handle
+ // it so we add it at this point.
+ let tokens_t = tokens_t.unsqueeze(0)?;
+ let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
+
+ // Extract the no speech probability on the first iteration by looking at the first
+ // token logits and the probability for the according token.
+ if i == 0 {
+ let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
+ no_speech_prob = softmax(&logits, 0)?
+ .i(self.no_speech_token as usize)?
+ .to_scalar::()? as f64;
+ }
+
+ let (_, seq_len, _) = ys.dims3()?;
+ let logits = model
+ .decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
+ .i(0)?
+ .i(0)?;
+ // TODO: Besides suppress tokens, we should apply the heuristics from
+ // ApplyTimestampRules, i.e.:
+ // - Timestamps come in pairs, except before EOT.
+ // - Timestamps should be non-decreasing.
+ // - If the sum of the probabilities of timestamps is higher than any other tokens,
+ // only consider timestamps when sampling.
+ // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
+ let logits = logits.broadcast_add(&self.suppress_tokens)?;
+ let next_token = if t > 0f64 {
+ let prs = softmax(&(&logits / t)?, 0)?;
+ let logits_v: Vec = prs.to_vec1()?;
+ let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
+ distr.sample(&mut self.rng) as u32
+ } else {
+ let logits_v: Vec = logits.to_vec1()?;
+ logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap()
+ };
+ tokens.push(next_token);
+ let prob = softmax(&logits, candle::D::Minus1)?
+ .i(next_token as usize)?
+ .to_scalar::()? as f64;
+ if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
+ break;
+ }
+ sum_logprob += prob.ln();
+ }
+ let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
+ let avg_logprob = sum_logprob / tokens.len() as f64;
+
+ Ok(DecodingResult {
+ tokens,
+ text,
+ avg_logprob,
+ no_speech_prob,
+ temperature: t,
+ compression_ratio: f64::NAN,
+ })
+ }
+
+ fn decode_with_fallback(&mut self, segment: &Tensor) -> anyhow::Result {
+ for (i, &t) in m::TEMPERATURES.iter().enumerate() {
+ let dr: Result = self.decode(segment, t);
+ if i == m::TEMPERATURES.len() - 1 {
+ return dr;
+ }
+ // On errors, we try again with a different temperature.
+ match dr {
+ Ok(dr) => {
+ let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
+ || dr.avg_logprob < m::LOGPROB_THRESHOLD;
+ if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
+ return Ok(dr);
+ }
+ }
+ Err(err) => {
+ console_log!("Error running at {t}: {err}")
+ }
+ }
+ }
+ unreachable!()
+ }
+
+ fn run(&mut self, mel: &Tensor) -> anyhow::Result> {
+ let (_, _, content_frames) = mel.dims3()?;
+ let mut seek = 0;
+ let mut segments = vec![];
+ while seek < content_frames {
+ let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
+ let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
+ let mel_segment = mel.narrow(2, seek, segment_size)?;
+ let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
+ let dr = self.decode_with_fallback(&mel_segment)?;
+ seek += segment_size;
+ if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
+ console_log!("no speech detected, skipping {seek} {dr:?}");
+ continue;
+ }
+ let segment = Segment {
+ start: time_offset,
+ duration: segment_duration,
+ dr,
+ };
+ console_log!("{seek}: {segment:?}");
+ segments.push(segment)
+ }
+ Ok(segments)
+ }
+
+ pub fn load(md: ModelData) -> anyhow::Result {
+ let device = Device::Cpu;
+ let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(E::msg)?;
+
+ let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
+ let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
+ console_log!("loaded mel filters {:?}", mel_filters.shape());
+ let mel_filters = mel_filters.flatten_all()?.to_vec1::()?;
+ let config: Config = serde_json::from_slice(&md.config)?;
+ let model = if md.quantized {
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
+ &md.weights,
+ )?;
+ Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
+ } else {
+ let vb = VarBuilder::from_buffered_safetensors(md.weights, m::DTYPE, &device)?;
+ Model::Normal(m::model::Whisper::load(&vb, config)?)
+ };
+ console_log!("done loading model");
+
+ let task = match md.task.as_deref() {
+ Some("translate") => Some(Task::Translate),
+ _ => Some(Task::Transcribe),
+ };
+
+ let decoder = Self::new(
+ model,
+ tokenizer,
+ mel_filters,
+ &device,
+ task,
+ md.language,
+ md.is_multilingual,
+ md.timestamps,
+ )?;
+ Ok(decoder)
+ }
+
+ pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result> {
+ let device = Device::Cpu;
+ let mut wav_input = std::io::Cursor::new(wav_input);
+ let (header, data) = wav::read(&mut wav_input)?;
+ console_log!("loaded wav data: {header:?}");
+ if header.sampling_rate != m::SAMPLE_RATE as u32 {
+ anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
+ }
+ let data = data.as_sixteen().expect("expected 16 bit wav file");
+ let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
+ .iter()
+ .map(|v| *v as f32 / 32768.)
+ .collect();
+ console_log!("pcm data loaded {}", pcm_data.len());
+ let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?;
+ let mel_len = mel.len();
+ let n_mels = self.model.config().num_mel_bins;
+ let mel = Tensor::from_vec(mel, (1, n_mels, mel_len / n_mels), &device)?;
+ console_log!("loaded mel: {:?}", mel.dims());
+ let segments = self.run(&mel)?;
+ Ok(segments)
+ }
+}
+
+/// Returns the token id for the selected language.
+pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result {
+ console_log!("detecting language");
+ let (_bsize, _, seq_len) = mel.dims3()?;
+ let mel = mel.narrow(
+ 2,
+ 0,
+ usize::min(seq_len, model.config().max_source_positions),
+ )?;
+ let device = mel.device();
+
+ let language_token_ids = LANGUAGES
+ .iter()
+ .map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
+ .map(|e| e.map_err(E::msg))
+ .collect::, E>>()?;
+
+ let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
+ let audio_features = model.encoder_forward(&mel, true)?;
+ let tokens = Tensor::new(&[[sot_token]], device)?;
+ let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
+ let ys = model.decoder_forward(&tokens, &audio_features, true)?;
+ let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
+ let logits = logits.index_select(&language_token_ids, 0)?;
+ let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
+ let probs = probs.to_vec1::()?;
+ let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::>();
+ probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
+ for ((_, language), p) in probs.iter().take(5) {
+ println!("{language}: {p}")
+ }
+ let token = &format!("<|{}|>", probs[0].0 .0);
+ let language = token_id(tokenizer, token)?;
+ console_log!("detected language: {language} {token}");
+ Ok(language)
+}
+pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result {
+ match tokenizer.token_to_id(token) {
+ None => candle::bail!("no token-id for {token}"),
+ Some(id) => Ok(id),
+ }
+}
+#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
+pub enum Task {
+ Transcribe,
+ Translate,
+}
+
+// Communication to the worker happens through bincode, the model weights and configs are fetched
+// on the main thread and transferred via the following structure.
+#[derive(Serialize, Deserialize)]
+pub struct ModelData {
+ pub weights: Vec,
+ pub tokenizer: Vec,
+ pub mel_filters: Vec,
+ pub config: Vec,
+ pub quantized: bool,
+ pub timestamps: bool,
+ pub is_multilingual: bool,
+ pub language: Option,
+ pub task: Option,
+}
+
+pub struct Worker {
+ link: WorkerLink,
+ decoder: Option,
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerInput {
+ ModelData(ModelData),
+ DecodeTask { wav_bytes: Vec },
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerOutput {
+ Decoded(Vec),
+ WeightsLoaded,
+}
+
+impl yew_agent::Worker for Worker {
+ type Input = WorkerInput;
+ type Message = ();
+ type Output = Result;
+ type Reach = Public;
+
+ fn create(link: WorkerLink) -> Self {
+ Self {
+ link,
+ decoder: None,
+ }
+ }
+
+ fn update(&mut self, _msg: Self::Message) {
+ // no messaging
+ }
+
+ fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
+ let output = match msg {
+ WorkerInput::ModelData(md) => match Decoder::load(md) {
+ Ok(decoder) => {
+ self.decoder = Some(decoder);
+ Ok(WorkerOutput::WeightsLoaded)
+ }
+ Err(err) => Err(format!("model creation error {err:?}")),
+ },
+ WorkerInput::DecodeTask { wav_bytes } => match &mut self.decoder {
+ None => Err("model has not been set".to_string()),
+ Some(decoder) => decoder
+ .convert_and_run(&wav_bytes)
+ .map(WorkerOutput::Decoded)
+ .map_err(|e| e.to_string()),
+ },
+ };
+ self.link.respond(id, output);
+ }
+
+ fn name_of_resource() -> &'static str {
+ "worker.js"
+ }
+
+ fn resource_path_is_relative() -> bool {
+ true
+ }
+}
diff --git a/candle-wasm-examples-whisper/whisperWorker.js b/candle-wasm-examples-whisper/whisperWorker.js
new file mode 100644
index 0000000..bd44f62
--- /dev/null
+++ b/candle-wasm-examples-whisper/whisperWorker.js
@@ -0,0 +1,116 @@
+//load the candle Whisper decoder wasm module
+import init, { Decoder } from "./build/m.js";
+
+async function fetchArrayBuffer(url) {
+ const cacheName = "whisper-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class Whisper {
+ static instance = {};
+ // Retrieve the Whisper model. When called for the first time,
+ // this will load the model and save it for future use.
+ static async getInstance(params) {
+ const {
+ weightsURL,
+ modelID,
+ tokenizerURL,
+ mel_filtersURL,
+ configURL,
+ quantized,
+ is_multilingual,
+ timestamps,
+ task,
+ language,
+ } = params;
+ // load individual modelID only once
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+ const [
+ weightsArrayU8,
+ tokenizerArrayU8,
+ mel_filtersArrayU8,
+ configArrayU8,
+ ] = await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ fetchArrayBuffer(mel_filtersURL),
+ fetchArrayBuffer(configURL),
+ ]);
+
+ this.instance[modelID] = new Decoder(
+ weightsArrayU8,
+ tokenizerArrayU8,
+ mel_filtersArrayU8,
+ configArrayU8,
+ quantized,
+ is_multilingual,
+ timestamps,
+ task,
+ language
+ );
+ } else {
+ self.postMessage({ status: "loading", message: "Model Already Loaded" });
+ }
+ return this.instance[modelID];
+ }
+}
+
+self.addEventListener("message", async (event) => {
+ const {
+ weightsURL,
+ modelID,
+ tokenizerURL,
+ configURL,
+ mel_filtersURL,
+ audioURL,
+ } = event.data;
+ try {
+ self.postMessage({ status: "decoding", message: "Starting Decoder" });
+ let quantized = false;
+ if (modelID.includes("quantized")) {
+ quantized = true;
+ }
+ let is_multilingual = false;
+ if (modelID.includes("multilingual")) {
+ is_multilingual = true;
+ }
+ let timestamps = true;
+ const decoder = await Whisper.getInstance({
+ weightsURL,
+ modelID,
+ tokenizerURL,
+ mel_filtersURL,
+ configURL,
+ quantized,
+ is_multilingual,
+ timestamps,
+ task: null,
+ language: null,
+ });
+
+ self.postMessage({ status: "decoding", message: "Loading Audio" });
+ const audioArrayU8 = await fetchArrayBuffer(audioURL);
+
+ self.postMessage({ status: "decoding", message: "Running Decoder..." });
+ const segments = decoder.decode(audioArrayU8);
+
+ // Send the segment back to the main thread as JSON
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: JSON.parse(segments),
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});