fix newlines in ai
parent
5f31a2c572
commit
ecf29d54b8
7
ai.go
7
ai.go
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
||||||
"github.com/nikolaydubina/llama2.go/llama2"
|
"github.com/nikolaydubina/llama2.go/llama2"
|
||||||
|
|
@ -13,6 +14,8 @@ type AI interface {
|
||||||
Do(context.Context, string) (string, error)
|
Do(context.Context, string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO https://github.com/ollama/ollama/blob/main/server/routes.go#L153
|
||||||
|
|
||||||
type AILocal struct {
|
type AILocal struct {
|
||||||
checkpointPath string
|
checkpointPath string
|
||||||
tokenizerPath string
|
tokenizerPath string
|
||||||
|
|
@ -73,7 +76,7 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
|
||||||
|
|
||||||
runState := llama2.NewRunState(config)
|
runState := llama2.NewRunState(config)
|
||||||
|
|
||||||
promptTokens := vocab.Encode(prompt)
|
promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>"))
|
||||||
|
|
||||||
out := bytes.NewBuffer(nil)
|
out := bytes.NewBuffer(nil)
|
||||||
|
|
||||||
|
|
@ -131,5 +134,5 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
|
||||||
}
|
}
|
||||||
out.Write([]byte("\n"))
|
out.Write([]byte("\n"))
|
||||||
|
|
||||||
return string(out.Bytes()), nil
|
return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
21
ai_test.go
21
ai_test.go
|
|
@ -18,13 +18,16 @@ func TestAILocal(t *testing.T) {
|
||||||
defer can()
|
defer can()
|
||||||
|
|
||||||
d := os.TempDir()
|
d := os.TempDir()
|
||||||
for k, u := range map[string]string{
|
checkpoints := "checkpoints"
|
||||||
"checkpoints": "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin",
|
tokenizer := "tokenizer"
|
||||||
"tokenizer": "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin",
|
for u, p := range map[string]*string{
|
||||||
|
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints,
|
||||||
|
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer,
|
||||||
} {
|
} {
|
||||||
func() {
|
func() {
|
||||||
if _, err := os.Stat(path.Join(d, k)); os.IsNotExist(err) {
|
*p = path.Base(u)
|
||||||
t.Logf("downloading %s from %s", u, k)
|
if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) {
|
||||||
|
t.Logf("downloading %s from %s", u, *p)
|
||||||
|
|
||||||
resp, err := http.Get(u)
|
resp, err := http.Get(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -32,7 +35,7 @@ func TestAILocal(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
f, err := os.Create(path.Join(d, k))
|
f, err := os.Create(path.Join(d, *p))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -40,7 +43,7 @@ func TestAILocal(t *testing.T) {
|
||||||
|
|
||||||
if _, err := io.Copy(f, resp.Body); err != nil {
|
if _, err := io.Copy(f, resp.Body); err != nil {
|
||||||
f.Close()
|
f.Close()
|
||||||
os.Remove(path.Join(d, k))
|
os.Remove(path.Join(d, *p))
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -48,8 +51,8 @@ func TestAILocal(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ai := NewAILocal(
|
ai := NewAILocal(
|
||||||
path.Join(d, "checkpoints"),
|
path.Join(d, checkpoints),
|
||||||
path.Join(d, "tokenizer"),
|
path.Join(d, tokenizer),
|
||||||
0.9,
|
0.9,
|
||||||
256,
|
256,
|
||||||
0.9,
|
0.9,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue