Compare commits

...

3 Commits

Author SHA1 Message Date
bel
d5b09db0c6 ai not lookin great 2024-04-12 23:18:17 -06:00
bel
f288a9b098 all of langchaingo just calls apis ffff 2024-04-12 22:44:36 -06:00
bel
ecf29d54b8 fix newlines in ai 2024-04-12 21:30:57 -06:00
5 changed files with 107 additions and 23 deletions

38
ai.go
View File

@@ -4,15 +4,49 @@ import (
"bytes" "bytes"
"context" "context"
"os" "os"
"strings"
nn "github.com/nikolaydubina/llama2.go/exp/nnfast" nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
"github.com/nikolaydubina/llama2.go/llama2" "github.com/nikolaydubina/llama2.go/llama2"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama"
) )
type AI interface { type AI interface {
Do(context.Context, string) (string, error) Do(context.Context, string) (string, error)
} }
type AINoop struct {
}
func NewAINoop() AINoop {
return AINoop{}
}
func (ai AINoop) Do(ctx context.Context, prompt string) (string, error) {
return ":shrug:", nil
}
type AIOllama struct {
model string
url string
}
func NewAIOllama(url, model string) AIOllama {
return AIOllama{url: url, model: model}
}
func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) {
llm, err := ollama.New(
ollama.WithModel(ai.model),
ollama.WithServerURL(ai.url),
)
if err != nil {
return "", err
}
return llms.GenerateFromSinglePrompt(ctx, llm, prompt)
}
type AILocal struct { type AILocal struct {
checkpointPath string checkpointPath string
tokenizerPath string tokenizerPath string
@@ -73,7 +107,7 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
runState := llama2.NewRunState(config) runState := llama2.NewRunState(config)
promptTokens := vocab.Encode(prompt) promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>"))
out := bytes.NewBuffer(nil) out := bytes.NewBuffer(nil)
@@ -131,5 +165,5 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
} }
out.Write([]byte("\n")) out.Write([]byte("\n"))
return string(out.Bytes()), nil return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil
} }

View File

@@ -13,18 +13,30 @@ import (
"time" "time"
) )
func TestAILocal(t *testing.T) { func TestAINoop(t *testing.T) {
ctx, can := context.WithTimeout(context.Background(), time.Minute) ai := NewAINoop()
defer can()
testAI(t, ai)
}
func TestAIOllama(t *testing.T) {
ai := NewAIOllama("http://localhost:11434", "gemma:2b")
testAI(t, ai)
}
func TestAILocal(t *testing.T) {
d := os.TempDir() d := os.TempDir()
for k, u := range map[string]string{ checkpoints := "checkpoints"
"checkpoints": "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin", tokenizer := "tokenizer"
"tokenizer": "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin", for u, p := range map[string]*string{
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints,
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer,
} { } {
func() { func() {
if _, err := os.Stat(path.Join(d, k)); os.IsNotExist(err) { *p = path.Base(u)
t.Logf("downloading %s from %s", u, k) if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) {
t.Logf("downloading %s from %s", u, *p)
resp, err := http.Get(u) resp, err := http.Get(u)
if err != nil { if err != nil {
@@ -32,7 +44,7 @@ func TestAILocal(t *testing.T) {
} }
defer resp.Body.Close() defer resp.Body.Close()
f, err := os.Create(path.Join(d, k)) f, err := os.Create(path.Join(d, *p))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -40,7 +52,7 @@ func TestAILocal(t *testing.T) {
if _, err := io.Copy(f, resp.Body); err != nil { if _, err := io.Copy(f, resp.Body); err != nil {
f.Close() f.Close()
os.Remove(path.Join(d, k)) os.Remove(path.Join(d, *p))
t.Fatal(err) t.Fatal(err)
} }
} }
@@ -48,17 +60,24 @@ func TestAILocal(t *testing.T) {
} }
ai := NewAILocal( ai := NewAILocal(
path.Join(d, "checkpoints"), path.Join(d, checkpoints),
path.Join(d, "tokenizer"), path.Join(d, tokenizer),
0.9, 0.9,
256, 256,
0.9, 0.9,
) )
testAI(t, ai)
}
func testAI(t *testing.T, ai AI) {
ctx, can := context.WithTimeout(context.Background(), time.Minute)
defer can()
t.Run("mvp", func(t *testing.T) { t.Run("mvp", func(t *testing.T) {
if result, err := ai.Do(ctx, "hello world"); err != nil { if result, err := ai.Do(ctx, "hello world"); err != nil {
t.Fatal(err) t.Fatal(err)
} else if len(result) < 250 { } else if len(result) < 3 {
t.Error(result) t.Error(result)
} else { } else {
t.Logf("%s", result) t.Logf("%s", result)

View File

@@ -21,9 +21,14 @@ type Config struct {
BasicAuthUser string BasicAuthUser string
BasicAuthPassword string BasicAuthPassword string
FillWithTestdata bool FillWithTestdata bool
OllamaURL string
OllamaModel string
LocalCheckpoint string
LocalTokenizer string
storage Storage storage Storage
queue Queue queue Queue
driver Driver driver Driver
ai AI
} }
func newConfig(ctx context.Context) (Config, error) { func newConfig(ctx context.Context) (Config, error) {
@@ -32,7 +37,8 @@ func newConfig(ctx context.Context) (Config, error) {
func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) { func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) {
def := Config{ def := Config{
Port: 8080, Port: 8080,
OllamaModel: "gemma:2b",
} }
var m map[string]any var m map[string]any
@@ -104,5 +110,13 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
result.storage = NewStorage(result.driver) result.storage = NewStorage(result.driver)
result.queue = NewQueue(result.driver) result.queue = NewQueue(result.driver)
if result.OllamaURL != "" {
result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel)
} else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" {
result.ai = NewAILocal(result.LocalCheckpoint, result.LocalTokenizer, 0.9, 128, 0.9)
} else {
result.ai = NewAINoop()
}
return result, nil return result, nil
} }

9
go.mod
View File

@@ -6,7 +6,14 @@ require (
github.com/go-errors/errors v1.5.1 github.com/go-errors/errors v1.5.1
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/nikolaydubina/llama2.go v0.7.1 github.com/nikolaydubina/llama2.go v0.7.1
github.com/tmc/langchaingo v0.1.8
go.etcd.io/bbolt v1.3.9 go.etcd.io/bbolt v1.3.9
) )
require golang.org/x/sys v0.4.0 // indirect require (
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gage-technologies/mistral-go v1.0.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
golang.org/x/sys v0.16.0 // indirect
)

22
go.sum
View File

@@ -1,20 +1,30 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/gage-technologies/mistral-go v1.0.0 h1:Hwk0uJO+Iq4kMX/EwbfGRUq9zkO36w7HZ/g53N4N73A=
github.com/gage-technologies/mistral-go v1.0.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk= github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk=
github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE= github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE=
github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM= github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tmc/langchaingo v0.1.8 h1:nrImgh0aWdu3stJTHz80N60WGwPWY8HXCK10gQny7bA=
github.com/tmc/langchaingo v0.1.8/go.mod h1:iNBfS9e6jxBKsJSPWnlqNhoVWgdA3D1g5cdFJjbIZNQ=
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI= go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=