46 Commits

Author SHA1 Message Date
Bel LaPointe
fffea2ddf0 no render mac 2025-09-10 11:20:01 -06:00
Bel LaPointe
12dbf12299 k 2024-09-21 21:33:40 -04:00
Bel LaPointe
f04a55590f fixed 2024-09-21 21:33:40 -04:00
Bel LaPointe
2254afcbfb wav to mkv with subtitles scripting 2024-09-21 21:33:40 -04:00
bel
5fdc60e32c stem words for destuttering 2024-01-03 20:40:35 -07:00
bel
4c80247ab9 accept lower sample rates if 16k not avail 2024-01-03 17:18:07 -07:00
bel
53e675b9a0 no panic on unusable mic 2024-01-03 17:09:27 -07:00
Bel LaPointe
9780c6f2ef todo 2024-01-03 08:50:59 -07:00
Bel LaPointe
7f902af26f default update 2024-01-03 08:40:13 -07:00
Bel LaPointe
9bc009996c oop 2024-01-03 08:38:24 -07:00
Bel LaPointe
cbc8a4f9fd cargo run -- --stream-step 8 --stream-retain 4 --stream-head=2 --stream-tail=0 2> /dev/null 2024-01-03 08:37:27 -07:00
Bel LaPointe
a8c8140d18 functionize at least 2024-01-03 08:28:22 -07:00
Bel LaPointe
5bc3209070 x=2; cargo run -- --wav $HOME/Downloads/41A6C472-6E4D-4953-9A90-2497D2DAD8C9.wav --stream-step $((x*4)) --stream-retain $((x*2)) --stream-{head,tail}=$((x)) 2> /dev/null 2024-01-03 08:22:45 -07:00
Bel LaPointe
8b5c18e65e todo 2024-01-03 08:22:15 -07:00
Bel LaPointe
ec47d8142a destutter with stopwords impl 2024-01-03 07:54:21 -07:00
bel
03659164ba wip 2024-01-02 21:14:02 -07:00
bel
709dd1dba3 tod 2024-01-02 21:12:33 -07:00
bel
26595396cf tod 2024-01-02 21:01:01 -07:00
Bel LaPointe
fb7892b52b todo 2024-01-02 18:38:13 -07:00
Bel LaPointe
b08e055dac todo 2024-01-02 18:23:03 -07:00
Bel LaPointe
9d993cfc8a update destutterer to do punctuation-free words 2024-01-02 18:20:46 -07:00
Bel LaPointe
f4f8ea429a merge 2024-01-02 17:51:29 -07:00
Bel LaPointe
38bea3735f todo 2024-01-02 17:51:14 -07:00
Bel LaPointe
1c48026690 need to overlap without ANY puctuation, which i can do by breaking into words 2024-01-02 17:49:47 -07:00
Bel LaPointe
a57312786a gr 2024-01-02 17:48:17 -07:00
Bel LaPointe
55e3bf0a26 update defaults 2024-01-02 17:47:00 -07:00
Bel LaPointe
743c8c5f67 time cargo run -- --wav $HOME/Downloads/41A6C472-6E4D-4953-9A90-2497D2DAD8C9.wav --stream-step 30 --stream-retain 25 --stream-{head,tail}=1 2> /dev/null 2024-01-02 16:45:04 -07:00
Bel LaPointe
d32f7a4c40 destutterer doesnt drop stutter for prev 2024-01-02 16:36:39 -07:00
Bel LaPointe
d94cbd6927 baked-lib wav_channel to base 2024-01-02 16:30:07 -07:00
Bel LaPointe
7a5db3b2ac no callback on empty str 2024-01-02 16:29:58 -07:00
Bel LaPointe
d0dc9571d7 rust-whisper-baked listen also destutters 2024-01-02 16:12:07 -07:00
Bel LaPointe
6082f7e446 rust-whisper-baked de-stutters wav_channel output 2024-01-02 16:10:11 -07:00
Bel LaPointe
dd6f980266 rust-whisper-baked with --wav to wav_channel 2024-01-02 14:18:37 -07:00
Bel LaPointe
601fe517d7 rust-whisper-baked-lib::wav_channel() 2024-01-02 14:17:57 -07:00
Bel LaPointe
cd339de334 rust-whisper-lib::wav_channel() 2024-01-02 14:15:53 -07:00
bel
393100973c baked streams even wav files 2024-01-01 20:56:53 -07:00
bel
97c025f04d rust-whisper-baked works with WAV 2024-01-01 19:06:14 -07:00
bel
871efd9b8c drop unused 2024-01-01 18:41:49 -07:00
Bel LaPointe
86bb1769f3 del unused 2023-12-21 22:30:43 -05:00
Bel LaPointe
c23352b7e5 found an example that just works 2023-12-21 22:26:49 -05:00
Bel LaPointe
1cceb773f2 reduce whisper-rs space 2023-12-21 22:15:00 -05:00
Bel LaPointe
29ff174f5d Revert "found burn depends on zstd that doesnt wasm"
This reverts commit 1c9d646a50.
2023-12-21 22:13:25 -05:00
Bel LaPointe
1c9d646a50 found burn depends on zstd that doesnt wasm 2023-12-21 22:13:23 -05:00
Bel LaPointe
d63d05505c drop excess packages 2023-12-21 22:02:25 -05:00
Bel LaPointe
5fb1308b30 rust-whisper-lib imports whisper-burn@432ef86ec1cb9d57406b9fab4e2a13155b10d74b 2023-12-21 22:00:51 -05:00
Bel LaPointe
19b0a4b385 baked channel uses tiny 2023-12-21 15:09:37 -05:00
31 changed files with 2499 additions and 883 deletions

4
.gitignore vendored
View File

@@ -9,3 +9,7 @@ snowboy-2022/snowboy
**/*.wav
snowboy-2022/Dockerfile
**/target
/candle-wasm-examples-whisper/mel_filters*
/candle-wasm-examples-whisper/whisper-tiny*
/candle-wasm-examples-whisper/quantized*
/candle-wasm-examples-whisper/audios*

2
.gitmodules vendored
View File

@@ -1,3 +1,3 @@
[submodule "rust-whisper.d/gitea-whisper-rs"]
[submodule "gitea-whisper-rs"]
path = gitea-whisper-rs
url = https://gitea.inhome.blapointe.com/bel/whisper-rs.git

View File

@@ -0,0 +1,52 @@
[package]
name = "candle-wasm-example-whisper"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.2" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.2" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
# App crates.
anyhow = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
wav = { workspace = true }
safetensors = { workspace = true }
# Wasm specific crates.
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2"
yew-agent = "0.2.0"
yew = { version = "0.20.0", features = ["csr"] }
[dependencies.web-sys]
version = "0.3.64"
features = [
'Blob',
'Document',
'Element',
'HtmlElement',
'Node',
'Window',
'Request',
'RequestCache',
'RequestInit',
'RequestMode',
'Response',
'Performance',
]

View File

@@ -0,0 +1,68 @@
## Running Whisper Examples
Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes.
### Pure Rust UI
To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install)
From the `candle-wasm-examples/whisper` directory run:
Download assets:
```bash
# mel filters
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors
# Model and tokenizer tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/resolve/main/model.safetensors -P whisper-tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/tokenizer.json -P whisper-tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/config.json -P whisper-tiny.en
# model and tokenizer tiny multilanguage
wget -c https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors -P whisper-tiny
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/tokenizer.json -P whisper-tiny
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/config.json -P whisper-tiny
#quantized
wget -c https://huggingface.co/lmz/candle-whisper/resolve/main/model-tiny-en-q80.gguf -P quantized
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/tokenizer-tiny-en.json -P quantized
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/config-tiny-en.json -P quantized
# Audio samples
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -P audios
```
Run hot reload server:
```bash
trunk serve --release --public-url / --port 8080
```
### Vanilla JS and WebWorkers
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
```bash
sh build-lib.sh
```
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
```js
import init, { Decoder } from "./build/m.js";
```
The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
Finally, you can preview the example by running a local HTTP server. For example:
```bash
python -m http.server
```
Then open `http://localhost:8000/lib-example.html` in your browser.

View File

@@ -0,0 +1,2 @@
cargo build --target wasm32-unknown-unknown --release
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web

View File

@@ -0,0 +1,40 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Welcome to Candle!</title>
<link data-trunk rel="copy-file" href="mel_filters.safetensors" />
<!-- samples -->
<link data-trunk rel="copy-dir" href="audios" />
<!-- tiny.en -->
<link data-trunk rel="copy-dir" href="whisper-tiny.en" />
<!-- tiny -->
<link data-trunk rel="copy-dir" href="whisper-tiny" />
<!-- quantized -->
<link data-trunk rel="copy-dir" href="quantized" />
<link
data-trunk
rel="rust"
href="Cargo.toml"
data-bin="app"
data-type="main" />
<link
data-trunk
rel="rust"
href="Cargo.toml"
data-bin="worker"
data-type="worker" />
<link
rel="stylesheet"
href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic" />
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css" />
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css" />
</head>
<body></body>
</html>

View File

