WE GOT AI

main
bel 2024-04-12 15:58:47 -06:00
parent a48518fea6
commit 046dc0e1ba
3 changed files with 38 additions and 4 deletions

3
ai.go
View File

@ -3,7 +3,6 @@ package main
import (
"bytes"
"context"
"errors"
"os"
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
@ -132,5 +131,5 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
}
out.Write([]byte("\n"))
return string(out.Bytes()), errors.New("not impl")
return string(out.Bytes()), nil
}

View File

@ -4,6 +4,10 @@ package main
import (
"context"
"io"
"net/http"
"os"
"path"
"testing"
"time"
)
@ -12,9 +16,39 @@ func TestAILocal(t *testing.T) {
ctx, can := context.WithTimeout(context.Background(), time.Minute)
defer can()
d := os.TempDir()
for k, u := range map[string]string{
"checkpoints": "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin",
"tokenizer": "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin",
} {
func() {
if _, err := os.Stat(path.Join(d, k)); os.IsNotExist(err) {
t.Logf("downloading %s from %s", u, k)
resp, err := http.Get(u)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
f, err := os.Create(path.Join(d, k))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := io.Copy(f, resp.Body); err != nil {
f.Close()
os.Remove(path.Join(d, k))
t.Fatal(err)
}
}
}()
}
ai := NewAILocal(
"checkpointPath",
"tokenizerPath",
path.Join(d, "checkpoints"),
path.Join(d, "tokenizer"),
0.9,
256,
0.9,

View File

@ -38,6 +38,7 @@ func (s Storage) ThreadsSince(ctx context.Context, t time.Time) ([]string, error
for k := range threads {
result = append(result, k)
}
sort.Strings(result)
return result, nil
}