From f288a9b098b65da9ce9a7f60858f97031af92144 Mon Sep 17 00:00:00 2001 From: bel Date: Fri, 12 Apr 2024 22:44:36 -0600 Subject: [PATCH] all of langchaingo just calls apis ffff --- ai.go | 38 +++++++++++++++++++++++++++++++++++++- ai_test.go | 22 +++++++++++++++++++--- go.mod | 9 ++++++++- go.sum | 22 ++++++++++++++++------ 4 files changed, 80 insertions(+), 11 deletions(-) diff --git a/ai.go b/ai.go index c99e2f3..6f0d51b 100644 --- a/ai.go +++ b/ai.go @@ -8,13 +8,49 @@ import ( nn "github.com/nikolaydubina/llama2.go/exp/nnfast" "github.com/nikolaydubina/llama2.go/llama2" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/mistral" + "github.com/tmc/langchaingo/llms/ollama" ) type AI interface { Do(context.Context, string) (string, error) } -// TODO https://github.com/ollama/ollama/blob/main/server/routes.go#L153 +type AIMistral struct { + model string +} + +func NewAIMistral(model string) AIMistral { + return AIMistral{model: model} +} + +func (ai AIMistral) Do(ctx context.Context, prompt string) (string, error) { + llm, err := mistral.New(mistral.WithModel(ai.model)) + if err != nil { + return "", err + } + return llms.GenerateFromSinglePrompt(ctx, llm, prompt, + llms.WithTemperature(0.8), + llms.WithModel("mistral-small-latest"), + ) +} + +type AIOllama struct { + model string +} + +func NewAIOllama(model string) AIOllama { + return AIOllama{model: model} +} + +func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) { + llm, err := ollama.New(ollama.WithModel(ai.model)) + if err != nil { + return "", err + } + return llms.GenerateFromSinglePrompt(ctx, llm, prompt) +} type AILocal struct { checkpointPath string diff --git a/ai_test.go b/ai_test.go index 484a7e3..4fba4c1 100644 --- a/ai_test.go +++ b/ai_test.go @@ -13,10 +13,19 @@ import ( "time" ) -func TestAILocal(t *testing.T) { - ctx, can := context.WithTimeout(context.Background(), time.Minute) - defer can() +func TestAIMistral(t *testing.T) { + ai := NewAIMistral("open-mistral-7b") + testAI(t, ai) +} + +func TestAIOllama(t *testing.T) { + ai := NewAIOllama("gemma:2b") + + testAI(t, ai) +} + +func TestAILocal(t *testing.T) { d := os.TempDir() checkpoints := "checkpoints" tokenizer := "tokenizer" @@ -58,6 +67,13 @@ func TestAILocal(t *testing.T) { 0.9, ) + testAI(t, ai) +} + +func testAI(t *testing.T, ai AI) { + ctx, can := context.WithTimeout(context.Background(), time.Minute) + defer can() + t.Run("mvp", func(t *testing.T) { if result, err := ai.Do(ctx, "hello world"); err != nil { t.Fatal(err) diff --git a/go.mod b/go.mod index f48b54b..9c3fb93 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,14 @@ require ( github.com/go-errors/errors v1.5.1 github.com/lib/pq v1.10.9 github.com/nikolaydubina/llama2.go v0.7.1 + github.com/tmc/langchaingo v0.1.8 go.etcd.io/bbolt v1.3.9 ) -require golang.org/x/sys v0.4.0 // indirect +require ( + github.com/dlclark/regexp2 v1.10.0 // indirect + github.com/gage-technologies/mistral-go v1.0.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/pkoukk/tiktoken-go v0.1.6 // indirect + golang.org/x/sys v0.16.0 // indirect +) diff --git a/go.sum b/go.sum index a496fba..b238cea 100644 --- a/go.sum +++ b/go.sum @@ -1,20 +1,30 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/gage-technologies/mistral-go v1.0.0 h1:Hwk0uJO+Iq4kMX/EwbfGRUq9zkO36w7HZ/g53N4N73A= +github.com/gage-technologies/mistral-go v1.0.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28= github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk= github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE= github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM= +github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= +github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tmc/langchaingo v0.1.8 h1:nrImgh0aWdu3stJTHz80N60WGwPWY8HXCK10gQny7bA= +github.com/tmc/langchaingo v0.1.8/go.mod h1:iNBfS9e6jxBKsJSPWnlqNhoVWgdA3D1g5cdFJjbIZNQ= go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI= go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=