package main import ( "bytes" "context" "os" "strings" nn "github.com/nikolaydubina/llama2.go/exp/nnfast" "github.com/nikolaydubina/llama2.go/llama2" ) type AI interface { Do(context.Context, string) (string, error) } // TODO https://github.com/ollama/ollama/blob/main/server/routes.go#L153 type AILocal struct { checkpointPath string tokenizerPath string temperature float64 steps int topp float64 } func NewAILocal( checkpointPath string, tokenizerPath string, temperature float64, steps int, topp float64, ) AILocal { return AILocal{ checkpointPath: checkpointPath, tokenizerPath: tokenizerPath, temperature: temperature, steps: steps, topp: topp, } } // https://github.com/nikolaydubina/llama2.go/blob/master/main.go func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) { checkpointFile, err := os.OpenFile(ai.checkpointPath, os.O_RDONLY, 0) if err != nil { return "", err } defer checkpointFile.Close() config, err := llama2.NewConfigFromCheckpoint(checkpointFile) if err != nil { return "", err } isSharedWeights := config.VocabSize > 0 if config.VocabSize < 0 { config.VocabSize = -config.VocabSize } tokenizerFile, err := os.OpenFile(ai.tokenizerPath, os.O_RDONLY, 0) if err != nil { return "", err } defer tokenizerFile.Close() vocab := llama2.NewVocabFromFile(config.VocabSize, tokenizerFile) w := llama2.NewTransformerWeightsFromCheckpoint(config, checkpointFile, isSharedWeights) // right now we cannot run for more than config.SeqLen steps steps := ai.steps if steps <= 0 || steps > config.SeqLen { steps = config.SeqLen } runState := llama2.NewRunState(config) promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>")) out := bytes.NewBuffer(nil) // the current position we are in var token int = 1 // 1 = BOS token in llama-2 sentencepiece var pos = 0 for pos < steps { // forward the transformer to get logits for the next token llama2.Transformer(token, pos, config, runState, w) var next int if pos < len(promptTokens) { next = promptTokens[pos] } else { // sample the next token if ai.temperature == 0 { // greedy argmax sampling next = nn.ArgMax(runState.Logits) } else { // apply the temperature to the logits for q := 0; q < config.VocabSize; q++ { runState.Logits[q] /= float32(ai.temperature) } // apply softmax to the logits to the probabilities for next token nn.SoftMax(runState.Logits) // we now want to sample from this distribution to get the next token if ai.topp <= 0 || ai.topp >= 1 { // simply sample from the predicted probability distribution next = nn.Sample(runState.Logits) } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero next = nn.SampleTopP(runState.Logits, float32(ai.topp)) } } } pos++ // data-dependent terminating condition: the BOS (1) token delimits sequences if next == 1 { break } // following BOS (1) token, sentencepiece decoder strips any leading whitespace var tokenStr string if token == 1 && vocab.Words[next][0] == ' ' { tokenStr = vocab.Words[next][1:] } else { tokenStr = vocab.Words[next] } out.Write([]byte(tokenStr)) // advance forward token = next } out.Write([]byte("\n")) return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil }