found an example that just works

master
Bel LaPointe 2023-12-21 22:26:49 -05:00
parent 1cceb773f2
commit c23352b7e5
16 changed files with 1820 additions and 0 deletions

4
.gitignore vendored
View File

@ -9,3 +9,7 @@ snowboy-2022/snowboy
**/*.wav
snowboy-2022/Dockerfile
**/target
/candle-wasm-examples-whisper/mel_filters*
/candle-wasm-examples-whisper/whisper-tiny*
/candle-wasm-examples-whisper/quantized*
/candle-wasm-examples-whisper/audios*

View File

@ -0,0 +1,52 @@
[package]
name = "candle-wasm-example-whisper"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.2" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.2" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
# App crates.
anyhow = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
wav = { workspace = true }
safetensors = { workspace = true }
# Wasm specific crates.
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2"
yew-agent = "0.2.0"
yew = { version = "0.20.0", features = ["csr"] }
[dependencies.web-sys]
version = "0.3.64"
features = [
'Blob',
'Document',
'Element',
'HtmlElement',
'Node',
'Window',
'Request',
'RequestCache',
'RequestInit',
'RequestMode',
'Response',
'Performance',
]

View File

@ -0,0 +1,68 @@
## Running Whisper Examples
Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes.
### Pure Rust UI
To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install)
From the `candle-wasm-examples/whisper` directory run:
Download assets:
```bash
# mel filters
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors
# Model and tokenizer tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/resolve/main/model.safetensors -P whisper-tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/tokenizer.json -P whisper-tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/config.json -P whisper-tiny.en
# model and tokenizer tiny multilanguage
wget -c https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors -P whisper-tiny
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/tokenizer.json -P whisper-tiny
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/config.json -P whisper-tiny
#quantized
wget -c https://huggingface.co/lmz/candle-whisper/resolve/main/model-tiny-en-q80.gguf -P quantized
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/tokenizer-tiny-en.json -P quantized
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/config-tiny-en.json -P quantized
# Audio samples
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -P audios
```
Run hot reload server:
```bash
trunk serve --release --public-url / --port 8080
```
### Vanilla JS and WebWorkers
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
```bash
sh build-lib.sh
```
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
```js
import init, { Decoder } from "./build/m.js";
```
The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
Finally, you can preview the example by running a local HTTP server. For example:
```bash
python -m http.server
```
Then open `http://localhost:8000/lib-example.html` in your browser.

View File

@ -0,0 +1,2 @@
cargo build --target wasm32-unknown-unknown --release
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web

View File

@ -0,0 +1,40 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Welcome to Candle!</title>
<link data-trunk rel="copy-file" href="mel_filters.safetensors" />
<!-- samples -->
<link data-trunk rel="copy-dir" href="audios" />
<!-- tiny.en -->
<link data-trunk rel="copy-dir" href="whisper-tiny.en" />
<!-- tiny -->
<link data-trunk rel="copy-dir" href="whisper-tiny" />
<!-- quantized -->
<link data-trunk rel="copy-dir" href="quantized" />
<link
data-trunk
rel="rust"
href="Cargo.toml"
data-bin="app"
data-type="main" />
<link
data-trunk
rel="rust"
href="Cargo.toml"
data-bin="worker"
data-type="worker" />
<link
rel="stylesheet"
href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic" />
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css" />
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css" />
</head>
<body></body>
</html>

View File

