From 046dc0e1ba14fb636731a7868e786399269b11a8 Mon Sep 17 00:00:00 2001 From: bel Date: Fri, 12 Apr 2024 15:58:47 -0600 Subject: [PATCH] WE GOT AI --- ai.go | 3 +-- ai_test.go | 38 ++++++++++++++++++++++++++++++++++++-- storage.go | 1 + 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/ai.go b/ai.go index 081147e..5c17116 100644 --- a/ai.go +++ b/ai.go @@ -3,7 +3,6 @@ package main import ( "bytes" "context" - "errors" "os" nn "github.com/nikolaydubina/llama2.go/exp/nnfast" @@ -132,5 +131,5 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) { } out.Write([]byte("\n")) - return string(out.Bytes()), errors.New("not impl") + return string(out.Bytes()), nil } diff --git a/ai_test.go b/ai_test.go index b7a7d85..7fece21 100644 --- a/ai_test.go +++ b/ai_test.go @@ -4,6 +4,10 @@ package main import ( "context" + "io" + "net/http" + "os" + "path" "testing" "time" ) @@ -12,9 +16,39 @@ func TestAILocal(t *testing.T) { ctx, can := context.WithTimeout(context.Background(), time.Minute) defer can() + d := os.TempDir() + for k, u := range map[string]string{ + "checkpoints": "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin", + "tokenizer": "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin", + } { + func() { + if _, err := os.Stat(path.Join(d, k)); os.IsNotExist(err) { + t.Logf("downloading %s from %s", u, k) + + resp, err := http.Get(u) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + f, err := os.Create(path.Join(d, k)) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if _, err := io.Copy(f, resp.Body); err != nil { + f.Close() + os.Remove(path.Join(d, k)) + t.Fatal(err) + } + } + }() + } + ai := NewAILocal( - "checkpointPath", - "tokenizerPath", + path.Join(d, "checkpoints"), + path.Join(d, "tokenizer"), 0.9, 256, 0.9, diff --git a/storage.go b/storage.go index 96bec5e..5442a3c 100644 --- a/storage.go +++ b/storage.go @@ -38,6 +38,7 @@ func (s Storage) ThreadsSince(ctx context.Context, t time.Time) ([]string, error for k := range threads { result = append(result, k) } + sort.Strings(result) return result, nil }