@@ -0,0 +1,351 @@
<html>
<head>
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
<title>Candle Whisper Rust/WASM</title>
</head>
<body></body>
</html>
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<style>
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
html,
body {
font-family: "Source Sans 3", sans-serif;
}
</style>
<script src="https://cdn.tailwindcss.com"></script>
<script type="module">
// base url for audio examples
const AUDIO_BASE_URL =
"https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/";
// models base url
const MODELS = {
tiny_multilingual: {
base_url: "https://huggingface.co/openai/whisper-tiny/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
size: "151 MB",
},
tiny_en: {
base_url:
"https://huggingface.co/openai/whisper-tiny.en/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
size: "151 MB",
},
tiny_quantized_multilingual_q80: {
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
model: "model-tiny-q80.gguf",
tokenizer: "tokenizer-tiny.json",
config: "config-tiny.json",
size: "41.5 MB",
},
tiny_en_quantized_q80: {
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
model: "model-tiny-q80.gguf",
tokenizer: "tokenizer-tiny-en.json",
config: "config-tiny-en.json",
size: "41.8 MB",
},
distil_medium_en: {
base_url:
"https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
size: "789 MB",
},
};
const modelEl = document.querySelector("#model");
Object.keys(MODELS).forEach((modelID) => {
const model = MODELS[modelID];
const option = document.createElement("option");
option.value = modelID;
option.textContent = `${modelID} (${model.size})`;
modelEl.appendChild(option);
});
const whisperWorker = new Worker("./whisperWorker.js", {
type: "module",
});
async function classifyAudio(
weightsURL, // URL to the weights file
modelID, // model ID
tokenizerURL, // URL to the tokenizer file
configURL, // model config URL
mel_filtersURL, // URL to the mel filters file
audioURL, // URL to the audio file
updateStatus // function to update the status
) {
return new Promise((resolve, reject) => {
whisperWorker.postMessage({
weightsURL,
modelID,
tokenizerURL,
configURL,
mel_filtersURL,
audioURL,
});
function messageHandler(event) {
console.log(event.data);
if ("status" in event.data) {
updateStatus(event.data);
}
if ("error" in event.data) {
whisperWorker.removeEventListener("message", messageHandler);
reject(new Error(event.data.error));
}
if (event.data.status === "complete") {
whisperWorker.removeEventListener("message", messageHandler);
resolve(event.data);
}
}
whisperWorker.addEventListener("message", messageHandler);
});
}
// keep track of the audio URL
let audioURL = null;
function setAudio(src) {
const audio = document.querySelector("#audio");
audio.src = src;
audio.controls = true;
audio.hidden = false;
document.querySelector("#detect").disabled = false;
audioURL = src;
}
// add event listener to audio buttons
document.querySelectorAll("#audios-select > button").forEach((target) => {
target.addEventListener("click", (e) => {
const value = target.dataset.value;
const href = AUDIO_BASE_URL + value;
setAudio(href);
});
});
//add event listener to file input
document.querySelector("#file-upload").addEventListener("change", (e) => {
const target = e.target;
if (target.files.length > 0) {
const href = URL.createObjectURL(target.files[0]);
setAudio(href);
}
});
// add event listener to drop-area
const dropArea = document.querySelector("#drop-area");
dropArea.addEventListener("dragenter", (e) => {
e.preventDefault();
dropArea.classList.add("border-blue-700");
});
dropArea.addEventListener("dragleave", (e) => {
e.preventDefault();
dropArea.classList.remove("border-blue-700");
});
dropArea.addEventListener("dragover", (e) => {
e.preventDefault();
dropArea.classList.add("border-blue-700");
});
dropArea.addEventListener("drop", (e) => {
e.preventDefault();
dropArea.classList.remove("border-blue-700");
const url = e.dataTransfer.getData("text/uri-list");
const files = e.dataTransfer.files;
if (files.length > 0) {
const href = URL.createObjectURL(files[0]);
setAudio(href);
} else if (url) {
setAudio(url);
}
});
// add event listener to detect button
document.querySelector("#detect").addEventListener("click", async () => {
if (audioURL === null) {
return;
}
const modelID = modelEl.value;
const model = MODELS[modelID];
const modelURL = model.base_url + model.model;
const tokenizerURL = model.base_url + model.tokenizer;
const configURL = model.base_url + model.config;
classifyAudio(
modelURL,
modelID,
tokenizerURL,
configURL,
"mel_filters.safetensors",
audioURL,
updateStatus
)
.then((result) => {
console.log("RESULT", result);
const { output } = result;
const text = output.map((segment) => segment.dr.text).join(" ");
console.log(text);
document.querySelector("#output-status").hidden = true;
document.querySelector("#output-generation").hidden = false;
document.querySelector("#output-generation").textContent = text;
})
.catch((error) => {
console.error(error);
});
});
function updateStatus(data) {
const { status, message } = data;
const button = document.querySelector("#detect");
if (status === "decoding" || status === "loading") {
button.disabled = true;
button.textContent = message;
} else if (status === "complete") {
button.disabled = false;
button.textContent = "Transcribe Audio";
}
}
</script>
</head>
<body class="container max-w-4xl mx-auto p-4">
<main class="grid grid-cols-1 gap-8 relative">
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
<div>
<h1 class="text-5xl font-bold">Candle Whisper</h1>
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
<p class="max-w-lg">
Transcribe audio in the browser using rust/wasm with an audio file.
This demo uses the
<a
href="https://huggingface.co/openai/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline">
OpenAI Whisper models
</a>
and WASM runtime built with
<a
href="https://github.com/huggingface/candle/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
>Candle
</a>
</p>
</div>
<div>
<label for="model" class="font-medium">Models Options: </label>
<select
id="model"
class="border-2 border-gray-500 rounded-md font-light">
</select>
</div>
<!-- drag and drop area -->
<div class="relative">
<div
id="drop-area"
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden">
<div
class="flex flex-col items-center justify-center space-y-1 text-center">
<svg
width="25"
height="25"
viewBox="0 0 25 25"
fill="none"
xmlns="http://www.w3.org/2000/svg">
<path
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
fill="#000" />
</svg>
<div class="flex text-sm text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700">
<span>Drag and drop your audio here</span>
<span class="block text-xs">or</span>
<span class="block text-xs">Click to upload</span>
</label>
</div>
<input
id="file-upload"
name="file-upload"
type="file"
accept="audio/*"
class="sr-only" />
</div>
<audio
id="audio"
hidden
controls
class="w-full p-2 select-none"></audio>
</div>
</div>
<div>
<div class="flex flex-wrap gap-3 items-center" id="audios-select">
<h3 class="font-medium">Examples:</h3>
<button
data-value="samples_jfk.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>jfk.wav</span>
<span class="text-xs block"> (352 kB)</span>
</button>
<button
data-value="samples_a13.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>a13.wav</span>
<span class="text-xs block"> (960 kB)</span>
</button>
<button
data-value="samples_mm0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>mm0.wav</span>
<span class="text-xs block new"> (957 kB)</span>
</button>
<button
data-value="samples_gb0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>gb0.wav </span>
<span class="text-xs block">(4.08 MB)</span>
</button>
<button
data-value="samples_gb1.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>gb1.wav </span>
<span class="text-xs block">(6.36 MB)</span>
</button>
<button
data-value="samples_hp0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>hp0.wav </span>
<span class="text-xs block">(8.75 MB)</span>
</button>
</div>
</div>
<div>
<button
id="detect"
disabled
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
Transcribe Audio
</button>
</div>
<div>
<h3 class="font-medium">Transcription:</h3>
<div
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2">
<p hidden id="output-generation" class="grid-rows-2"></p>
<span id="output-status" class="m-auto font-light"
>No transcription results yet</span
>
</div>
</div>
</main>
</body>
</html>

View File

@@ -0,0 +1,6 @@
import init, { run_app } from './pkg/candle_wasm_example_whisper.js';
async function main() {
await init('/pkg/candle_wasm_example_whisper_bg.wasm');
run_app();
}
main()

View File

