spoc-bot-vr/ai_test.go

116 lines
2.1 KiB
Go

//go:build ai
package main
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path"
"testing"
"time"
)
func TestAIMistral(t *testing.T) {
ai := NewAIMistral("open-mistral-7b")
testAI(t, ai)
}
func TestAIOllama(t *testing.T) {
ai := NewAIOllama("gemma:2b")
testAI(t, ai)
}
func TestAILocal(t *testing.T) {
d := os.TempDir()
checkpoints := "checkpoints"
tokenizer := "tokenizer"
for u, p := range map[string]*string{
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints,
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer,
} {
func() {
*p = path.Base(u)
if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) {
t.Logf("downloading %s from %s", u, *p)
resp, err := http.Get(u)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
f, err := os.Create(path.Join(d, *p))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := io.Copy(f, resp.Body); err != nil {
f.Close()
os.Remove(path.Join(d, *p))
t.Fatal(err)
}
}
}()
}
ai := NewAILocal(
path.Join(d, checkpoints),
path.Join(d, tokenizer),
0.9,
256,
0.9,
)
testAI(t, ai)
}
func testAI(t *testing.T, ai AI) {
ctx, can := context.WithTimeout(context.Background(), time.Minute)
defer can()
t.Run("mvp", func(t *testing.T) {
if result, err := ai.Do(ctx, "hello world"); err != nil {
t.Fatal(err)
} else if len(result) < 250 {
t.Error(result)
} else {
t.Logf("%s", result)
}
})
t.Run("simulation", func(t *testing.T) {
d := NewRAM()
FillWithTestdata(ctx, d)
s := NewStorage(d)
threads, err := s.Threads(ctx)
if err != nil || len(threads) < 1 {
t.Fatal(err)
}
thread, err := s.Thread(ctx, threads[0])
if err != nil || len(thread) < 1 {
t.Fatal(err)
}
input := fmt.Sprintf(`
Summarize the following forum converstion.
---
%s
`, thread[0].Plaintext)
t.Logf("\n\t%s", input)
result, err := ai.Do(ctx, input)
if err != nil {
t.Fatal(err)
}
t.Logf("\n\t%s\n->\n\t%s", input, result)
})
}