spoc-bot-vr/.ai.go

170 lines
3.9 KiB
Go

package main
import (
"bytes"
"context"
"os"
"strings"
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
"github.com/nikolaydubina/llama2.go/llama2"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama"
)
type AI interface {
Do(context.Context, string) (string, error)
}
type AINoop struct {
}
func NewAINoop() AINoop {
return AINoop{}
}
func (ai AINoop) Do(ctx context.Context, prompt string) (string, error) {
return ":shrug:", nil
}
type AIOllama struct {
model string
url string
}
func NewAIOllama(url, model string) AIOllama {
return AIOllama{url: url, model: model}
}
func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) {
llm, err := ollama.New(
ollama.WithModel(ai.model),
ollama.WithServerURL(ai.url),
)
if err != nil {
return "", err
}
return llms.GenerateFromSinglePrompt(ctx, llm, prompt)
}
type AILocal struct {
checkpointPath string
tokenizerPath string
temperature float64
steps int
topp float64
}
func NewAILocal(
checkpointPath string,
tokenizerPath string,
temperature float64,
steps int,
topp float64,
) AILocal {
return AILocal{
checkpointPath: checkpointPath,
tokenizerPath: tokenizerPath,
temperature: temperature,
steps: steps,
topp: topp,
}
}
// https://github.com/nikolaydubina/llama2.go/blob/master/main.go
func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
checkpointFile, err := os.OpenFile(ai.checkpointPath, os.O_RDONLY, 0)
if err != nil {
return "", err
}
defer checkpointFile.Close()
config, err := llama2.NewConfigFromCheckpoint(checkpointFile)
if err != nil {
return "", err
}
isSharedWeights := config.VocabSize > 0
if config.VocabSize < 0 {
config.VocabSize = -config.VocabSize
}
tokenizerFile, err := os.OpenFile(ai.tokenizerPath, os.O_RDONLY, 0)
if err != nil {
return "", err
}
defer tokenizerFile.Close()
vocab := llama2.NewVocabFromFile(config.VocabSize, tokenizerFile)
w := llama2.NewTransformerWeightsFromCheckpoint(config, checkpointFile, isSharedWeights)
// right now we cannot run for more than config.SeqLen steps
steps := ai.steps
if steps <= 0 || steps > config.SeqLen {
steps = config.SeqLen
}
runState := llama2.NewRunState(config)
promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>"))
out := bytes.NewBuffer(nil)
// the current position we are in
var token int = 1 // 1 = BOS token in llama-2 sentencepiece
var pos = 0
for pos < steps {
// forward the transformer to get logits for the next token
llama2.Transformer(token, pos, config, runState, w)
var next int
if pos < len(promptTokens) {
next = promptTokens[pos]
} else {
// sample the next token
if ai.temperature == 0 {
// greedy argmax sampling
next = nn.ArgMax(runState.Logits)
} else {
// apply the temperature to the logits
for q := 0; q < config.VocabSize; q++ {
runState.Logits[q] /= float32(ai.temperature)
}
// apply softmax to the logits to the probabilities for next token
nn.SoftMax(runState.Logits)
// we now want to sample from this distribution to get the next token
if ai.topp <= 0 || ai.topp >= 1 {
// simply sample from the predicted probability distribution
next = nn.Sample(runState.Logits)
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = nn.SampleTopP(runState.Logits, float32(ai.topp))
}
}
}
pos++
// data-dependent terminating condition: the BOS (1) token delimits sequences
if next == 1 {
break
}
// following BOS (1) token, sentencepiece decoder strips any leading whitespace
var tokenStr string
if token == 1 && vocab.Words[next][0] == ' ' {
tokenStr = vocab.Words[next][1:]
} else {
tokenStr = vocab.Words[next]
}
out.Write([]byte(tokenStr))
// advance forward
token = next
}
out.Write([]byte("\n"))
return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil
}