fix newlines in ai

main
bel 2024-04-12 21:30:57 -06:00
parent 5f31a2c572
commit ecf29d54b8
2 changed files with 17 additions and 11 deletions

7
ai.go
View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"os" "os"
"strings"
nn "github.com/nikolaydubina/llama2.go/exp/nnfast" nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
"github.com/nikolaydubina/llama2.go/llama2" "github.com/nikolaydubina/llama2.go/llama2"
@ -13,6 +14,8 @@ type AI interface {
Do(context.Context, string) (string, error) Do(context.Context, string) (string, error)
} }
// TODO https://github.com/ollama/ollama/blob/main/server/routes.go#L153
type AILocal struct { type AILocal struct {
checkpointPath string checkpointPath string
tokenizerPath string tokenizerPath string
@ -73,7 +76,7 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
runState := llama2.NewRunState(config) runState := llama2.NewRunState(config)
promptTokens := vocab.Encode(prompt) promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>"))
out := bytes.NewBuffer(nil) out := bytes.NewBuffer(nil)
@ -131,5 +134,5 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
} }
out.Write([]byte("\n")) out.Write([]byte("\n"))
return string(out.Bytes()), nil return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil
} }

View File

@ -18,13 +18,16 @@ func TestAILocal(t *testing.T) {
defer can() defer can()
d := os.TempDir() d := os.TempDir()
for k, u := range map[string]string{ checkpoints := "checkpoints"
"checkpoints": "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin", tokenizer := "tokenizer"
"tokenizer": "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin", for u, p := range map[string]*string{
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints,
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer,
} { } {
func() { func() {
if _, err := os.Stat(path.Join(d, k)); os.IsNotExist(err) { *p = path.Base(u)
t.Logf("downloading %s from %s", u, k) if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) {
t.Logf("downloading %s from %s", u, *p)
resp, err := http.Get(u) resp, err := http.Get(u)
if err != nil { if err != nil {
@ -32,7 +35,7 @@ func TestAILocal(t *testing.T) {
} }
defer resp.Body.Close() defer resp.Body.Close()
f, err := os.Create(path.Join(d, k)) f, err := os.Create(path.Join(d, *p))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -40,7 +43,7 @@ func TestAILocal(t *testing.T) {
if _, err := io.Copy(f, resp.Body); err != nil { if _, err := io.Copy(f, resp.Body); err != nil {
f.Close() f.Close()
os.Remove(path.Join(d, k)) os.Remove(path.Join(d, *p))
t.Fatal(err) t.Fatal(err)
} }
} }
@ -48,8 +51,8 @@ func TestAILocal(t *testing.T) {
} }
ai := NewAILocal( ai := NewAILocal(
path.Join(d, "checkpoints"), path.Join(d, checkpoints),
path.Join(d, "tokenizer"), path.Join(d, tokenizer),
0.9, 0.9,
256, 256,
0.9, 0.9,