@@ -0,0 +1,283 @@
use crate::console_log;
use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput};
use js_sys::Date;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use yew::{html, Component, Context, Html};
use yew_agent::{Bridge, Bridged};
const SAMPLE_NAMES: [&str; 6] = [
"audios/samples_jfk.wav",
"audios/samples_a13.wav",
"audios/samples_gb0.wav",
"audios/samples_gb1.wav",
"audios/samples_hp0.wav",
"audios/samples_mm0.wav",
];
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response};
let window = web_sys::window().ok_or("window")?;
let mut opts = RequestInit::new();
let opts = opts
.method("GET")
.mode(RequestMode::Cors)
.cache(RequestCache::NoCache);
let request = Request::new_with_str_and_init(url, opts)?;
let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
// `resp_value` is a `Response` object.
assert!(resp_value.is_instance_of::<Response>());
let resp: Response = resp_value.dyn_into()?;
let data = JsFuture::from(resp.blob()?).await?;
let blob = web_sys::Blob::from(data);
let array_buffer = JsFuture::from(blob.array_buffer()).await?;
let data = js_sys::Uint8Array::new(&array_buffer).to_vec();
Ok(data)
}
pub enum Msg {
Run(usize),
UpdateStatus(String),
SetDecoder(ModelData),
WorkerInMsg(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>),
}
pub struct CurrentDecode {
start_time: Option<f64>,
}
pub struct App {
status: String,
loaded: bool,
segments: Vec<Segment>,
current_decode: Option<CurrentDecode>,
worker: Box<dyn Bridge<Worker>>,
}
async fn model_data_load() -> Result<ModelData, JsValue> {
let quantized = false;
let is_multilingual = false;
let (tokenizer, mel_filters, weights, config) = if quantized {
console_log!("loading quantized weights");
let tokenizer = fetch_url("quantized/tokenizer-tiny-en.json").await?;
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let weights = fetch_url("quantized/model-tiny-en-q80.gguf").await?;
let config = fetch_url("quantized/config-tiny-en.json").await?;
(tokenizer, mel_filters, weights, config)
} else {
console_log!("loading float weights");
if is_multilingual {
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let tokenizer = fetch_url("whisper-tiny/tokenizer.json").await?;
let weights = fetch_url("whisper-tiny/model.safetensors").await?;
let config = fetch_url("whisper-tiny/config.json").await?;
(tokenizer, mel_filters, weights, config)
} else {
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let tokenizer = fetch_url("whisper-tiny.en/tokenizer.json").await?;
let weights = fetch_url("whisper-tiny.en/model.safetensors").await?;
let config = fetch_url("whisper-tiny.en/config.json").await?;
(tokenizer, mel_filters, weights, config)
}
};
let timestamps = true;
let _task = Some("transcribe".to_string());
console_log!("{}", weights.len());
Ok(ModelData {
tokenizer,
mel_filters,
weights,
config,
quantized,
timestamps,
task: None,
is_multilingual,
language: None,
})
}
fn performance_now() -> Option<f64> {
let window = web_sys::window()?;
let performance = window.performance()?;
Some(performance.now() / 1000.)
}
impl Component for App {
type Message = Msg;
type Properties = ();
fn create(ctx: &Context<Self>) -> Self {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
status,
segments: vec![],
current_decode: None,
worker,
loaded: false,
}
}
fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) {
if first_render {
ctx.link().send_future(async {
match model_data_load().await {
Err(err) => {
let status = format!("{err:?}");
Msg::UpdateStatus(status)
}
Ok(model_data) => Msg::SetDecoder(model_data),
}
});
}
}
fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
match msg {
Msg::SetDecoder(md) => {
self.status = "weights loaded successfully!".to_string();
self.loaded = true;
console_log!("loaded weights");
self.worker.send(WorkerInput::ModelData(md));
true
}
Msg::Run(sample_index) => {
let sample = SAMPLE_NAMES[sample_index];
if self.current_decode.is_some() {
self.status = "already decoding some sample at the moment".to_string()
} else {
let start_time = performance_now();
self.current_decode = Some(CurrentDecode { start_time });
self.status = format!("decoding {sample}");
self.segments.clear();
ctx.link().send_future(async move {
match fetch_url(sample).await {
Err(err) => {
let output = Err(format!("decoding error: {err:?}"));
// Mimic a worker output to so as to release current_decode
Msg::WorkerOutMsg(output)
}
Ok(wav_bytes) => {
Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
}
}
})
}
//
true
}
Msg::WorkerOutMsg(output) => {
let dt = self.current_decode.as_ref().and_then(|current_decode| {
current_decode.start_time.and_then(|start_time| {
performance_now().map(|stop_time| stop_time - start_time)
})
});
self.current_decode = None;
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::Decoded(segments)) => {
self.status = match dt {
None => "decoding succeeded!".to_string(),
Some(dt) => format!("decoding succeeded in {:.2}s", dt),
};
self.segments = segments;
}
Err(err) => {
self.status = format!("decoding error {err:?}");
}
}
true
}
Msg::WorkerInMsg(inp) => {
self.worker.send(inp);
true
}
Msg::UpdateStatus(status) => {
self.status = status;
true
}
}
}
fn view(&self, ctx: &Context<Self>) -> Html {
html! {
<div>
<table>
<thead>
<tr>
<th>{"Sample"}</th>
<th></th>
<th></th>
</tr>
</thead>
<tbody>
{
SAMPLE_NAMES.iter().enumerate().map(|(i, name)| { html! {
<tr>
<th>{name}</th>
<th><audio controls=true src={format!("./{name}")}></audio></th>
{ if self.loaded {
html!(<th><button class="button" onclick={ctx.link().callback(move |_| Msg::Run(i))}> { "run" }</button></th>)
}else{html!()}
}
</tr>
}
}).collect::<Html>()
}
</tbody>
</table>
<h2>
{&self.status}
</h2>
{
if !self.loaded{
html! { <progress id="progress-bar" aria-label="loading weights…"></progress> }
} else if self.current_decode.is_some() {
html! { <progress id="progress-bar" aria-label="decoding…"></progress> }
} else { html!{
<blockquote>
<p>
{
self.segments.iter().map(|segment| { html! {
<>
<i>
{
format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})",
segment.start,
segment.start + segment.duration,
segment.dr.avg_logprob,
segment.dr.no_speech_prob,
)
}
</i>
<br/ >
{&segment.dr.text}
<br/ >
</>
} }).collect::<Html>()
}
</p>
</blockquote>
}
}
}
// Display the current date and time the page was rendered
<p class="footer">
{ "Rendered: " }
{ String::from(Date::new_0().to_string()) }
</p>
</div>
}
}
}

View File

@@ -0,0 +1,215 @@
// Audio processing code, adapted from whisper.cpp
// https://github.com/ggerganov/whisper.cpp
use super::worker;
pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
impl Float for f32 {}
impl Float for f64 {}
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
fn fft<T: Float>(inp: &[T]) -> Vec<T> {
let n = inp.len();
let zero = T::zero();
if n == 1 {
return vec![inp[0], zero];
}
if n % 2 == 1 {
return dft(inp);
}
let mut out = vec![zero; n * 2];
let mut even = Vec::with_capacity(n / 2);
let mut odd = Vec::with_capacity(n / 2);
for (i, &inp) in inp.iter().enumerate() {
if i % 2 == 0 {
even.push(inp)
} else {
odd.push(inp);
}
}
let even_fft = fft(&even);
let odd_fft = fft(&odd);
let two_pi = T::PI() + T::PI();
let n_t = T::from(n).unwrap();
for k in 0..n / 2 {
let k_t = T::from(k).unwrap();
let theta = two_pi * k_t / n_t;
let re = theta.cos();
let im = -theta.sin();
let re_odd = odd_fft[2 * k];
let im_odd = odd_fft[2 * k + 1];
out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
}
out
}
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
fn dft<T: Float>(inp: &[T]) -> Vec<T> {
let zero = T::zero();
let n = inp.len();
let two_pi = T::PI() + T::PI();
let mut out = Vec::with_capacity(2 * n);
let n_t = T::from(n).unwrap();
for k in 0..n {
let k_t = T::from(k).unwrap();
let mut re = zero;
let mut im = zero;
for (j, &inp) in inp.iter().enumerate() {
let j_t = T::from(j).unwrap();
let angle = two_pi * k_t * j_t / n_t;
re += inp * angle.cos();
im -= inp * angle.sin();
}
out.push(re);
out.push(im);
}
out
}
#[allow(clippy::too_many_arguments)]
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
fn log_mel_spectrogram_w<T: Float>(
ith: usize,
hann: &[T],
samples: &[T],
filters: &[T],
fft_size: usize,
fft_step: usize,
speed_up: bool,
n_len: usize,
n_mel: usize,
n_threads: usize,
) -> Vec<T> {
let n_fft = if speed_up {
1 + fft_size / 4
} else {
1 + fft_size / 2
};
let zero = T::zero();
let half = T::from(0.5).unwrap();
let mut fft_in = vec![zero; fft_size];
let mut mel = vec![zero; n_len * n_mel];
for i in (ith..n_len).step_by(n_threads) {
let offset = i * fft_step;
// apply Hanning window
for j in 0..fft_size {
fft_in[j] = if offset + j < samples.len() {
hann[j] * samples[offset + j]
} else {
zero
}
}
// FFT -> mag^2
let mut fft_out: Vec<T> = fft(&fft_in);
for j in 0..fft_size {
fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
}
for j in 1..fft_size / 2 {
let v = fft_out[fft_size - j];
fft_out[j] += v;
}
if speed_up {
// scale down in the frequency domain results in a speed up in the time domain
for j in 0..n_fft {
fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
}
}
// mel spectrogram
for j in 0..n_mel {
let mut sum = zero;
for k in 0..n_fft {
sum += fft_out[k] * filters[j * n_fft + k];
}
mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
}
}
mel
}
fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
samples: &[T],
filters: &[T],
fft_size: usize,
fft_step: usize,
n_mel: usize,
speed_up: bool,
) -> Vec<T> {
let zero = T::zero();
let two_pi = T::PI() + T::PI();
let half = T::from(0.5).unwrap();
let one = T::from(1.0).unwrap();
let four = T::from(4.0).unwrap();
let fft_size_t = T::from(fft_size).unwrap();
let hann: Vec<T> = (0..fft_size)
.map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
.collect();
let n_len = samples.len() / fft_step;
// pad audio with at least one extra chunk of zeros
let pad = 100 * worker::m::CHUNK_LENGTH / 2;
let n_len = if n_len % pad != 0 {
(n_len / pad + 1) * pad
} else {
n_len
};
let n_len = n_len + pad;
let samples = {
let mut samples_padded = samples.to_vec();
let to_add = n_len * fft_step - samples.len();
samples_padded.extend(std::iter::repeat(zero).take(to_add));
samples_padded
};
// Use a single thread for now.
let mut mel = log_mel_spectrogram_w(
0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
);
let mmax = mel
.iter()
.max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
.copied()
.unwrap_or(zero)
- T::from(8).unwrap();
for m in mel.iter_mut() {
let v = T::max(*m, mmax);
*m = v / four + one
}
mel
}
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
cfg: &worker::m::Config,
samples: &[T],
filters: &[T],
) -> anyhow::Result<Vec<T>> {
let mel = log_mel_spectrogram_(
samples,
filters,
worker::m::N_FFT,
worker::m::HOP_LENGTH,
cfg.num_mel_bins,
false,
);
Ok(mel)
}

View File

@@ -0,0 +1,4 @@
fn main() {
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
yew::Renderer::<candle_wasm_example_whisper::App>::new().render();
}

View File

