main
Bel LaPointe 2024-04-12 14:36:58 -06:00
parent 109899f9f8
commit 163f894f3a
4 changed files with 87 additions and 0 deletions

56
ai.go Normal file
View File

@ -0,0 +1,56 @@
package main
import (
"bytes"
"context"
"errors"
"os"
"github.com/nikolaydubina/llama2.go/llama2"
)
type AI interface {
Do(context.Context, string) (string, error)
}
type AILocal struct {
checkpointPath string
tokenizerPath string
temperature float64
steps int
topp float64
}
func NewAILocal(
checkpointPath string,
tokenizerPath string,
temperature float64,
steps int,
topp float64,
) AILocal {
return AILocal{
checkpointPath: checkpointPath,
tokenizerPath: tokenizerPath,
temperature: temperature,
steps: steps,
topp: topp,
}
}
// https://github.com/nikolaydubina/llama2.go/blob/master/main.go
func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
checkpointFile, err := os.OpenFile(ai.checkpointPath, os.O_RDONLY, 0)
if err != nil {
return "", err
}
defer checkpointFile.Close()
config, err := llama2.NewConfigFromCheckpoint(checkpointFile)
if err != nil {
return "", err
}
out := bytes.NewBuffer(nil)
return string(out.Bytes()), errors.New("not impl")
}

28
ai_test.go Normal file
View File

@ -0,0 +1,28 @@
//go:build ai
package main
import (
"context"
"testing"
"time"
)
func TestAILocal(t *testing.T) {
ctx, can := context.WithTimeout(context.Background(), time.Minute)
defer can()
ai := NewAILocal(
"checkpointPath",
"tokenizerPath",
0.9,
256,
0.9,
)
if result, err := ai.Do(ctx, "hello world"); err != nil {
t.Fatal(err)
} else {
t.Logf("%s", result)
}
}

1
go.mod
View File

@ -9,5 +9,6 @@ require (
require ( require (
github.com/lib/pq v1.10.9 // indirect github.com/lib/pq v1.10.9 // indirect
github.com/nikolaydubina/llama2.go v0.7.1 // indirect
golang.org/x/sys v0.4.0 // indirect golang.org/x/sys v0.4.0 // indirect
) )

2
go.sum
View File

@ -2,6 +2,8 @@ github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8b
github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE=
github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM=
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI= go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=