57 lines
1.0 KiB
Go
57 lines
1.0 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"os"
|
|
|
|
"github.com/nikolaydubina/llama2.go/llama2"
|
|
)
|
|
|
|
type AI interface {
|
|
Do(context.Context, string) (string, error)
|
|
}
|
|
|
|
type AILocal struct {
|
|
checkpointPath string
|
|
tokenizerPath string
|
|
temperature float64
|
|
steps int
|
|
topp float64
|
|
}
|
|
|
|
func NewAILocal(
|
|
checkpointPath string,
|
|
tokenizerPath string,
|
|
temperature float64,
|
|
steps int,
|
|
topp float64,
|
|
) AILocal {
|
|
return AILocal{
|
|
checkpointPath: checkpointPath,
|
|
tokenizerPath: tokenizerPath,
|
|
temperature: temperature,
|
|
steps: steps,
|
|
topp: topp,
|
|
}
|
|
}
|
|
|
|
// https://github.com/nikolaydubina/llama2.go/blob/master/main.go
|
|
func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
|
|
checkpointFile, err := os.OpenFile(ai.checkpointPath, os.O_RDONLY, 0)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer checkpointFile.Close()
|
|
|
|
config, err := llama2.NewConfigFromCheckpoint(checkpointFile)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
out := bytes.NewBuffer(nil)
|
|
|
|
return string(out.Bytes()), errors.New("not impl")
|
|
}
|