@@ -0,0 +1,53 @@
use candle_wasm_example_whisper::worker::{Decoder as D, ModelData};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct Decoder {
decoder: D,
}
#[wasm_bindgen]
impl Decoder {
#[wasm_bindgen(constructor)]
#[allow(clippy::too_many_arguments)]
pub fn new(
weights: Vec<u8>,
tokenizer: Vec<u8>,
mel_filters: Vec<u8>,
config: Vec<u8>,
quantized: bool,
is_multilingual: bool,
timestamps: bool,
task: Option<String>,
language: Option<String>,
) -> Result<Decoder, JsError> {
let decoder = D::load(ModelData {
tokenizer,
mel_filters,
config,
quantized,
weights,
is_multilingual,
timestamps,
task,
language,
});
match decoder {
Ok(decoder) => Ok(Self { decoder }),
Err(e) => Err(JsError::new(&e.to_string())),
}
}
#[wasm_bindgen]
pub fn decode(&mut self, wav_input: Vec<u8>) -> Result<String, JsError> {
let segments = self
.decoder
.convert_and_run(&wav_input)
.map_err(|e| JsError::new(&e.to_string()))?;
let json = serde_json::to_string(&segments)?;
Ok(json)
}
}
fn main() {}

View File

@@ -0,0 +1,4 @@
use yew_agent::PublicWorker;
fn main() {
candle_wasm_example_whisper::Worker::register();
}

View File

@@ -0,0 +1,101 @@
pub const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
("de", "german"),
("es", "spanish"),
("ru", "russian"),
("ko", "korean"),
("fr", "french"),
("ja", "japanese"),
("pt", "portuguese"),
("tr", "turkish"),
("pl", "polish"),
("ca", "catalan"),
("nl", "dutch"),
("ar", "arabic"),
("sv", "swedish"),
("it", "italian"),
("id", "indonesian"),
("hi", "hindi"),
("fi", "finnish"),
("vi", "vietnamese"),
("he", "hebrew"),
("uk", "ukrainian"),
("el", "greek"),
("ms", "malay"),
("cs", "czech"),
("ro", "romanian"),
("da", "danish"),
("hu", "hungarian"),
("ta", "tamil"),
("no", "norwegian"),
("th", "thai"),
("ur", "urdu"),
("hr", "croatian"),
("bg", "bulgarian"),
("lt", "lithuanian"),
("la", "latin"),
("mi", "maori"),
("ml", "malayalam"),
("cy", "welsh"),
("sk", "slovak"),
("te", "telugu"),
("fa", "persian"),
("lv", "latvian"),
("bn", "bengali"),
("sr", "serbian"),
("az", "azerbaijani"),
("sl", "slovenian"),
("kn", "kannada"),
("et", "estonian"),
("mk", "macedonian"),
("br", "breton"),
("eu", "basque"),
("is", "icelandic"),
("hy", "armenian"),
("ne", "nepali"),
("mn", "mongolian"),
("bs", "bosnian"),
("kk", "kazakh"),
("sq", "albanian"),
("sw", "swahili"),
("gl", "galician"),
("mr", "marathi"),
("pa", "punjabi"),
("si", "sinhala"),
("km", "khmer"),
("sn", "shona"),
("yo", "yoruba"),
("so", "somali"),
("af", "afrikaans"),
("oc", "occitan"),
("ka", "georgian"),
("be", "belarusian"),
("tg", "tajik"),
("sd", "sindhi"),
("gu", "gujarati"),
("am", "amharic"),
("yi", "yiddish"),
("lo", "lao"),
("uz", "uzbek"),
("fo", "faroese"),
("ht", "haitian creole"),
("ps", "pashto"),
("tk", "turkmen"),
("nn", "nynorsk"),
("mt", "maltese"),
("sa", "sanskrit"),
("lb", "luxembourgish"),
("my", "myanmar"),
("bo", "tibetan"),
("tl", "tagalog"),
("mg", "malagasy"),
("as", "assamese"),
("tt", "tatar"),
("haw", "hawaiian"),
("ln", "lingala"),
("ha", "hausa"),
("ba", "bashkir"),
("jw", "javanese"),
("su", "sundanese"),
];

View File

@@ -0,0 +1,29 @@
pub const WITH_TIMER: bool = true;
struct Timer {
label: &'static str,
}
// impl Timer {
// fn new(label: &'static str) -> Self {
// if WITH_TIMER {
// web_sys::console::time_with_label(label);
// }
// Self { label }
// }
// }
impl Drop for Timer {
fn drop(&mut self) {
if WITH_TIMER {
web_sys::console::time_end_with_label(self.label)
}
}
}
mod app;
mod audio;
pub mod languages;
pub mod worker;
pub use app::App;
pub use worker::Worker;

View File

@@ -0,0 +1,492 @@
use crate::languages::LANGUAGES;
use anyhow::Error as E;
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
use candle_nn::{ops::softmax, VarBuilder};
pub use candle_transformers::models::whisper::{self as m, Config};
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
use yew_agent::{HandlerId, Public, WorkerLink};
#[wasm_bindgen]
extern "C" {
// Use `js_namespace` here to bind `console.log(..)` instead of just
// `log(..)`
#[wasm_bindgen(js_namespace = console)]
pub fn log(s: &str);
}
#[macro_export]
macro_rules! console_log {
// Note that this is using the `log` function imported above during
// `bare_bones`
($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string()))
}
pub const DTYPE: DType = DType::F32;
pub enum Model {
Normal(m::model::Whisper),
Quantized(m::quantized_model::Whisper),
}
// Maybe we should use some traits rather than doing the dispatch for all these.
impl Model {
pub fn config(&self) -> &Config {
match self {
Self::Normal(m) => &m.config,
Self::Quantized(m) => &m.config,
}
}
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.encoder.forward(x, flush),
Self::Quantized(m) => m.encoder.forward(x, flush),
}
}
pub fn decoder_forward(
&mut self,
x: &Tensor,
xa: &Tensor,
flush: bool,
) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.forward(x, xa, flush),
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
}
}
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.final_linear(x),
Self::Quantized(m) => m.decoder.final_linear(x),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecodingResult {
pub tokens: Vec<u32>,
pub text: String,
pub avg_logprob: f64,
pub no_speech_prob: f64,
temperature: f64,
compression_ratio: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Segment {
pub start: f64,
pub duration: f64,
pub dr: DecodingResult,
}
pub struct Decoder {
model: Model,
rng: rand::rngs::StdRng,
task: Option<Task>,
language: Option<String>,
is_multilingual: bool,
mel_filters: Vec<f32>,
timestamps: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
translate_token: u32,
eot_token: u32,
no_speech_token: u32,
no_timestamps_token: u32,
}
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
mel_filters: Vec<f32>,
device: &Device,
task: Option<Task>,
language: Option<String>,
is_multilingual: bool,
timestamps: bool,
) -> anyhow::Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
.map(|i| {
if model.config().suppress_tokens.contains(&i) {
f32::NEG_INFINITY
} else {
0f32
}
})
.collect();
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = m::NO_SPEECH_TOKENS
.iter()
.find_map(|token| token_id(&tokenizer, token).ok());
let no_speech_token = match no_speech_token {
None => anyhow::bail!("unable to find any non-speech token"),
Some(n) => n,
};
let seed = 299792458;
Ok(Self {
model,
rng: StdRng::seed_from_u64(seed),
tokenizer,
mel_filters,
task,
timestamps,
language,
is_multilingual,
suppress_tokens,
sot_token,
transcribe_token,
translate_token,
eot_token,
no_speech_token,
no_timestamps_token,
})
}
fn decode(&mut self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
let model = &mut self.model;
let language_token = match (self.is_multilingual, &self.language) {
(true, None) => Some(detect_language(model, &self.tokenizer, mel)?),
(false, None) => None,
(true, Some(language)) => {
match token_id(&self.tokenizer, &format!("<|{:?}|>", self.language)) {
Ok(token_id) => Some(token_id),
Err(_) => anyhow::bail!("language {language} is not supported"),
}
}
(false, Some(_)) => {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
let audio_features = model.encoder_forward(mel, true)?;
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config().max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
if let Some(language_token) = language_token {
tokens.push(language_token);
}
match self.task {
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
}
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}
let (_, seq_len, _) = ys.dims3()?;
let logits = model
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
// ApplyTimestampRules, i.e.:
// - Timestamps come in pairs, except before EOT.
// - Timestamps should be non-decreasing.
// - If the sum of the probabilities of timestamps is higher than any other tokens,
// only consider timestamps when sampling.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
tokens.push(next_token);
let prob = softmax(&logits, candle::D::Minus1)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
break;
}
sum_logprob += prob.ln();
}
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
let avg_logprob = sum_logprob / tokens.len() as f64;
Ok(DecodingResult {
tokens,
text,
avg_logprob,
no_speech_prob,
temperature: t,
compression_ratio: f64::NAN,
})
}
fn decode_with_fallback(&mut self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult, _> = self.decode(segment, t);
if i == m::TEMPERATURES.len() - 1 {
return dr;
}
// On errors, we try again with a different temperature.
match dr {
Ok(dr) => {
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
Err(err) => {
console_log!("Error running at {t}: {err}")
}
}
}
unreachable!()
}
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
let mut segments = vec![];
while seek < content_frames {
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
console_log!("no speech detected, skipping {seek} {dr:?}");
continue;
}
let segment = Segment {
start: time_offset,
duration: segment_duration,
dr,
};
console_log!("{seek}: {segment:?}");
segments.push(segment)
}
Ok(segments)
}
pub fn load(md: ModelData) -> anyhow::Result<Self> {
let device = Device::Cpu;
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(E::msg)?;
let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
console_log!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let config: Config = serde_json::from_slice(&md.config)?;
let model = if md.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&md.weights,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb = VarBuilder::from_buffered_safetensors(md.weights, m::DTYPE, &device)?;
Model::Normal(m::model::Whisper::load(&vb, config)?)
};
console_log!("done loading model");
let task = match md.task.as_deref() {
Some("translate") => Some(Task::Translate),
_ => Some(Task::Transcribe),
};
let decoder = Self::new(
model,
tokenizer,
mel_filters,
&device,
task,
md.language,
md.is_multilingual,
md.timestamps,
)?;
Ok(decoder)
}
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
let device = Device::Cpu;
let mut wav_input = std::io::Cursor::new(wav_input);
let (header, data) = wav::read(&mut wav_input)?;
console_log!("loaded wav data: {header:?}");
if header.sampling_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
console_log!("pcm data loaded {}", pcm_data.len());
let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?;
let mel_len = mel.len();
let n_mels = self.model.config().num_mel_bins;
let mel = Tensor::from_vec(mel, (1, n_mels, mel_len / n_mels), &device)?;
console_log!("loaded mel: {:?}", mel.dims());
let segments = self.run(&mel)?;
Ok(segments)
}
}
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32, E> {
console_log!("detecting language");
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(
2,
0,
usize::min(seq_len, model.config().max_source_positions),
)?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
.map(|e| e.map_err(E::msg))
.collect::<Result<Vec<_>, E>>()?;
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
let audio_features = model.encoder_forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}")
}
let token = &format!("<|{}|>", probs[0].0 .0);
let language = token_id(tokenizer, token)?;
console_log!("detected language: {language} {token}");
Ok(language)
}
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id),
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum Task {
Transcribe,
Translate,
}
// Communication to the worker happens through bincode, the model weights and configs are fetched
// on the main thread and transferred via the following structure.
#[derive(Serialize, Deserialize)]
pub struct ModelData {
pub weights: Vec<u8>,
pub tokenizer: Vec<u8>,
pub mel_filters: Vec<u8>,
pub config: Vec<u8>,
pub quantized: bool,
pub timestamps: bool,
pub is_multilingual: bool,
pub language: Option<String>,
pub task: Option<String>,
}
pub struct Worker {
link: WorkerLink<Self>,
decoder: Option<Decoder>,
}
#[derive(Serialize, Deserialize)]
pub enum WorkerInput {
ModelData(ModelData),
DecodeTask { wav_bytes: Vec<u8> },
}
#[derive(Serialize, Deserialize)]
pub enum WorkerOutput {
Decoded(Vec<Segment>),
WeightsLoaded,
}
impl yew_agent::Worker for Worker {
type Input = WorkerInput;
type Message = ();
type Output = Result<WorkerOutput, String>;
type Reach = Public<Self>;
fn create(link: WorkerLink<Self>) -> Self {
Self {
link,
decoder: None,
}
}
fn update(&mut self, _msg: Self::Message) {
// no messaging
}
fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
let output = match msg {
WorkerInput::ModelData(md) => match Decoder::load(md) {
Ok(decoder) => {
self.decoder = Some(decoder);
Ok(WorkerOutput::WeightsLoaded)
}
Err(err) => Err(format!("model creation error {err:?}")),
},
WorkerInput::DecodeTask { wav_bytes } => match &mut self.decoder {
None => Err("model has not been set".to_string()),
Some(decoder) => decoder
.convert_and_run(&wav_bytes)
.map(WorkerOutput::Decoded)
.map_err(|e| e.to_string()),
},
};
self.link.respond(id, output);
}
fn name_of_resource() -> &'static str {
"worker.js"
}
fn resource_path_is_relative() -> bool {
true
}
}

