wip ai
parent
109899f9f8
commit
163f894f3a
|
|
@ -0,0 +1,56 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/nikolaydubina/llama2.go/llama2"
|
||||
)
|
||||
|
||||
type AI interface {
|
||||
Do(context.Context, string) (string, error)
|
||||
}
|
||||
|
||||
type AILocal struct {
|
||||
checkpointPath string
|
||||
tokenizerPath string
|
||||
temperature float64
|
||||
steps int
|
||||
topp float64
|
||||
}
|
||||
|
||||
func NewAILocal(
|
||||
checkpointPath string,
|
||||
tokenizerPath string,
|
||||
temperature float64,
|
||||
steps int,
|
||||
topp float64,
|
||||
) AILocal {
|
||||
return AILocal{
|
||||
checkpointPath: checkpointPath,
|
||||
tokenizerPath: tokenizerPath,
|
||||
temperature: temperature,
|
||||
steps: steps,
|
||||
topp: topp,
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/nikolaydubina/llama2.go/blob/master/main.go
|
||||
func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
|
||||
checkpointFile, err := os.OpenFile(ai.checkpointPath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer checkpointFile.Close()
|
||||
|
||||
config, err := llama2.NewConfigFromCheckpoint(checkpointFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out := bytes.NewBuffer(nil)
|
||||
|
||||
return string(out.Bytes()), errors.New("not impl")
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
//go:build ai
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAILocal(t *testing.T) {
|
||||
ctx, can := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer can()
|
||||
|
||||
ai := NewAILocal(
|
||||
"checkpointPath",
|
||||
"tokenizerPath",
|
||||
0.9,
|
||||
256,
|
||||
0.9,
|
||||
)
|
||||
|
||||
if result, err := ai.Do(ctx, "hello world"); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
t.Logf("%s", result)
|
||||
}
|
||||
}
|
||||
1
go.mod
1
go.mod
|
|
@ -9,5 +9,6 @@ require (
|
|||
|
||||
require (
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/nikolaydubina/llama2.go v0.7.1 // indirect
|
||||
golang.org/x/sys v0.4.0 // indirect
|
||||
)
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -2,6 +2,8 @@ github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8b
|
|||
github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE=
|
||||
github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM=
|
||||
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
|
||||
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
|
||||
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
|
||||
|
|
|
|||
Loading…
Reference in New Issue