@ -0,0 +1,351 @@
<html>
<head>
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
<title>Candle Whisper Rust/WASM</title>
</head>
<body></body>
</html>
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<style>
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
html,
body {
font-family: "Source Sans 3", sans-serif;
}
</style>
<script src="https://cdn.tailwindcss.com"></script>
<script type="module">
// base url for audio examples
const AUDIO_BASE_URL =
"https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/";
// models base url
const MODELS = {
tiny_multilingual: {
base_url: "https://huggingface.co/openai/whisper-tiny/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
size: "151 MB",
},
tiny_en: {
base_url:
"https://huggingface.co/openai/whisper-tiny.en/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
size: "151 MB",
},
tiny_quantized_multilingual_q80: {
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
model: "model-tiny-q80.gguf",
tokenizer: "tokenizer-tiny.json",
config: "config-tiny.json",
size: "41.5 MB",
},
tiny_en_quantized_q80: {
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
model: "model-tiny-q80.gguf",
tokenizer: "tokenizer-tiny-en.json",
config: "config-tiny-en.json",
size: "41.8 MB",
},
distil_medium_en: {
base_url:
"https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
size: "789 MB",
},
};
const modelEl = document.querySelector("#model");
Object.keys(MODELS).forEach((modelID) => {
const model = MODELS[modelID];
const option = document.createElement("option");
option.value = modelID;
option.textContent = `${modelID} (${model.size})`;
modelEl.appendChild(option);
});
const whisperWorker = new Worker("./whisperWorker.js", {
type: "module",
});
async function classifyAudio(
weightsURL, // URL to the weights file
modelID, // model ID
tokenizerURL, // URL to the tokenizer file
configURL, // model config URL
mel_filtersURL, // URL to the mel filters file
audioURL, // URL to the audio file
updateStatus // function to update the status
) {
return new Promise((resolve, reject) => {
whisperWorker.postMessage({
weightsURL,
modelID,
tokenizerURL,
configURL,
mel_filtersURL,
audioURL,
});
function messageHandler(event) {
console.log(event.data);
if ("status" in event.data) {
updateStatus(event.data);
}
if ("error" in event.data) {
whisperWorker.removeEventListener("message", messageHandler);
reject(new Error(event.data.error));
}
if (event.data.status === "complete") {
whisperWorker.removeEventListener("message", messageHandler);
resolve(event.data);
}
}
whisperWorker.addEventListener("message", messageHandler);
});
}
// keep track of the audio URL
let audioURL = null;
function setAudio(src) {
const audio = document.querySelector("#audio");
audio.src = src;
audio.controls = true;
audio.hidden = false;
document.querySelector("#detect").disabled = false;
audioURL = src;
}
// add event listener to audio buttons
document.querySelectorAll("#audios-select > button").forEach((target) => {
target.addEventListener("click", (e) => {
const value = target.dataset.value;
const href = AUDIO_BASE_URL + value;
setAudio(href);
});
});
//add event listener to file input
document.querySelector("#file-upload").addEventListener("change", (e) => {
const target = e.target;
if (target.files.length > 0) {
const href = URL.createObjectURL(target.files[0]);
setAudio(href);
}
});
// add event listener to drop-area
const dropArea = document.querySelector("#drop-area");
dropArea.addEventListener("dragenter", (e) => {
e.preventDefault();
dropArea.classList.add("border-blue-700");
});
dropArea.addEventListener("dragleave", (e) => {
e.preventDefault();
dropArea.classList.remove("border-blue-700");
});
dropArea.addEventListener("dragover", (e) => {
e.preventDefault();
dropArea.classList.add("border-blue-700");
});
dropArea.addEventListener("drop", (e) => {
e.preventDefault();
dropArea.classList.remove("border-blue-700");
const url = e.dataTransfer.getData("text/uri-list");
const files = e.dataTransfer.files;
if (files.length > 0) {
const href = URL.createObjectURL(files[0]);
setAudio(href);
} else if (url) {
setAudio(url);
}
});
// add event listener to detect button
document.querySelector("#detect").addEventListener("click", async () => {
if (audioURL === null) {
return;
}
const modelID = modelEl.value;
const model = MODELS[modelID];
const modelURL = model.base_url + model.model;
const tokenizerURL = model.base_url + model.tokenizer;
const configURL = model.base_url + model.config;
classifyAudio(
modelURL,
modelID,
tokenizerURL,
configURL,
"mel_filters.safetensors",
audioURL,
updateStatus
)
.then((result) => {
console.log("RESULT", result);
const { output } = result;
const text = output.map((segment) => segment.dr.text).join(" ");
console.log(text);
document.querySelector("#output-status").hidden = true;
document.querySelector("#output-generation").hidden = false;
document.querySelector("#output-generation").textContent = text;
})
.catch((error) => {
console.error(error);
});
});
function updateStatus(data) {
const { status, message } = data;
const button = document.querySelector("#detect");
if (status === "decoding" || status === "loading") {
button.disabled = true;
button.textContent = message;
} else if (status === "complete") {
button.disabled = false;
button.textContent = "Transcribe Audio";
}
}
</script>
</head>
<body class="container max-w-4xl mx-auto p-4">
<main class="grid grid-cols-1 gap-8 relative">
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
<div>
<h1 class="text-5xl font-bold">Candle Whisper</h1>
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
<p class="max-w-lg">
Transcribe audio in the browser using rust/wasm with an audio file.
This demo uses the
<a
href="https://huggingface.co/openai/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline">
OpenAI Whisper models
</a>
and WASM runtime built with
<a
href="https://github.com/huggingface/candle/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
>Candle
</a>
</p>
</div>
<div>
<label for="model" class="font-medium">Models Options: </label>
<select
id="model"
class="border-2 border-gray-500 rounded-md font-light">
</select>
</div>
<!-- drag and drop area -->
<div class="relative">
<div
id="drop-area"
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden">
<div
class="flex flex-col items-center justify-center space-y-1 text-center">
<svg
width="25"
height="25"
viewBox="0 0 25 25"
fill="none"
xmlns="http://www.w3.org/2000/svg">
<path
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
fill="#000" />
</svg>
<div class="flex text-sm text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700">
<span>Drag and drop your audio here</span>
<span class="block text-xs">or</span>
<span class="block text-xs">Click to upload</span>
</label>
</div>
<input
id="file-upload"
name="file-upload"
type="file"
accept="audio/*"
class="sr-only" />
</div>
<audio
id="audio"
hidden
controls
class="w-full p-2 select-none"></audio>
</div>
</div>
<div>
<div class="flex flex-wrap gap-3 items-center" id="audios-select">
<h3 class="font-medium">Examples:</h3>
<button
data-value="samples_jfk.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>jfk.wav</span>
<span class="text-xs block"> (352 kB)</span>
</button>
<button
data-value="samples_a13.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>a13.wav</span>
<span class="text-xs block"> (960 kB)</span>
</button>
<button
data-value="samples_mm0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>mm0.wav</span>
<span class="text-xs block new"> (957 kB)</span>
</button>
<button
data-value="samples_gb0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>gb0.wav </span>
<span class="text-xs block">(4.08 MB)</span>
</button>
<button
data-value="samples_gb1.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>gb1.wav </span>
<span class="text-xs block">(6.36 MB)</span>
</button>
<button
data-value="samples_hp0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>hp0.wav </span>
<span class="text-xs block">(8.75 MB)</span>
</button>
</div>
</div>
<div>
<button
id="detect"
disabled
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
Transcribe Audio
</button>
</div>
<div>
<h3 class="font-medium">Transcription:</h3>
<div
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2">
<p hidden id="output-generation" class="grid-rows-2"></p>
<span id="output-status" class="m-auto font-light"
>No transcription results yet</span
>
</div>
</div>
</main>
</body>
</html>