View File

@@ -0,0 +1,116 @@
//load the candle Whisper decoder wasm module
import init, { Decoder } from "./build/m.js";
async function fetchArrayBuffer(url) {
const cacheName = "whisper-candle-cache";
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const data = await cachedResponse.arrayBuffer();
return new Uint8Array(data);
}
const res = await fetch(url, { cache: "force-cache" });
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
class Whisper {
static instance = {};
// Retrieve the Whisper model. When called for the first time,
// this will load the model and save it for future use.
static async getInstance(params) {
const {
weightsURL,
modelID,
tokenizerURL,
mel_filtersURL,
configURL,
quantized,
is_multilingual,
timestamps,
task,
language,
} = params;
// load individual modelID only once
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [
weightsArrayU8,
tokenizerArrayU8,
mel_filtersArrayU8,
configArrayU8,
] = await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(mel_filtersURL),
fetchArrayBuffer(configURL),
]);
this.instance[modelID] = new Decoder(
weightsArrayU8,
tokenizerArrayU8,
mel_filtersArrayU8,
configArrayU8,
quantized,
is_multilingual,
timestamps,
task,
language
);
} else {
self.postMessage({ status: "loading", message: "Model Already Loaded" });
}
return this.instance[modelID];
}
}
self.addEventListener("message", async (event) => {
const {
weightsURL,
modelID,
tokenizerURL,
configURL,
mel_filtersURL,
audioURL,
} = event.data;
try {
self.postMessage({ status: "decoding", message: "Starting Decoder" });
let quantized = false;
if (modelID.includes("quantized")) {
quantized = true;
}
let is_multilingual = false;
if (modelID.includes("multilingual")) {
is_multilingual = true;
}
let timestamps = true;
const decoder = await Whisper.getInstance({
weightsURL,
modelID,
tokenizerURL,
mel_filtersURL,
configURL,
quantized,
is_multilingual,
timestamps,
task: null,
language: null,
});
self.postMessage({ status: "decoding", message: "Loading Audio" });
const audioArrayU8 = await fetchArrayBuffer(audioURL);
self.postMessage({ status: "decoding", message: "Running Decoder..." });
const segments = decoder.decode(audioArrayU8);
// Send the segment back to the main thread as JSON
self.postMessage({
status: "complete",
message: "complete",
output: JSON.parse(segments),
});
} catch (e) {
self.postMessage({ error: e });
}
});

View File

@@ -13,6 +13,7 @@ if ! which rust-whisper-baked; then
fi >&2
cat <<EOF
rust-whisper-baked --stream-device pulse_monitor --stream-step 16 --stream-retain 8 --stream-{head,tail}=0.25 2> /dev/null
rust-whisper-baked --stream-device 'BlackHole 2ch' --stream-step 30 --stream-retain 1 --stream-{head,tail}=0.25 --threads 9 2> /dev/null
| tee -a "$HOME/Sync/drawful/DnD/bdoob/__log.d/$(date +%Y.%m.%d).transcript.txt"
| tee -a "$HOME/Sync/drawful/DnD/nessira.d/_log.d/$(date +%Y.%m.%d).transcript.txt"

View File

@@ -25,7 +25,11 @@ pub fn devices() -> Vec<String> {
fn _devices() -> Result<Vec<cpal::Device>, String> {
match cpal::default_host().devices() {
Ok(devices) => Ok(devices.filter(|device| {
device.supported_input_configs().unwrap().count() > 0
let input_configs = device.supported_input_configs();
if !input_configs.is_ok() {
return false;
}
input_configs.unwrap().count() > 0
}).collect()),
Err(msg) => Err(format!("failed to get devices: {}", msg)),
}
@@ -92,13 +96,22 @@ impl Listener {
filter(|device| device.name().unwrap() == self.device_name).
collect::<Vec<_>>();
let device = devices.first().unwrap();
let cfg = device.supported_input_configs()
let mut sample_rate = 15_500;
let mut cfgs: Vec<_> = device.supported_input_configs()
.unwrap()
.filter(|x| x.sample_format() == cpal::SampleFormat::F32)
.filter(|x| x.min_sample_rate() >= cpal::SampleRate(15_500))
.nth(0)
.filter(|x| x.min_sample_rate() >= cpal::SampleRate(sample_rate))
.collect();
while cfgs.len() == 0 && sample_rate > 0 {
sample_rate /= 2;
cfgs = device.supported_input_configs()
.unwrap()
.with_max_sample_rate();
.filter(|x| x.sample_format() == cpal::SampleFormat::F32)
.filter(|x| x.min_sample_rate() >= cpal::SampleRate(sample_rate))
.collect();
}
assert!(cfgs.len() > 0);
let cfg = cfgs[0].clone().with_max_sample_rate();
let downsample_ratio = cfg.channels() as f32 * (cfg.sample_rate().0 as f32 / 16_000.0);
let stream = device.build_input_stream(

View File

@@ -2,21 +2,6 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "addr2line"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb"
dependencies = [
"gimli",
]
[[package]]
name = "adler"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "1.1.2"
@@ -26,21 +11,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "0.6.5"
@@ -89,27 +59,6 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "backtrace"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
]
[[package]]
name = "bindgen"
version = "0.68.1"
@@ -139,18 +88,6 @@ version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
[[package]]
name = "bumpalo"
version = "3.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cc"
version = "1.0.83"
@@ -175,20 +112,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets 0.48.5",
]
[[package]]
name = "clang-sys"
version = "1.6.1"
@@ -255,12 +178,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "core-foundation-sys"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
[[package]]
name = "either"
version = "1.9.0"
@@ -283,12 +200,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "gimli"
version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "glob"
version = "0.3.1"
@@ -310,38 +221,6 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "iana-time-zone"
version = "0.1.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "js-sys"
version = "0.3.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca"
dependencies = [
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@@ -394,15 +273,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7"
dependencies = [
"adler",
]
[[package]]
name = "nom"
version = "7.1.3"
@@ -413,24 +283,6 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "num-traits"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
dependencies = [
"autocfg",
]
[[package]]
name = "object"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
version = "1.19.0"
@@ -443,12 +295,6 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]]
name = "pin-project-lite"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]]
name = "prettyplease"
version = "0.2.15"
@@ -523,20 +369,11 @@ dependencies = [
name = "rust-whisper-lib"
version = "0.1.0"
dependencies = [
"byteorder",
"chrono",
"clap",
"tokio",
"wav",
"whisper-rs",
]
[[package]]
name = "rustc-demangle"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -579,16 +416,6 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "tokio"
version = "1.35.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104"
dependencies = [
"backtrace",
"pin-project-lite",
]
[[package]]
name = "unicode-ident"
version = "1.0.12"
@@ -601,60 +428,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "wasm-bindgen"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e"
dependencies = [
"cfg-if",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f"
[[package]]
name = "wav"
version = "1.0.0"
@@ -715,37 +488,13 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.51.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64"
dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets 0.52.0",
]
[[package]]
name = "windows-targets"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
dependencies = [
"windows_aarch64_gnullvm 0.48.5",
"windows_aarch64_msvc 0.48.5",
"windows_i686_gnu 0.48.5",
"windows_i686_msvc 0.48.5",
"windows_x86_64_gnu 0.48.5",
"windows_x86_64_gnullvm 0.48.5",
"windows_x86_64_msvc 0.48.5",
"windows-targets",
]
[[package]]
@@ -754,93 +503,51 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd"
dependencies = [
"windows_aarch64_gnullvm 0.52.0",
"windows_aarch64_msvc 0.52.0",
"windows_i686_gnu 0.52.0",
"windows_i686_msvc 0.52.0",
"windows_x86_64_gnu 0.52.0",
"windows_x86_64_gnullvm 0.52.0",
"windows_x86_64_msvc 0.52.0",
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.0"

View File

@@ -6,7 +6,7 @@ pub fn channel<F>(
stream: std::sync::mpsc::Receiver<Vec<f32>>,
) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static {
flags.model_path = None;
flags.model_buffer = Some(include_bytes!("../../models/ggml-small.en.bin").to_vec());
flags.model_buffer = Some(get_fast());
rust_whisper_lib::channel(flags.clone(), handler_fn, stream);
}
@@ -15,7 +15,27 @@ pub fn wav<F>(
handler_fn: F
) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static {
flags.model_path = None;
flags.model_buffer = Some(include_bytes!("../../models/ggml-distil-medium.en.bin").to_vec());
flags.model_buffer = Some(get_good());
rust_whisper_lib::wav(flags.clone(), handler_fn, flags.wav.unwrap());
}
pub fn wav_channel<F>(
mut flags: rust_whisper_lib::Flags,
handler_fn: F
) where F: FnMut(Result<rust_whisper_lib::Transcribed, String>) + Send + 'static {
flags.model_path = None;
flags.model_buffer = Some(get_good());
rust_whisper_lib::wav_channel(flags, handler_fn);
}
pub fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
rust_whisper_lib::f32_from_wav_file(path)
}
fn get_fast() -> Vec<u8> {
include_bytes!("../../models/ggml-small.en.bin").to_vec()
}
fn get_good() -> Vec<u8> {
include_bytes!("../../models/ggml-distil-medium.en.bin").to_vec()
}

