//go:build ai package main import ( "context" "fmt" "io" "net/http" "os" "path" "testing" "time" ) func TestAILocal(t *testing.T) { ctx, can := context.WithTimeout(context.Background(), time.Minute) defer can() d := os.TempDir() for k, u := range map[string]string{ "checkpoints": "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin", "tokenizer": "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin", } { func() { if _, err := os.Stat(path.Join(d, k)); os.IsNotExist(err) { t.Logf("downloading %s from %s", u, k) resp, err := http.Get(u) if err != nil { t.Fatal(err) } defer resp.Body.Close() f, err := os.Create(path.Join(d, k)) if err != nil { t.Fatal(err) } defer f.Close() if _, err := io.Copy(f, resp.Body); err != nil { f.Close() os.Remove(path.Join(d, k)) t.Fatal(err) } } }() } ai := NewAILocal( path.Join(d, "checkpoints"), path.Join(d, "tokenizer"), 0.9, 256, 0.9, ) t.Run("mvp", func(t *testing.T) { if result, err := ai.Do(ctx, "hello world"); err != nil { t.Fatal(err) } else if len(result) < 250 { t.Error(result) } else { t.Logf("%s", result) } }) t.Run("simulation", func(t *testing.T) { d := NewRAM() FillWithTestdata(ctx, d) s := NewStorage(d) threads, err := s.Threads(ctx) if err != nil || len(threads) < 1 { t.Fatal(err) } thread, err := s.Thread(ctx, threads[0]) if err != nil || len(thread) < 1 { t.Fatal(err) } input := fmt.Sprintf(` Summarize the following forum converstion. --- %s `, thread[0].Plaintext) t.Logf("\n\t%s", input) result, err := ai.Do(ctx, input) if err != nil { t.Fatal(err) } t.Logf("\n\t%s\n->\n\t%s", input, result) }) }