View File

@ -0,0 +1,6 @@
import init, { run_app } from './pkg/candle_wasm_example_whisper.js';
async function main() {
await init('/pkg/candle_wasm_example_whisper_bg.wasm');
run_app();
}
main()

View File

@ -0,0 +1,283 @@
use crate::console_log;
use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput};
use js_sys::Date;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use yew::{html, Component, Context, Html};
use yew_agent::{Bridge, Bridged};
const SAMPLE_NAMES: [&str; 6] = [
"audios/samples_jfk.wav",
"audios/samples_a13.wav",
"audios/samples_gb0.wav",
"audios/samples_gb1.wav",
"audios/samples_hp0.wav",
"audios/samples_mm0.wav",
];
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response};
let window = web_sys::window().ok_or("window")?;
let mut opts = RequestInit::new();
let opts = opts
.method("GET")
.mode(RequestMode::Cors)
.cache(RequestCache::NoCache);
let request = Request::new_with_str_and_init(url, opts)?;
let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
// `resp_value` is a `Response` object.
assert!(resp_value.is_instance_of::<Response>());
let resp: Response = resp_value.dyn_into()?;
let data = JsFuture::from(resp.blob()?).await?;
let blob = web_sys::Blob::from(data);
let array_buffer = JsFuture::from(blob.array_buffer()).await?;
let data = js_sys::Uint8Array::new(&array_buffer).to_vec();
Ok(data)
}
pub enum Msg {
Run(usize),
UpdateStatus(String),
SetDecoder(ModelData),
WorkerInMsg(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>),
}
pub struct CurrentDecode {
start_time: Option<f64>,
}
pub struct App {
status: String,
loaded: bool,
segments: Vec<Segment>,
current_decode: Option<CurrentDecode>,
worker: Box<dyn Bridge<Worker>>,
}
async fn model_data_load() -> Result<ModelData, JsValue> {
let quantized = false;
let is_multilingual = false;
let (tokenizer, mel_filters, weights, config) = if quantized {
console_log!("loading quantized weights");
let tokenizer = fetch_url("quantized/tokenizer-tiny-en.json").await?;
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let weights = fetch_url("quantized/model-tiny-en-q80.gguf").await?;
let config = fetch_url("quantized/config-tiny-en.json").await?;
(tokenizer, mel_filters, weights, config)
} else {
console_log!("loading float weights");
if is_multilingual {
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let tokenizer = fetch_url("whisper-tiny/tokenizer.json").await?;
let weights = fetch_url("whisper-tiny/model.safetensors").await?;
let config = fetch_url("whisper-tiny/config.json").await?;
(tokenizer, mel_filters, weights, config)
} else {
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let tokenizer = fetch_url("whisper-tiny.en/tokenizer.json").await?;
let weights = fetch_url("whisper-tiny.en/model.safetensors").await?;
let config = fetch_url("whisper-tiny.en/config.json").await?;
(tokenizer, mel_filters, weights, config)
}
};
let timestamps = true;
let _task = Some("transcribe".to_string());
console_log!("{}", weights.len());
Ok(ModelData {
tokenizer,
mel_filters,
weights,
config,
quantized,
timestamps,
task: None,
is_multilingual,
language: None,
})
}
fn performance_now() -> Option<f64> {
let window = web_sys::window()?;
let performance = window.performance()?;
Some(performance.now() / 1000.)
}
impl Component for App {
type Message = Msg;
type Properties = ();
fn create(ctx: &Context<Self>) -> Self {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
status,
segments: vec![],
current_decode: None,
worker,
loaded: false,
}
}
fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) {
if first_render {
ctx.link().send_future(async {
match model_data_load().await {
Err(err) => {
let status = format!("{err:?}");
Msg::UpdateStatus(status)
}
Ok(model_data) => Msg::SetDecoder(model_data),
}
});
}
}
fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
match msg {
Msg::SetDecoder(md) => {
self.status = "weights loaded successfully!".to_string();
self.loaded = true;
console_log!("loaded weights");
self.worker.send(WorkerInput::ModelData(md));
true
}
Msg::Run(sample_index) => {
let sample = SAMPLE_NAMES[sample_index];
if self.current_decode.is_some() {
self.status = "already decoding some sample at the moment".to_string()
} else {
let start_time = performance_now();
self.current_decode = Some(CurrentDecode { start_time });
self.status = format!("decoding {sample}");
self.segments.clear();
ctx.link().send_future(async move {
match fetch_url(sample).await {
Err(err) => {
let output = Err(format!("decoding error: {err:?}"));
// Mimic a worker output to so as to release current_decode
Msg::WorkerOutMsg(output)
}
Ok(wav_bytes) => {
Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
}
}
})
}
//
true
}
Msg::WorkerOutMsg(output) => {
let dt = self.current_decode.as_ref().and_then(|current_decode| {
current_decode.start_time.and_then(|start_time| {
performance_now().map(|stop_time| stop_time - start_time)
})
});
self.current_decode = None;
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::Decoded(segments)) => {
self.status = match dt {
None => "decoding succeeded!".to_string(),
Some(dt) => format!("decoding succeeded in {:.2}s", dt),
};
self.segments = segments;
}
Err(err) => {
self.status = format!("decoding error {err:?}");
}
}
true
}
Msg::WorkerInMsg(inp) => {
self.worker.send(inp);
true
}
Msg::UpdateStatus(status) => {
self.status = status;
true
}
}
}
fn view(&self, ctx: &Context<Self>) -> Html {
html! {
<div>
<table>
<thead>
<tr>
<th>{"Sample"}</th>
<th></th>
<th></th>
</tr>
</thead>
<tbody>
{
SAMPLE_NAMES.iter().enumerate().map(|(i, name)| { html! {
<tr>
<th>{name}</th>
<th><audio controls=true src={format!("./{name}")}></audio></th>
{ if self.loaded {
html!(<th><button class="button" onclick={ctx.link().callback(move |_| Msg::Run(i))}> { "run" }</button></th>)
}else{html!()}
}
</tr>
}
}).collect::<Html>()
}
</tbody>
</table>
<h2>
{&self.status}
</h2>
{
if !self.loaded{
html! { <progress id="progress-bar" aria-label="loading weights…"></progress> }
} else if self.current_decode.is_some() {
html! { <progress id="progress-bar" aria-label="decoding…"></progress> }
} else { html!{
<blockquote>
<p>
{
self.segments.iter().map(|segment| { html! {
<>
<i>
{
format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})",
segment.start,
segment.start + segment.duration,
segment.dr.avg_logprob,
segment.dr.no_speech_prob,
)
}
</i>
<br/ >
{&segment.dr.text}
<br/ >
</>
} }).collect::<Html>()
}
</p>
</blockquote>
}
}
}
// Display the current date and time the page was rendered
<p class="footer">
{ "Rendered: " }
{ String::from(Date::new_0().to_string()) }
</p>
</div>
}
}
}