View File

@@ -2,21 +2,6 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "addr2line"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb"
dependencies = [
"gimli",
]
[[package]]
name = "adler"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "1.1.2"
@@ -48,21 +33,6 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "0.6.5"
@@ -117,21 +87,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "backtrace"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
]
[[package]]
name = "bindgen"
version = "0.68.1"
@@ -193,12 +148,6 @@ version = "3.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.5.0"
@@ -236,20 +185,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets 0.48.5",
]
[[package]]
name = "clang-sys"
version = "1.6.1"
@@ -411,12 +346,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "gimli"
version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "glob"
version = "0.3.1"
@@ -444,29 +373,6 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "iana-time-zone"
version = "0.1.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "indexmap"
version = "2.1.0"
@@ -477,6 +383,12 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "itoa"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
[[package]]
name = "jni"
version = "0.19.0"
@@ -608,15 +520,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7"
dependencies = [
"adler",
]
[[package]]
name = "ndk"
version = "0.7.0"
@@ -708,15 +611,6 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "object"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0"
dependencies = [
"memchr",
]
[[package]]
name = "oboe"
version = "0.5.0"
@@ -775,12 +669,6 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]]
name = "pin-project-lite"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]]
name = "pkg-config"
version = "0.3.27"
@@ -875,14 +763,26 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9b1a3d5f46d53f4a3478e2be4a5a5ce5108ea58b100dcd139830eae7f79a3a1"
[[package]]
name = "rust-stemmers"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46a2036019fdb888131db7a4c847a1063a7493f971ed94ea82c67eada63ca54"
dependencies = [
"serde",
"serde_derive",
]
[[package]]
name = "rust-whisper-baked"
version = "0.1.0"
dependencies = [
"clap",
"listen-lib",
"rust-stemmers",
"rust-whisper-baked-lib",
"rust-whisper-lib",
"stop-words",
]
[[package]]
@@ -896,20 +796,11 @@ dependencies = [
name = "rust-whisper-lib"
version = "0.1.0"
dependencies = [
"byteorder",
"chrono",
"clap",
"tokio",
"wav",
"whisper-rs",
]
[[package]]
name = "rustc-demangle"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -929,6 +820,12 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "ryu"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c"
[[package]]
name = "same-file"
version = "1.0.6"
@@ -944,6 +841,37 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
version = "1.0.193"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.193"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.41",
]
[[package]]
name = "serde_json"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb0652c533506ad7a2e353cce269330d6afd8bdfb6d75e0ace5b35aacbd7b9e9"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "shlex"
version = "1.2.0"
@@ -975,6 +903,15 @@ version = "1.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
[[package]]
name = "stop-words"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8500024d809de02ecbf998472b7bed3c4fca380df2be68917f6a473bdb28ddcc"
dependencies = [
"serde_json",
]
[[package]]
name = "strsim"
version = "0.10.0"
@@ -1023,16 +960,6 @@ dependencies = [
"syn 2.0.41",
]
[[package]]
name = "tokio"
version = "1.35.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104"
dependencies = [
"backtrace",
"pin-project-lite",
]
[[package]]
name = "toml_datetime"
version = "0.6.5"
@@ -1226,15 +1153,6 @@ dependencies = [
"windows-targets 0.42.2",
]
[[package]]
name = "windows-core"
version = "0.51.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64"
dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "windows-sys"
version = "0.52.0"

View File

@@ -10,3 +10,5 @@ rust-whisper-lib = { path = "../rust-whisper-lib" }
rust-whisper-baked-lib = { path = "../rust-whisper-baked-lib" }
listen-lib = { path = "../listen-lib" }
clap = { version = "4.4.10", features = ["derive"] }
stop-words = "0.8.0"
rust-stemmers = "1.2.0"

View File

