From 163f894f3a94b302770ae6944880c99dd80accc7 Mon Sep 17 00:00:00 2001 From: Bel LaPointe <153096461+breel-render@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:36:58 -0600 Subject: [PATCH] wip ai --- ai.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ ai_test.go | 28 +++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 ++ 4 files changed, 87 insertions(+) create mode 100644 ai.go create mode 100644 ai_test.go diff --git a/ai.go b/ai.go new file mode 100644 index 0000000..9a018d1 --- /dev/null +++ b/ai.go @@ -0,0 +1,56 @@ +package main + +import ( + "bytes" + "context" + "errors" + "os" + + "github.com/nikolaydubina/llama2.go/llama2" +) + +type AI interface { + Do(context.Context, string) (string, error) +} + +type AILocal struct { + checkpointPath string + tokenizerPath string + temperature float64 + steps int + topp float64 +} + +func NewAILocal( + checkpointPath string, + tokenizerPath string, + temperature float64, + steps int, + topp float64, +) AILocal { + return AILocal{ + checkpointPath: checkpointPath, + tokenizerPath: tokenizerPath, + temperature: temperature, + steps: steps, + topp: topp, + } +} + +// https://github.com/nikolaydubina/llama2.go/blob/master/main.go +func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) { + checkpointFile, err := os.OpenFile(ai.checkpointPath, os.O_RDONLY, 0) + if err != nil { + return "", err + } + defer checkpointFile.Close() + + config, err := llama2.NewConfigFromCheckpoint(checkpointFile) + if err != nil { + return "", err + } + + out := bytes.NewBuffer(nil) + + return string(out.Bytes()), errors.New("not impl") +} diff --git a/ai_test.go b/ai_test.go new file mode 100644 index 0000000..b7a7d85 --- /dev/null +++ b/ai_test.go @@ -0,0 +1,28 @@ +//go:build ai + +package main + +import ( + "context" + "testing" + "time" +) + +func TestAILocal(t *testing.T) { + ctx, can := context.WithTimeout(context.Background(), time.Minute) + defer can() + + ai := NewAILocal( + "checkpointPath", + "tokenizerPath", + 0.9, + 256, + 0.9, + ) + + if result, err := ai.Do(ctx, "hello world"); err != nil { + t.Fatal(err) + } else { + t.Logf("%s", result) + } +} diff --git a/go.mod b/go.mod index 5b55d72..4f2035a 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,6 @@ require ( require ( github.com/lib/pq v1.10.9 // indirect + github.com/nikolaydubina/llama2.go v0.7.1 // indirect golang.org/x/sys v0.4.0 // indirect ) diff --git a/go.sum b/go.sum index ddfc625..da4a17a 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8b github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE= +github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM= go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI= go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=