View File

@ -0,0 +1,215 @@
// Audio processing code, adapted from whisper.cpp
// https://github.com/ggerganov/whisper.cpp
use super::worker;
pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
impl Float for f32 {}
impl Float for f64 {}
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
fn fft<T: Float>(inp: &[T]) -> Vec<T> {
let n = inp.len();
let zero = T::zero();
if n == 1 {
return vec![inp[0], zero];
}
if n % 2 == 1 {
return dft(inp);
}
let mut out = vec![zero; n * 2];
let mut even = Vec::with_capacity(n / 2);
let mut odd = Vec::with_capacity(n / 2);
for (i, &inp) in inp.iter().enumerate() {
if i % 2 == 0 {
even.push(inp)
} else {
odd.push(inp);
}
}
let even_fft = fft(&even);
let odd_fft = fft(&odd);
let two_pi = T::PI() + T::PI();
let n_t = T::from(n).unwrap();
for k in 0..n / 2 {
let k_t = T::from(k).unwrap();
let theta = two_pi * k_t / n_t;
let re = theta.cos();
let im = -theta.sin();
let re_odd = odd_fft[2 * k];
let im_odd = odd_fft[2 * k + 1];
out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
}
out
}
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
fn dft<T: Float>(inp: &[T]) -> Vec<T> {
let zero = T::zero();
let n = inp.len();
let two_pi = T::PI() + T::PI();
let mut out = Vec::with_capacity(2 * n);
let n_t = T::from(n).unwrap();
for k in 0..n {
let k_t = T::from(k).unwrap();
let mut re = zero;
let mut im = zero;
for (j, &inp) in inp.iter().enumerate() {
let j_t = T::from(j).unwrap();
let angle = two_pi * k_t * j_t / n_t;
re += inp * angle.cos();
im -= inp * angle.sin();
}
out.push(re);
out.push(im);
}
out
}
#[allow(clippy::too_many_arguments)]
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
fn log_mel_spectrogram_w<T: Float>(
ith: usize,
hann: &[T],
samples: &[T],
filters: &[T],
fft_size: usize,
fft_step: usize,
speed_up: bool,
n_len: usize,
n_mel: usize,
n_threads: usize,
) -> Vec<T> {
let n_fft = if speed_up {
1 + fft_size / 4
} else {
1 + fft_size / 2
};
let zero = T::zero();
let half = T::from(0.5).unwrap();
let mut fft_in = vec![zero; fft_size];
let mut mel = vec![zero; n_len * n_mel];
for i in (ith..n_len).step_by(n_threads) {
let offset = i * fft_step;
// apply Hanning window
for j in 0..fft_size {
fft_in[j] = if offset + j < samples.len() {
hann[j] * samples[offset + j]
} else {
zero
}
}
// FFT -> mag^2
let mut fft_out: Vec<T> = fft(&fft_in);
for j in 0..fft_size {
fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
}
for j in 1..fft_size / 2 {
let v = fft_out[fft_size - j];
fft_out[j] += v;
}
if speed_up {
// scale down in the frequency domain results in a speed up in the time domain
for j in 0..n_fft {
fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
}
}
// mel spectrogram
for j in 0..n_mel {
let mut sum = zero;
for k in 0..n_fft {
sum += fft_out[k] * filters[j * n_fft + k];
}
mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
}
}
mel
}
fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
samples: &[T],
filters: &[T],
fft_size: usize,
fft_step: usize,
n_mel: usize,
speed_up: bool,
) -> Vec<T> {
let zero = T::zero();
let two_pi = T::PI() + T::PI();
let half = T::from(0.5).unwrap();
let one = T::from(1.0).unwrap();
let four = T::from(4.0).unwrap();
let fft_size_t = T::from(fft_size).unwrap();
let hann: Vec<T> = (0..fft_size)
.map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
.collect();
let n_len = samples.len() / fft_step;
// pad audio with at least one extra chunk of zeros
let pad = 100 * worker::m::CHUNK_LENGTH / 2;
let n_len = if n_len % pad != 0 {
(n_len / pad + 1) * pad
} else {
n_len
};
let n_len = n_len + pad;
let samples = {
let mut samples_padded = samples.to_vec();
let to_add = n_len * fft_step - samples.len();
samples_padded.extend(std::iter::repeat(zero).take(to_add));
samples_padded
};
// Use a single thread for now.
let mut mel = log_mel_spectrogram_w(
0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
);
let mmax = mel
.iter()
.max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
.copied()
.unwrap_or(zero)
- T::from(8).unwrap();
for m in mel.iter_mut() {
let v = T::max(*m, mmax);
*m = v / four + one
}
mel
}
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
cfg: &worker::m::Config,
samples: &[T],
filters: &[T],
) -> anyhow::Result<Vec<T>> {
let mel = log_mel_spectrogram_(
samples,
filters,
worker::m::N_FFT,
worker::m::HOP_LENGTH,
cfg.num_mel_bins,
false,
);
Ok(mel)
}

