116 lines
2.1 KiB
Go
116 lines
2.1 KiB
Go
//go:build ai
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestAINoop(t *testing.T) {
|
|
ai := NewAINoop()
|
|
|
|
testAI(t, ai)
|
|
}
|
|
|
|
func TestAIOllama(t *testing.T) {
|
|
ai := NewAIOllama("http://localhost:11434", "gemma:2b")
|
|
|
|
testAI(t, ai)
|
|
}
|
|
|
|
func TestAILocal(t *testing.T) {
|
|
d := os.TempDir()
|
|
checkpoints := "checkpoints"
|
|
tokenizer := "tokenizer"
|
|
for u, p := range map[string]*string{
|
|
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints,
|
|
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer,
|
|
} {
|
|
func() {
|
|
*p = path.Base(u)
|
|
if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) {
|
|
t.Logf("downloading %s from %s", u, *p)
|
|
|
|
resp, err := http.Get(u)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
f, err := os.Create(path.Join(d, *p))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
if _, err := io.Copy(f, resp.Body); err != nil {
|
|
f.Close()
|
|
os.Remove(path.Join(d, *p))
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
ai := NewAILocal(
|
|
path.Join(d, checkpoints),
|
|
path.Join(d, tokenizer),
|
|
0.9,
|
|
256,
|
|
0.9,
|
|
)
|
|
|
|
testAI(t, ai)
|
|
}
|
|
|
|
func testAI(t *testing.T, ai AI) {
|
|
ctx, can := context.WithTimeout(context.Background(), time.Minute)
|
|
defer can()
|
|
|
|
t.Run("mvp", func(t *testing.T) {
|
|
if result, err := ai.Do(ctx, "hello world"); err != nil {
|
|
t.Fatal(err)
|
|
} else if len(result) < 3 {
|
|
t.Error(result)
|
|
} else {
|
|
t.Logf("%s", result)
|
|
}
|
|
})
|
|
|
|
t.Run("simulation", func(t *testing.T) {
|
|
d := NewRAM()
|
|
FillWithTestdata(ctx, d)
|
|
s := NewStorage(d)
|
|
|
|
threads, err := s.Threads(ctx)
|
|
if err != nil || len(threads) < 1 {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
thread, err := s.Thread(ctx, threads[0])
|
|
if err != nil || len(thread) < 1 {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
input := fmt.Sprintf(`
|
|
Summarize the following forum converstion.
|
|
---
|
|
%s
|
|
`, thread[0].Plaintext)
|
|
t.Logf("\n\t%s", input)
|
|
result, err := ai.Do(ctx, input)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Logf("\n\t%s\n->\n\t%s", input, result)
|
|
})
|
|
|
|
}
|