@@ -1,21 +1,70 @@
use rust_whisper_lib;
use rust_whisper_baked_lib;
use clap::Parser;
use listen_lib;
use rust_whisper_baked_lib;
use rust_whisper_lib;
use std::thread;
fn main() {
let flags = rust_whisper_lib::Flags::parse();
match flags.wav.clone() {
Some(_) => wav_channel(flags),
None => channel(flags),
};
}
fn wav_channel(flags: rust_whisper_lib::Flags) {
let mut w = new_destutterer();
rust_whisper_baked_lib::wav_channel(
flags.clone(),
move |result: Result<rust_whisper_lib::Transcribed, String>| {
match result {
Ok(transcribed) => {
let s = w.step(transcribed.to_string());
println!("{}", s);
}
Err(msg) => {
eprintln!("error: {}", msg);
}
};
},
);
}
fn wav(flags: rust_whisper_lib::Flags, _path: String) {
let mut w = new_destutterer();
rust_whisper_baked_lib::wav(
flags,
move |result: Result<rust_whisper_lib::Transcribed, String>| {
match result {
Ok(transcribed) => {
let s = w.step(transcribed.to_string());
println!("{}", s);
}
Err(msg) => {
eprintln!("error: {}", msg);
}
};
},
);
}
fn channel(flags: rust_whisper_lib::Flags) {
let (send, recv) = std::sync::mpsc::sync_channel(100);
eprintln!("rust whisper baked lib channel...");
thread::spawn(move || {
let mut w = new_destutterer();
rust_whisper_baked_lib::channel(
flags.clone(),
|result: Result<rust_whisper_lib::Transcribed, String>| {
move |result: Result<rust_whisper_lib::Transcribed, String>| {
match result {
Ok(transcribed) => { println!("{}", transcribed.to_string()); },
Err(msg) => { eprintln!("error: {}", msg); },
Ok(transcribed) => {
let s = w.step(transcribed.to_string());
println!("{}", s);
}
Err(msg) => {
eprintln!("error: {}", msg);
}
};
},
recv,
@@ -26,17 +75,25 @@ fn main() {
let flags = rust_whisper_lib::Flags::parse();
match flags.stream_device {
Some(device_name) => {
if device_name == "" {
eprintln!("with device ({}) '{}'", device_name.len(), &device_name);
if device_name.len() == 0 {
let mut i = 0;
for device in listen_lib::devices() {
eprintln!("{}", device);
eprintln!("[{}] {}", i, device);
i += 1;
}
eprintln!("found {} devices", i);
} else {
listen_lib::main_with(|data| {
listen_lib::main_with(
|data| {
send.send(data).unwrap();
}, device_name);
}
},
device_name,
);
}
}
None => {
eprintln!("without any device");
listen_lib::main(|data| {
send.send(data).unwrap();
});
@@ -44,3 +101,184 @@ fn main() {
}
eprintln!("/listen lib main...");
}
struct Destutterer {
prev: Words,
}
fn new_destutterer() -> Destutterer {
Destutterer { prev: new_words() }
}
impl Destutterer {
fn step(&mut self, next: String) -> String {
if next.len() == 0 {
return next;
}
let next_words = Words::from_string(next.clone());
let mut n = self
.prev
.comparable_len()
.clamp(0, next_words.comparable_len());
//println!("n={} prev='{:?}' next='{:?}'", n, self.prev.to_comparable_words(), next_words.to_comparable_words());
while n > 0 {
let (prev_s, _) = self.prev.last_n_comparable_to_string(n);
let (next_s, next_idx) = next_words.first_n_comparable_to_string(n);
if prev_s == next_s {
self.prev = next_words;
return self.prev.skip(next_idx + 1).to_string();
}
n -= 1;
}
self.prev = next_words;
self.prev.to_string()
}
}
#[derive(Clone, Debug)]
struct Words {
raw: Vec<String>,
}
fn new_words() -> Words {
Words { raw: vec![] }
}
impl Words {
fn from_string(s: String) -> Words {
let mut result = Words { raw: vec![] };
for word in s.split(" ") {
let word = word.trim();
if word.len() > 0 {
result.raw.push(word.to_string());
}
}
result
}
fn skip(&self, n: usize) -> Words {
Words {
raw: self.raw.iter().skip(n).map(|x| x.clone()).collect(),
}
}
fn last_n_comparable_to_string(&self, n: usize) -> (String, usize) {
let v = self.to_comparable_words();
let v = v[(v.len() - n).clamp(0, v.len())..].to_vec();
return (
v.iter()
.map(|x| x.s.clone().unwrap())
.collect::<Vec<String>>()
.join(" "),
v[0].idx,
);
}
fn first_n_comparable_to_string(&self, n: usize) -> (String, usize) {
let v = self.to_comparable_words();
let v = v[0..n.clamp(0, v.len())].to_vec();
return (
v.iter()
.map(|x| x.s.clone().unwrap())
.collect::<Vec<String>>()
.join(" "),
v[v.len() - 1].idx,
);
}
fn comparable_len(&self) -> usize {
self.to_comparable_words().len()
}
fn to_comparable_words(&self) -> Vec<Word> {
self.to_words()
.iter()
.filter(|x| x.s.is_some())
.map(|x| x.clone())
.collect()
}
fn to_words(&self) -> Vec<Word> {
let skips = stop_words::get("en");
let stemmer = rust_stemmers::Stemmer::create(rust_stemmers::Algorithm::English);
let strs = self
.raw
.iter()
.map(|w| w.to_lowercase())
.map(|w| {
w.chars()
.filter(|c| c.is_ascii_alphanumeric())
.collect::<String>()
})
.map(|w| stemmer.stem(&w).into_owned())
.collect::<Vec<String>>();
let mut result = vec![];
for i in 0..strs.len() {
result.push(Word {
s: match skips.contains(&strs[i]) {
true => None,
false => Some(strs[i].clone()),
},
idx: i as usize,
});
}
result
}
fn to_string(&self) -> String {
self.raw
.iter()
.map(|x| x.clone())
.collect::<Vec<String>>()
.join(" ")
}
}
#[derive(Debug, Clone)]
struct Word {
s: Option<String>,
idx: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_destutterer_stop_words() {
let mut w = new_destutterer();
assert_eq!(
"welcome to the internet".to_string(),
w.step("welcome to the internet".to_string())
);
assert_eq!(
"have a look around".to_string(),
w.step("welcome to the a internet; have a look around".to_string())
);
}
#[test]
fn test_destutterer_punctuation() {
let mut w = new_destutterer();
assert_eq!(
"cat, dog. cow? moose!".to_string(),
w.step("cat, dog. cow? moose!".to_string())
);
assert_eq!(
"elephant! fez gator".to_string(),
w.step("moose, elephant! fez gator".to_string())
);
assert_eq!("hij".to_string(), w.step("fez gator hij".to_string()));
}
#[test]
fn test_destutterer_basic() {
let mut w = new_destutterer();
assert_eq!(
"cat dog cow".to_string(),
w.step(" cat dog cow ".to_string())
);
assert_eq!("moose".to_string(), w.step(" dog cow moose ".to_string()));
}
}

View File

@@ -3,44 +3,14 @@
version = 3
[[package]]
name = "addr2line"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb"
dependencies = [
"gimli",
]
[[package]]
name = "adler"
name = "aho-corasick"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0"
checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41"
dependencies = [
"memchr",
]
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "0.6.5"
@@ -89,27 +59,6 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "backtrace"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
]
[[package]]
name = "bindgen"
version = "0.68.1"
@@ -135,30 +84,15 @@ dependencies = [
[[package]]
name = "bitflags"
version = "2.4.1"
version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
[[package]]
name = "bumpalo"
version = "3.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42"
[[package]]
name = "cc"
version = "1.0.83"
version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0"
dependencies = [
"libc",
]
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
[[package]]
name = "cexpr"
@@ -175,20 +109,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets 0.48.5",
]
[[package]]
name = "clang-sys"
version = "1.6.1"
@@ -255,40 +175,18 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "core-foundation-sys"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
[[package]]
name = "either"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]]
name = "errno"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245"
dependencies = [
"libc",
"windows-sys",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "gimli"
version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "glob"
version = "0.3.1"
@@ -301,47 +199,6 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "home"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
dependencies = [
"windows-sys",
]
[[package]]
name = "iana-time-zone"
version = "0.1.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "js-sys"
version = "0.3.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca"
dependencies = [
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@@ -356,9 +213,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.151"
version = "0.2.147"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4"
checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "libloading"
@@ -370,23 +227,17 @@ dependencies = [
"winapi",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456"
[[package]]
name = "log"
version = "0.4.20"
version = "0.4.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
[[package]]
name = "memchr"
version = "2.6.4"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167"
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
[[package]]
name = "minimal-lexical"
@@ -394,15 +245,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7"
dependencies = [
"adler",
]
[[package]]
name = "nom"
version = "7.1.3"
@@ -413,29 +255,11 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "num-traits"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
dependencies = [
"autocfg",
]
[[package]]
name = "object"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
version = "1.19.0"
version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
[[package]]
name = "peeking_take_while"
@@ -443,12 +267,6 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]]
name = "pin-project-lite"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]]
name = "prettyplease"
version = "0.2.15"
@@ -461,27 +279,27 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.70"
version = "1.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b"
checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.33"
version = "1.0.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae"
checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965"
dependencies = [
"proc-macro2",
]
[[package]]
name = "regex"
version = "1.10.2"
version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343"
checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575"
dependencies = [
"aho-corasick",
"memchr",
@@ -491,9 +309,9 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.4.3"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f"
checksum = "39354c10dd07468c2e73926b23bb9c2caca74c5501e38a35da70406f1d923310"
dependencies = [
"aho-corasick",
"memchr",
@@ -502,9 +320,9 @@ dependencies = [
[[package]]
name = "regex-syntax"
version = "0.8.2"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2"
[[package]]
name = "riff"
@@ -516,39 +334,17 @@ checksum = "b9b1a3d5f46d53f4a3478e2be4a5a5ce5108ea58b100dcd139830eae7f79a3a1"
name = "rust-whisper-lib"
version = "0.1.0"
dependencies = [
"byteorder",
"chrono",
"clap",
"tokio",
"wav",
"whisper-rs",
]
[[package]]
name = "rustc-demangle"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustix"
version = "0.38.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316"
dependencies = [
"bitflags",
"errno",
"libc",
"linux-raw-sys",
"windows-sys",
]
[[package]]
name = "shlex"
version = "1.2.0"
@@ -563,30 +359,20 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "syn"
version = "2.0.41"
version = "2.0.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269"
checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "tokio"
version = "1.35.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104"
dependencies = [
"backtrace",
"pin-project-lite",
]
[[package]]
name = "unicode-ident"
version = "1.0.12"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c"
[[package]]
name = "utf8parse"
@@ -594,60 +380,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "wasm-bindgen"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e"
dependencies = [
"cfg-if",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f"
[[package]]
name = "wav"
version = "1.0.0"
@@ -659,14 +391,13 @@ dependencies = [
[[package]]
name = "which"
version = "4.4.2"
version = "4.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269"
dependencies = [
"either",
"home",
"libc",
"once_cell",
"rustix",
]
[[package]]
@@ -708,37 +439,13 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.51.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64"
dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets 0.52.0",
]
[[package]]
name = "windows-targets"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
dependencies = [
"windows_aarch64_gnullvm 0.48.5",
"windows_aarch64_msvc 0.48.5",
"windows_i686_gnu 0.48.5",
"windows_i686_msvc 0.48.5",
"windows_x86_64_gnu 0.48.5",
"windows_x86_64_gnullvm 0.48.5",
"windows_x86_64_msvc 0.48.5",
"windows-targets",
]
[[package]]
@@ -747,93 +454,51 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd"
dependencies = [
"windows_aarch64_gnullvm 0.52.0",
"windows_aarch64_msvc 0.52.0",
"windows_i686_gnu 0.52.0",
"windows_i686_msvc 0.52.0",
"windows_x86_64_gnu 0.52.0",
"windows_x86_64_gnullvm 0.52.0",
"windows_x86_64_msvc 0.52.0",
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.0"

View File

@@ -7,8 +7,5 @@ edition = "2021"
[dependencies]
whisper-rs = { path = "../gitea-whisper-rs", version = "0.8.0" }
wav = "1"
tokio = "1.27"
byteorder = "1.5.0"
chrono = "0.4.31"
clap = { version = "4.4.10", features = ["derive"] }
wav = "1"

View File

@@ -14,13 +14,13 @@ pub struct Flags {
#[arg(long, default_value = "8")]
pub threads: i32,
#[arg(long, default_value = "5")]
#[arg(long, default_value = "8")]
pub stream_step: u64,
#[arg(long, default_value = "0.6")]
#[arg(long, default_value = "4.0")]
pub stream_retain: f32,
#[arg(long, default_value = "0.3")]
#[arg(long, default_value = "2.0")]
pub stream_head: f32,
#[arg(long, default_value = "0.3")]
#[arg(long, default_value = "0.0")]
pub stream_tail: f32,
#[arg(long, default_value = "false")]
@@ -44,7 +44,7 @@ pub fn wav<F>(flags: Flags, handler_fn: F, wav_path: String) where F: FnMut(Resu
w.transcribe(&f32_from_wav_file(&wav_path).unwrap())
}
fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
pub fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
let f = std::fs::File::open(path);
if let Some(err) = f.as_ref().err() {
return Err(format!("failed to open wav file: {}", err));
@@ -61,12 +61,41 @@ fn f32_from_wav_file(path: &String) -> Result<Vec<f32>, String> {
return Err("!= 16_000 hz".to_string());
}
match data.as_sixteen() {
Some(data16) => Ok(whisper_rs::convert_integer_to_float_audio(&data16)),
Some(data16) => {
let mut floats = Vec::with_capacity(data16.len());
for sample in data16 {
floats.push(*sample as f32 / 32768.0);
}
Ok(floats)
},
None => Err(format!("couldnt translate wav to 16s")),
}
}
pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver<Vec<f32>>) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
pub fn wav_channel<F>(flags: Flags, handler_fn: F) where F: FnMut(Result<Transcribed, String>) + Send + 'static {
let path = flags.wav.as_ref().unwrap();
let mut audio = f32_from_wav_file(&path).unwrap();
let mut iter = vec![];
let n = audio.len() / match audio.len() % 100 {
0 => 100,
_ => 99,
};
for _ in 0..100 {
iter.push(audio.drain(0..n.clamp(0, audio.len())).collect());
}
let (fin_send, fin_recv) = std::sync::mpsc::sync_channel::<Option<i32>>(1);
channel_and_close(flags.clone(), handler_fn, iter, move || { fin_send.send(None).unwrap(); });
match fin_recv.recv() {
Ok(_) => {},
Err(x) => panic!("failed to receive: {}", x),
};
}
pub fn channel<F, I>(flags: Flags, handler_fn: F, stream: I) where F: FnMut(Result<Transcribed, String>) + Send + 'static, I: IntoIterator<Item = Vec<f32>> {
channel_and_close(flags, handler_fn, stream, || {});
}
fn channel_and_close<F, I, G>(flags: Flags, handler_fn: F, stream: I, mut close_fn: G) where F: FnMut(Result<Transcribed, String>) + Send + 'static, I: IntoIterator<Item = Vec<f32>>, G: FnMut() + Send + 'static {
let w = new_service(
flags.model_path,
flags.model_buffer,
@@ -81,7 +110,7 @@ pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver
false => {},
};
let mut buffer = vec![];
for data in stream.iter() {
for data in stream {
data.iter().for_each(|x| buffer.push(*x));
if buffer.len() >= (flags.stream_step * 16_000) as usize {
w.transcribe_async(&buffer).unwrap();
@@ -106,6 +135,10 @@ pub fn channel<F>(flags: Flags, handler_fn: F, stream: std::sync::mpsc::Receiver
buffer.truncate(stream_retain);
}
}
if buffer.len() > 0 {
w.transcribe(&buffer);
}
close_fn();
}
struct Service {
@@ -197,8 +230,10 @@ impl Impl {
let result = whispered
.after(&(self.stream_head * 100.0))
.before(&(self.stream_tail * 100.0));
if result.to_string().trim().len() > 0 {
(self.handler_fn.as_mut().unwrap())(Ok(result));
}
}
fn on_error(&mut self, msg: String) {
(self.handler_fn.as_mut().unwrap())(Err(format!("failed to transcribe: {}", &msg)));
@@ -333,6 +368,28 @@ impl Transcribed {
mod tests {
use super::*;
#[test]
fn test_transcribe_tiny_jfk_wav_whisper_rs_wav_channel() {
wav_channel(
Flags {
model_path: None,
model_buffer: Some(include_bytes!("../../models/ggml-tiny.en.bin").to_vec()),
threads: 8,
stream_step: 30,
stream_retain: 0.0,
stream_head: 0.0,
stream_tail: 0.0,
wav: Some("../gitea-whisper-rs/sys/whisper.cpp/bindings/go/samples/jfk.wav".to_string()),
debug: false,
stream_device: None,
},
move | result | {
assert!(result.is_ok());
assert_eq!(result.unwrap().to_string(), " And so my fellow Americans ask not what your country can do for you ask what you can do for your country.");
},
);
}
#[test]
fn test_transcribe_tiny_jfk_wav_whisper_rs() {
wav(

12
todo.yaml Executable file
View File

@@ -0,0 +1,12 @@
todo:
- wav to subtitles
- compound words like checkmark vs check mark should destutter
- whisper trims outside silence so head and tail never get hit
- split on silence-ish instead of duration
- rust-whisper warn when transcription time ~ input time
scheduled: []
done:
- todo: need to overlap without ANY puctuation, which i can do by breaking into words
ts: Tue Jan 2 13:23:00 EST 2024
- todo: overlap without stop words
ts: Wed Jan 3 03:22:14 EST 2024

BIN
wav_to_mkv.d/sc.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@@ -0,0 +1,100 @@
// This example is not going to build in this folder.
// You need to copy this code into your project and add the dependencies whisper_rs and hound in your cargo.toml
use hound;
use std::fs::File;
use std::io::Write;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
fn main() -> Result<(), &'static str> {
let args: Vec<String> = std::env::args().collect();
// Load a context and model.
let ctx = WhisperContext::new(&args[1])
.expect("failed to load model");
// Create a state
let mut state = ctx.create_state().expect("failed to create key");
// Create a params object for running the model.
// The number of past samples to consider defaults to 0.
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
// Edit params as needed.
// Set the number of threads to use to 1.
//params.set_n_threads(1);
// Enable translation.
params.set_translate(true);
// Set the language to translate to to English.
params.set_language(Some("en"));
// Disable anything that prints to stdout.
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
// Open the audio file.
let mut reader = hound::WavReader::open(&args[2]).expect("failed to open file");
#[allow(unused_variables)]
let hound::WavSpec {
channels,
sample_rate,
bits_per_sample,
..
} = reader.spec();
// Convert the audio to floating point samples.
let mut audio = whisper_rs::convert_integer_to_float_audio(
&reader
.samples::<i16>()
.map(|s| s.expect("invalid sample"))
.collect::<Vec<_>>(),
);
// Convert audio to 16KHz mono f32 samples, as required by the model.
// These utilities are provided for convenience, but can be replaced with custom conversion logic.
// SIMD variants of these functions are also available on nightly Rust (see the docs).
if channels == 2 {
audio = whisper_rs::convert_stereo_to_mono_audio(&audio)?;
} else if channels != 1 {
panic!(">2 channels unsupported");
}
if sample_rate != 16000 {
panic!("sample rate must be 16KHz");
}
// Run the model.
state.full(params, &audio[..]).expect("failed to run model");
// Create a file to write the transcript to.
let mut file = File::create("transcript.txt").expect("failed to create file");
// Iterate through the segments of the transcript.
let num_segments = state
.full_n_segments()
.expect("failed to get number of segments");
for i in 0..num_segments {
// Get the transcribed text and timestamps for the current segment.
let segment = state
.full_get_segment_text(i)
.expect("failed to get segment");
let start_timestamp = state
.full_get_segment_t0(i)
.expect("failed to get start timestamp");
let end_timestamp = state
.full_get_segment_t1(i)
.expect("failed to get end timestamp");
// Print the segment to stdout.
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
// Format the segment information as a string.
let line = format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment);
// Write the segment information to the file.
file.write_all(line.as_bytes())
.expect("failed to write to file");
}
Ok(())
}

View File

@@ -0,0 +1,66 @@
#! /bin/bash
main() {
set -euo pipefail
input_wav="$(realpath "$1")"
model="$(realpath "${2:-../models/ggml-small.en.bin}")"
already_transcribed="${3:-false}"
sanitized_wav="${input_wav%.*}.mono-16khz.wav"
ffmpeg -y -i "$input_wav" -ac 1 -ar 16k "$sanitized_wav"
if ! $already_transcribed; then
pushd "$(dirname "$(realpath "$BASH_SOURCE")")"
cd ../gitea-whisper-rs/
cargo run --example wav_subtitles -- "$model" "$sanitized_wav"
popd
fi
out_to_srt ../gitea-whisper-rs/transcript.txt > "${input_wav%.*}.srt"
ffmpeg -y \
-loop 1 -i sc.jpg \
-i "$input_wav" \
-i "${input_wav%.*}.srt" \
-c:v libx264 \
-tune stillimage \
-pix_fmt yuv420p -shortest \
"${input_wav%.*}.mkv"
ls "${input_wav%.*}.mkv"
}
out_to_srt() {
cs_to_ts() {
echo "$1" | awk '{
printf "%02d:%02d:%02d,000",
int(($1/100.0)/60/60),
int(($1/100.0)/60%60),
int(($1/100.0)%60)
}'
}
cat "$1" \
| (
i=0
while read -r line; do
((i+=1))
echo "$i"
echo "$(cs_to_ts "$(
echo "${line%%:] *}" \
| tr -d '[' \
| awk '{print $1}'
)") --> $(cs_to_ts "$(
echo "${line%%:] *}" \
| tr -d '[' \
| awk '{print $3}'
)")"
echo "${line#*: }"
echo
done
)
}
if [ "$0" == "$BASH_SOURCE" ]; then
main "$@"
fi