View File

@ -0,0 +1,4 @@
fn main() {
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
yew::Renderer::<candle_wasm_example_whisper::App>::new().render();
}

View File

@ -0,0 +1,53 @@
use candle_wasm_example_whisper::worker::{Decoder as D, ModelData};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct Decoder {
decoder: D,
}
#[wasm_bindgen]
impl Decoder {
#[wasm_bindgen(constructor)]
#[allow(clippy::too_many_arguments)]
pub fn new(
weights: Vec<u8>,
tokenizer: Vec<u8>,
mel_filters: Vec<u8>,
config: Vec<u8>,
quantized: bool,
is_multilingual: bool,
timestamps: bool,
task: Option<String>,
language: Option<String>,
) -> Result<Decoder, JsError> {
let decoder = D::load(ModelData {
tokenizer,
mel_filters,
config,
quantized,
weights,
is_multilingual,
timestamps,
task,
language,
});
match decoder {
Ok(decoder) => Ok(Self { decoder }),
Err(e) => Err(JsError::new(&e.to_string())),
}
}
#[wasm_bindgen]
pub fn decode(&mut self, wav_input: Vec<u8>) -> Result<String, JsError> {
let segments = self
.decoder
.convert_and_run(&wav_input)
.map_err(|e| JsError::new(&e.to_string()))?;
let json = serde_json::to_string(&segments)?;
Ok(json)
}
}
fn main() {}

