From ecf29d54b869ecbe49fcbf50c5fd5d033b277fac Mon Sep 17 00:00:00 2001 From: bel Date: Fri, 12 Apr 2024 21:30:57 -0600 Subject: [PATCH] fix newlines in ai --- ai.go | 7 +++++-- ai_test.go | 21 ++++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/ai.go b/ai.go index 5c17116..c99e2f3 100644 --- a/ai.go +++ b/ai.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "os" + "strings" nn "github.com/nikolaydubina/llama2.go/exp/nnfast" "github.com/nikolaydubina/llama2.go/llama2" @@ -13,6 +14,8 @@ type AI interface { Do(context.Context, string) (string, error) } +// TODO https://github.com/ollama/ollama/blob/main/server/routes.go#L153 + type AILocal struct { checkpointPath string tokenizerPath string @@ -73,7 +76,7 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) { runState := llama2.NewRunState(config) - promptTokens := vocab.Encode(prompt) + promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>")) out := bytes.NewBuffer(nil) @@ -131,5 +134,5 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) { } out.Write([]byte("\n")) - return string(out.Bytes()), nil + return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil } diff --git a/ai_test.go b/ai_test.go index 8da2b79..484a7e3 100644 --- a/ai_test.go +++ b/ai_test.go @@ -18,13 +18,16 @@ func TestAILocal(t *testing.T) { defer can() d := os.TempDir() - for k, u := range map[string]string{ - "checkpoints": "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin", - "tokenizer": "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin", + checkpoints := "checkpoints" + tokenizer := "tokenizer" + for u, p := range map[string]*string{ + "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints, + "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer, } { func() { - if _, err := os.Stat(path.Join(d, k)); os.IsNotExist(err) { - t.Logf("downloading %s from %s", u, k) + *p = path.Base(u) + if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) { + t.Logf("downloading %s from %s", u, *p) resp, err := http.Get(u) if err != nil { @@ -32,7 +35,7 @@ func TestAILocal(t *testing.T) { } defer resp.Body.Close() - f, err := os.Create(path.Join(d, k)) + f, err := os.Create(path.Join(d, *p)) if err != nil { t.Fatal(err) } @@ -40,7 +43,7 @@ func TestAILocal(t *testing.T) { if _, err := io.Copy(f, resp.Body); err != nil { f.Close() - os.Remove(path.Join(d, k)) + os.Remove(path.Join(d, *p)) t.Fatal(err) } } @@ -48,8 +51,8 @@ func TestAILocal(t *testing.T) { } ai := NewAILocal( - path.Join(d, "checkpoints"), - path.Join(d, "tokenizer"), + path.Join(d, checkpoints), + path.Join(d, tokenizer), 0.9, 256, 0.9,