//go:build ai package main import ( "context" "fmt" "io" "net/http" "os" "path" "testing" "time" ) func TestAINoop(t *testing.T) { ai := NewAINoop() testAI(t, ai) } func TestAIOllama(t *testing.T) { ai := NewAIOllama("http://localhost:11434", "gemma:2b") testAI(t, ai) } func TestAILocal(t *testing.T) { d := os.TempDir() checkpoints := "checkpoints" tokenizer := "tokenizer" for u, p := range map[string]*string{ "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints, "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer, } { func() { *p = path.Base(u) if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) { t.Logf("downloading %s from %s", u, *p) resp, err := http.Get(u) if err != nil { t.Fatal(err) } defer resp.Body.Close() f, err := os.Create(path.Join(d, *p)) if err != nil { t.Fatal(err) } defer f.Close() if _, err := io.Copy(f, resp.Body); err != nil { f.Close() os.Remove(path.Join(d, *p)) t.Fatal(err) } } }() } ai := NewAILocal( path.Join(d, checkpoints), path.Join(d, tokenizer), 0.9, 256, 0.9, ) testAI(t, ai) } func testAI(t *testing.T, ai AI) { ctx, can := context.WithTimeout(context.Background(), time.Minute) defer can() t.Run("mvp", func(t *testing.T) { if result, err := ai.Do(ctx, "hello world"); err != nil { t.Fatal(err) } else if len(result) < 3 { t.Error(result) } else { t.Logf("%s", result) } }) t.Run("simulation", func(t *testing.T) { d := NewRAM() FillWithTestdata(ctx, d, renderAssetPattern, renderDatacenterPattern) s := NewStorage(d) threads, err := s.Threads(ctx) if err != nil || len(threads) < 1 { t.Fatal(err) } thread, err := s.Thread(ctx, threads[0]) if err != nil || len(thread) < 1 { t.Fatal(err) } input := fmt.Sprintf(` Summarize the following forum converstion. --- %s `, thread[0].Plaintext) t.Logf("\n\t%s", input) result, err := ai.Do(ctx, input) if err != nil { t.Fatal(err) } t.Logf("\n\t%s\n->\n\t%s", input, result) }) }