View File

@ -0,0 +1,4 @@
use yew_agent::PublicWorker;
fn main() {
candle_wasm_example_whisper::Worker::register();
}

View File

@ -0,0 +1,101 @@
pub const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
("de", "german"),
("es", "spanish"),
("ru", "russian"),
("ko", "korean"),
("fr", "french"),
("ja", "japanese"),
("pt", "portuguese"),
("tr", "turkish"),
("pl", "polish"),
("ca", "catalan"),
("nl", "dutch"),
("ar", "arabic"),
("sv", "swedish"),
("it", "italian"),
("id", "indonesian"),
("hi", "hindi"),
("fi", "finnish"),
("vi", "vietnamese"),
("he", "hebrew"),
("uk", "ukrainian"),
("el", "greek"),
("ms", "malay"),
("cs", "czech"),
("ro", "romanian"),
("da", "danish"),
("hu", "hungarian"),
("ta", "tamil"),
("no", "norwegian"),
("th", "thai"),
("ur", "urdu"),
("hr", "croatian"),
("bg", "bulgarian"),
("lt", "lithuanian"),
("la", "latin"),
("mi", "maori"),
("ml", "malayalam"),
("cy", "welsh"),
("sk", "slovak"),
("te", "telugu"),
("fa", "persian"),
("lv", "latvian"),
("bn", "bengali"),
("sr", "serbian"),
("az", "azerbaijani"),
("sl", "slovenian"),
("kn", "kannada"),
("et", "estonian"),
("mk", "macedonian"),
("br", "breton"),
("eu", "basque"),
("is", "icelandic"),
("hy", "armenian"),
("ne", "nepali"),
("mn", "mongolian"),
("bs", "bosnian"),
("kk", "kazakh"),
("sq", "albanian"),
("sw", "swahili"),
("gl", "galician"),
("mr", "marathi"),
("pa", "punjabi"),
("si", "sinhala"),
("km", "khmer"),
("sn", "shona"),
("yo", "yoruba"),
("so", "somali"),
("af", "afrikaans"),
("oc", "occitan"),
("ka", "georgian"),
("be", "belarusian"),
("tg", "tajik"),
("sd", "sindhi"),
("gu", "gujarati"),
("am", "amharic"),
("yi", "yiddish"),
("lo", "lao"),
("uz", "uzbek"),
("fo", "faroese"),
("ht", "haitian creole"),
("ps", "pashto"),
("tk", "turkmen"),
("nn", "nynorsk"),
("mt", "maltese"),
("sa", "sanskrit"),
("lb", "luxembourgish"),
("my", "myanmar"),
("bo", "tibetan"),
("tl", "tagalog"),
("mg", "malagasy"),
("as", "assamese"),
("tt", "tatar"),
("haw", "hawaiian"),
("ln", "lingala"),
("ha", "hausa"),
("ba", "bashkir"),
("jw", "javanese"),
("su", "sundanese"),
];

View File

@ -0,0 +1,29 @@
pub const WITH_TIMER: bool = true;
struct Timer {
label: &'static str,
}
// impl Timer {
// fn new(label: &'static str) -> Self {
// if WITH_TIMER {
// web_sys::console::time_with_label(label);
// }
// Self { label }
// }
// }
impl Drop for Timer {
fn drop(&mut self) {
if WITH_TIMER {
web_sys::console::time_end_with_label(self.label)
}
}
}
mod app;
mod audio;
pub mod languages;
pub mod worker;
pub use app::App;
pub use worker::Worker;

View File

