o cgo is super inefficient fuck ok we rust

master
bel 2023-03-30 22:16:53 -06:00
parent 12f824a9a7
commit 7e93939f3a
1 changed files with 46 additions and 27 deletions

View File

@ -1,35 +1,25 @@
package main package main
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"log"
"net/http"
"os" "os"
"strings"
"time"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav" "github.com/go-audio/wav"
) )
func main() { func main() {
p := os.Args[1] modelName := "small.en"
f, err := os.Open(p) if v := os.Getenv("MODEL"); v != "" {
if err != nil { modelName = v
panic(err)
} }
defer f.Close() model, err := whisper.New("./models/ggml-" + modelName + ".bin")
var data []float32
dec := wav.NewDecoder(f)
if buf, err := dec.FullPCMBuffer(); err != nil {
panic(err)
} else if dec.SampleRate != whisper.SampleRate {
panic(dec.SampleRate)
} else if dec.NumChans != 1 {
panic(dec.NumChans)
} else {
data = buf.AsFloat32Buffer().Data
}
model, err := whisper.New("./models/ggml-small.en.bin")
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -39,22 +29,51 @@ func main() {
} }
context.SetThreads(4) context.SetThreads(4)
context.ResetTimings() if err := http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := context.Process(data, func(segment whisper.Segment) { b, _ := io.ReadAll(r.Body)
//log.Printf("%+v", segment) if result, err := transcribe(context, bytes.NewReader(b)); err != nil {
}); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
w.Write([]byte(result))
}
})); err != nil {
panic(err) panic(err)
} }
context.PrintTimings() }
func transcribe(context whisper.Context, r io.ReadSeeker) (string, error) {
start := time.Now()
defer func() {
if os.Getenv("DEBUG") == "true" {
log.Printf("%0.1f to transcribe", time.Since(start).Seconds())
}
}()
var data []float32
dec := wav.NewDecoder(r)
if buf, err := dec.FullPCMBuffer(); err != nil {
return "", err
} else if dec.SampleRate != whisper.SampleRate {
return "", fmt.Errorf("sample rate %v != %v", dec.SampleRate, whisper.SampleRate)
} else if dec.NumChans != 1 {
return "", fmt.Errorf("chans %v != %v", dec.NumChans, 1)
} else {
data = buf.AsFloat32Buffer().Data
}
if err := context.Process(data, nil); err != nil {
return "", err
}
result := []string{}
for { for {
segment, err := context.NextSegment() segment, err := context.NextSegment()
if err == io.EOF { if err == io.EOF {
break break
} else if err != nil { } else if err != nil {
panic(err) return "", err
} }
fmt.Printf("%s ", segment.Text) result = append(result, segment.Text)
} }
fmt.Printf("\n") return strings.Join(result, " "), nil
} }