175 lines
4.1 KiB
Go
175 lines
4.1 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"os"
|
|
"strings"
|
|
|
|
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
|
"github.com/nikolaydubina/llama2.go/llama2"
|
|
"github.com/tmc/langchaingo/llms"
|
|
"github.com/tmc/langchaingo/llms/mistral"
|
|
"github.com/tmc/langchaingo/llms/ollama"
|
|
)
|
|
|
|
type AI interface {
|
|
Do(context.Context, string) (string, error)
|
|
}
|
|
|
|
type AIMistral struct {
|
|
model string
|
|
}
|
|
|
|
func NewAIMistral(model string) AIMistral {
|
|
return AIMistral{model: model}
|
|
}
|
|
|
|
func (ai AIMistral) Do(ctx context.Context, prompt string) (string, error) {
|
|
llm, err := mistral.New(mistral.WithModel(ai.model))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return llms.GenerateFromSinglePrompt(ctx, llm, prompt,
|
|
llms.WithTemperature(0.8),
|
|
llms.WithModel("mistral-small-latest"),
|
|
)
|
|
}
|
|
|
|
type AIOllama struct {
|
|
model string
|
|
}
|
|
|
|
func NewAIOllama(model string) AIOllama {
|
|
return AIOllama{model: model}
|
|
}
|
|
|
|
func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) {
|
|
llm, err := ollama.New(ollama.WithModel(ai.model))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return llms.GenerateFromSinglePrompt(ctx, llm, prompt)
|
|
}
|
|
|
|
type AILocal struct {
|
|
checkpointPath string
|
|
tokenizerPath string
|
|
temperature float64
|
|
steps int
|
|
topp float64
|
|
}
|
|
|
|
func NewAILocal(
|
|
checkpointPath string,
|
|
tokenizerPath string,
|
|
temperature float64,
|
|
steps int,
|
|
topp float64,
|
|
) AILocal {
|
|
return AILocal{
|
|
checkpointPath: checkpointPath,
|
|
tokenizerPath: tokenizerPath,
|
|
temperature: temperature,
|
|
steps: steps,
|
|
topp: topp,
|
|
}
|
|
}
|
|
|
|
// https://github.com/nikolaydubina/llama2.go/blob/master/main.go
|
|
func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
|
|
checkpointFile, err := os.OpenFile(ai.checkpointPath, os.O_RDONLY, 0)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer checkpointFile.Close()
|
|
|
|
config, err := llama2.NewConfigFromCheckpoint(checkpointFile)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
isSharedWeights := config.VocabSize > 0
|
|
if config.VocabSize < 0 {
|
|
config.VocabSize = -config.VocabSize
|
|
}
|
|
|
|
tokenizerFile, err := os.OpenFile(ai.tokenizerPath, os.O_RDONLY, 0)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer tokenizerFile.Close()
|
|
|
|
vocab := llama2.NewVocabFromFile(config.VocabSize, tokenizerFile)
|
|
|
|
w := llama2.NewTransformerWeightsFromCheckpoint(config, checkpointFile, isSharedWeights)
|
|
|
|
// right now we cannot run for more than config.SeqLen steps
|
|
steps := ai.steps
|
|
if steps <= 0 || steps > config.SeqLen {
|
|
steps = config.SeqLen
|
|
}
|
|
|
|
runState := llama2.NewRunState(config)
|
|
|
|
promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>"))
|
|
|
|
out := bytes.NewBuffer(nil)
|
|
|
|
// the current position we are in
|
|
var token int = 1 // 1 = BOS token in llama-2 sentencepiece
|
|
var pos = 0
|
|
|
|
for pos < steps {
|
|
// forward the transformer to get logits for the next token
|
|
llama2.Transformer(token, pos, config, runState, w)
|
|
|
|
var next int
|
|
if pos < len(promptTokens) {
|
|
next = promptTokens[pos]
|
|
} else {
|
|
// sample the next token
|
|
if ai.temperature == 0 {
|
|
// greedy argmax sampling
|
|
next = nn.ArgMax(runState.Logits)
|
|
} else {
|
|
// apply the temperature to the logits
|
|
for q := 0; q < config.VocabSize; q++ {
|
|
runState.Logits[q] /= float32(ai.temperature)
|
|
}
|
|
// apply softmax to the logits to the probabilities for next token
|
|
nn.SoftMax(runState.Logits)
|
|
// we now want to sample from this distribution to get the next token
|
|
if ai.topp <= 0 || ai.topp >= 1 {
|
|
// simply sample from the predicted probability distribution
|
|
next = nn.Sample(runState.Logits)
|
|
} else {
|
|
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
|
next = nn.SampleTopP(runState.Logits, float32(ai.topp))
|
|
}
|
|
}
|
|
}
|
|
pos++
|
|
|
|
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
|
if next == 1 {
|
|
break
|
|
}
|
|
|
|
// following BOS (1) token, sentencepiece decoder strips any leading whitespace
|
|
var tokenStr string
|
|
if token == 1 && vocab.Words[next][0] == ' ' {
|
|
tokenStr = vocab.Words[next][1:]
|
|
} else {
|
|
tokenStr = vocab.Words[next]
|
|
}
|
|
out.Write([]byte(tokenStr))
|
|
|
|
// advance forward
|
|
token = next
|
|
}
|
|
out.Write([]byte("\n"))
|
|
|
|
return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil
|
|
}
|