@ -0,0 +1,492 @@
use crate::languages::LANGUAGES;
use anyhow::Error as E;
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
use candle_nn::{ops::softmax, VarBuilder};
pub use candle_transformers::models::whisper::{self as m, Config};
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
use yew_agent::{HandlerId, Public, WorkerLink};
#[wasm_bindgen]
extern "C" {
// Use `js_namespace` here to bind `console.log(..)` instead of just
// `log(..)`
#[wasm_bindgen(js_namespace = console)]
pub fn log(s: &str);
}
#[macro_export]
macro_rules! console_log {
// Note that this is using the `log` function imported above during
// `bare_bones`
($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string()))
}
pub const DTYPE: DType = DType::F32;
pub enum Model {
Normal(m::model::Whisper),
Quantized(m::quantized_model::Whisper),
}
// Maybe we should use some traits rather than doing the dispatch for all these.
impl Model {
pub fn config(&self) -> &Config {
match self {
Self::Normal(m) => &m.config,
Self::Quantized(m) => &m.config,
}
}
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.encoder.forward(x, flush),
Self::Quantized(m) => m.encoder.forward(x, flush),
}
}
pub fn decoder_forward(
&mut self,
x: &Tensor,
xa: &Tensor,
flush: bool,
) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.forward(x, xa, flush),
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
}
}
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.final_linear(x),
Self::Quantized(m) => m.decoder.final_linear(x),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecodingResult {
pub tokens: Vec<u32>,
pub text: String,
pub avg_logprob: f64,
pub no_speech_prob: f64,
temperature: f64,
compression_ratio: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Segment {
pub start: f64,
pub duration: f64,
pub dr: DecodingResult,
}
pub struct Decoder {
model: Model,
rng: rand::rngs::StdRng,
task: Option<Task>,
language: Option<String>,
is_multilingual: bool,
mel_filters: Vec<f32>,
timestamps: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
translate_token: u32,
eot_token: u32,
no_speech_token: u32,
no_timestamps_token: u32,
}
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
mel_filters: Vec<f32>,
device: &Device,
task: Option<Task>,
language: Option<String>,
is_multilingual: bool,
timestamps: bool,
) -> anyhow::Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
.map(|i| {
if model.config().suppress_tokens.contains(&i) {
f32::NEG_INFINITY
} else {
0f32
}
})
.collect();
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = m::NO_SPEECH_TOKENS
.iter()
.find_map(|token| token_id(&tokenizer, token).ok());
let no_speech_token = match no_speech_token {
None => anyhow::bail!("unable to find any non-speech token"),
Some(n) => n,
};
let seed = 299792458;
Ok(Self {
model,
rng: StdRng::seed_from_u64(seed),
tokenizer,
mel_filters,
task,
timestamps,
language,
is_multilingual,
suppress_tokens,
sot_token,
transcribe_token,
translate_token,
eot_token,
no_speech_token,
no_timestamps_token,
})
}
fn decode(&mut self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
let model = &mut self.model;
let language_token = match (self.is_multilingual, &self.language) {
(true, None) => Some(detect_language(model, &self.tokenizer, mel)?),
(false, None) => None,
(true, Some(language)) => {
match token_id(&self.tokenizer, &format!("<|{:?}|>", self.language)) {
Ok(token_id) => Some(token_id),
Err(_) => anyhow::bail!("language {language} is not supported"),
}
}
(false, Some(_)) => {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
let audio_features = model.encoder_forward(mel, true)?;
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config().max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
if let Some(language_token) = language_token {
tokens.push(language_token);
}
match self.task {
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
}
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}
let (_, seq_len, _) = ys.dims3()?;
let logits = model
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
// ApplyTimestampRules, i.e.:
// - Timestamps come in pairs, except before EOT.
// - Timestamps should be non-decreasing.
// - If the sum of the probabilities of timestamps is higher than any other tokens,
// only consider timestamps when sampling.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
tokens.push(next_token);
let prob = softmax(&logits, candle::D::Minus1)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
break;
}
sum_logprob += prob.ln();
}
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
let avg_logprob = sum_logprob / tokens.len() as f64;
Ok(DecodingResult {
tokens,
text,
avg_logprob,
no_speech_prob,
temperature: t,
compression_ratio: f64::NAN,
})
}
fn decode_with_fallback(&mut self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult, _> = self.decode(segment, t);
if i == m::TEMPERATURES.len() - 1 {
return dr;
}
// On errors, we try again with a different temperature.
match dr {
Ok(dr) => {
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
Err(err) => {
console_log!("Error running at {t}: {err}")
}
}
}
unreachable!()
}
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
let mut segments = vec![];
while seek < content_frames {
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
console_log!("no speech detected, skipping {seek} {dr:?}");
continue;
}
let segment = Segment {
start: time_offset,
duration: segment_duration,
dr,
};
console_log!("{seek}: {segment:?}");
segments.push(segment)
}
Ok(segments)
}
pub fn load(md: ModelData) -> anyhow::Result<Self> {
let device = Device::Cpu;
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(E::msg)?;
let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
console_log!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let config: Config = serde_json::from_slice(&md.config)?;
let model = if md.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&md.weights,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb = VarBuilder::from_buffered_safetensors(md.weights, m::DTYPE, &device)?;
Model::Normal(m::model::Whisper::load(&vb, config)?)
};
console_log!("done loading model");
let task = match md.task.as_deref() {
Some("translate") => Some(Task::Translate),
_ => Some(Task::Transcribe),
};
let decoder = Self::new(
model,
tokenizer,
mel_filters,
&device,
task,
md.language,
md.is_multilingual,
md.timestamps,
)?;
Ok(decoder)
}
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
let device = Device::Cpu;
let mut wav_input = std::io::Cursor::new(wav_input);
let (header, data) = wav::read(&mut wav_input)?;
console_log!("loaded wav data: {header:?}");
if header.sampling_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
console_log!("pcm data loaded {}", pcm_data.len());
let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?;
let mel_len = mel.len();
let n_mels = self.model.config().num_mel_bins;
let mel = Tensor::from_vec(mel, (1, n_mels, mel_len / n_mels), &device)?;
console_log!("loaded mel: {:?}", mel.dims());
let segments = self.run(&mel)?;
Ok(segments)
}
}
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32, E> {
console_log!("detecting language");
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(
2,
0,
usize::min(seq_len, model.config().max_source_positions),
)?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
.map(|e| e.map_err(E::msg))
.collect::<Result<Vec<_>, E>>()?;
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
let audio_features = model.encoder_forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}")
}
let token = &format!("<|{}|>", probs[0].0 .0);
let language = token_id(tokenizer, token)?;
console_log!("detected language: {language} {token}");
Ok(language)
}
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id),
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum Task {
Transcribe,
Translate,
}
// Communication to the worker happens through bincode, the model weights and configs are fetched
// on the main thread and transferred via the following structure.
#[derive(Serialize, Deserialize)]
pub struct ModelData {
pub weights: Vec<u8>,
pub tokenizer: Vec<u8>,
pub mel_filters: Vec<u8>,
pub config: Vec<u8>,
pub quantized: bool,
pub timestamps: bool,
pub is_multilingual: bool,
pub language: Option<String>,
pub task: Option<String>,
}
pub struct Worker {
link: WorkerLink<Self>,
decoder: Option<Decoder>,
}
#[derive(Serialize, Deserialize)]
pub enum WorkerInput {
ModelData(ModelData),
DecodeTask { wav_bytes: Vec<u8> },
}
#[derive(Serialize, Deserialize)]
pub enum WorkerOutput {
Decoded(Vec<Segment>),
WeightsLoaded,
}
impl yew_agent::Worker for Worker {
type Input = WorkerInput;
type Message = ();
type Output = Result<WorkerOutput, String>;
type Reach = Public<Self>;
fn create(link: WorkerLink<Self>) -> Self {
Self {
link,
decoder: None,
}
}
fn update(&mut self, _msg: Self::Message) {
// no messaging
}
fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
let output = match msg {
WorkerInput::ModelData(md) => match Decoder::load(md) {
Ok(decoder) => {
self.decoder = Some(decoder);
Ok(WorkerOutput::WeightsLoaded)
}
Err(err) => Err(format!("model creation error {err:?}")),
},
WorkerInput::DecodeTask { wav_bytes } => match &mut self.decoder {
None => Err("model has not been set".to_string()),
Some(decoder) => decoder
.convert_and_run(&wav_bytes)
.map(WorkerOutput::Decoded)
.map_err(|e| e.to_string()),
},
};
self.link.respond(id, output);
}
fn name_of_resource() -> &'static str {
"worker.js"
}
fn resource_path_is_relative() -> bool {
true
}
}

View File

@ -0,0 +1,116 @@
//load the candle Whisper decoder wasm module
import init, { Decoder } from "./build/m.js";
async function fetchArrayBuffer(url) {
const cacheName = "whisper-candle-cache";
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const data = await cachedResponse.arrayBuffer();
return new Uint8Array(data);
}
const res = await fetch(url, { cache: "force-cache" });
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
class Whisper {
static instance = {};
// Retrieve the Whisper model. When called for the first time,
// this will load the model and save it for future use.
static async getInstance(params) {
const {
weightsURL,
modelID,
tokenizerURL,
mel_filtersURL,
configURL,
quantized,
is_multilingual,
timestamps,
task,
language,
} = params;
// load individual modelID only once
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [
weightsArrayU8,
tokenizerArrayU8,
mel_filtersArrayU8,
configArrayU8,
] = await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(mel_filtersURL),
fetchArrayBuffer(configURL),
]);
this.instance[modelID] = new Decoder(
weightsArrayU8,
tokenizerArrayU8,
mel_filtersArrayU8,
configArrayU8,
quantized,
is_multilingual,
timestamps,
task,
language
);
} else {
self.postMessage({ status: "loading", message: "Model Already Loaded" });
}
return this.instance[modelID];
}
}
self.addEventListener("message", async (event) => {
const {
weightsURL,
modelID,
tokenizerURL,
configURL,
mel_filtersURL,
audioURL,
} = event.data;
try {
self.postMessage({ status: "decoding", message: "Starting Decoder" });
let quantized = false;
if (modelID.includes("quantized")) {
quantized = true;
}
let is_multilingual = false;
if (modelID.includes("multilingual")) {
is_multilingual = true;
}
let timestamps = true;
const decoder = await Whisper.getInstance({
weightsURL,
modelID,
tokenizerURL,
mel_filtersURL,
configURL,
quantized,
is_multilingual,
timestamps,
task: null,
language: null,
});
self.postMessage({ status: "decoding", message: "Loading Audio" });
const audioArrayU8 = await fetchArrayBuffer(audioURL);
self.postMessage({ status: "decoding", message: "Running Decoder..." });
const segments = decoder.decode(audioArrayU8);
// Send the segment back to the main thread as JSON
self.postMessage({
status: "complete",
message: "complete",
output: JSON.parse(segments),
});
} catch (e) {
self.postMessage({ error: e });
}
});