go test -tags=ai -v -run=AI works with ollama which is cool and fast with llama3
parent
8557ddc522
commit
14de286415
126
ai.go
126
ai.go
|
|
@ -1,13 +1,8 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
||||
"github.com/nikolaydubina/llama2.go/llama2"
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/ollama"
|
||||
)
|
||||
|
|
@ -46,124 +41,3 @@ func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) {
|
|||
}
|
||||
return llms.GenerateFromSinglePrompt(ctx, llm, prompt)
|
||||
}
|
||||
|
||||
type AILocal struct {
|
||||
checkpointPath string
|
||||
tokenizerPath string
|
||||
temperature float64
|
||||
steps int
|
||||
topp float64
|
||||
}
|
||||
|
||||
func NewAILocal(
|
||||
checkpointPath string,
|
||||
tokenizerPath string,
|
||||
temperature float64,
|
||||
steps int,
|
||||
topp float64,
|
||||
) AILocal {
|
||||
return AILocal{
|
||||
checkpointPath: checkpointPath,
|
||||
tokenizerPath: tokenizerPath,
|
||||
temperature: temperature,
|
||||
steps: steps,
|
||||
topp: topp,
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/nikolaydubina/llama2.go/blob/master/main.go
|
||||
func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
|
||||
checkpointFile, err := os.OpenFile(ai.checkpointPath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer checkpointFile.Close()
|
||||
|
||||
config, err := llama2.NewConfigFromCheckpoint(checkpointFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
isSharedWeights := config.VocabSize > 0
|
||||
if config.VocabSize < 0 {
|
||||
config.VocabSize = -config.VocabSize
|
||||
}
|
||||
|
||||
tokenizerFile, err := os.OpenFile(ai.tokenizerPath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer tokenizerFile.Close()
|
||||
|
||||
vocab := llama2.NewVocabFromFile(config.VocabSize, tokenizerFile)
|
||||
|
||||
w := llama2.NewTransformerWeightsFromCheckpoint(config, checkpointFile, isSharedWeights)
|
||||
|
||||
// right now we cannot run for more than config.SeqLen steps
|
||||
steps := ai.steps
|
||||
if steps <= 0 || steps > config.SeqLen {
|
||||
steps = config.SeqLen
|
||||
}
|
||||
|
||||
runState := llama2.NewRunState(config)
|
||||
|
||||
promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>"))
|
||||
|
||||
out := bytes.NewBuffer(nil)
|
||||
|
||||
// the current position we are in
|
||||
var token int = 1 // 1 = BOS token in llama-2 sentencepiece
|
||||
var pos = 0
|
||||
|
||||
for pos < steps {
|
||||
// forward the transformer to get logits for the next token
|
||||
llama2.Transformer(token, pos, config, runState, w)
|
||||
|
||||
var next int
|
||||
if pos < len(promptTokens) {
|
||||
next = promptTokens[pos]
|
||||
} else {
|
||||
// sample the next token
|
||||
if ai.temperature == 0 {
|
||||
// greedy argmax sampling
|
||||
next = nn.ArgMax(runState.Logits)
|
||||
} else {
|
||||
// apply the temperature to the logits
|
||||
for q := 0; q < config.VocabSize; q++ {
|
||||
runState.Logits[q] /= float32(ai.temperature)
|
||||
}
|
||||
// apply softmax to the logits to the probabilities for next token
|
||||
nn.SoftMax(runState.Logits)
|
||||
// we now want to sample from this distribution to get the next token
|
||||
if ai.topp <= 0 || ai.topp >= 1 {
|
||||
// simply sample from the predicted probability distribution
|
||||
next = nn.Sample(runState.Logits)
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
next = nn.SampleTopP(runState.Logits, float32(ai.topp))
|
||||
}
|
||||
}
|
||||
}
|
||||
pos++
|
||||
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
if next == 1 {
|
||||
break
|
||||
}
|
||||
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace
|
||||
var tokenStr string
|
||||
if token == 1 && vocab.Words[next][0] == ' ' {
|
||||
tokenStr = vocab.Words[next][1:]
|
||||
} else {
|
||||
tokenStr = vocab.Words[next]
|
||||
}
|
||||
out.Write([]byte(tokenStr))
|
||||
|
||||
// advance forward
|
||||
token = next
|
||||
}
|
||||
out.Write([]byte("\n"))
|
||||
|
||||
return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil
|
||||
}
|
||||
|
|
|
|||
106
ai_test.go
106
ai_test.go
|
|
@ -4,11 +4,6 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -22,53 +17,7 @@ func TestAINoop(t *testing.T) {
|
|||
|
||||
func TestAIOllama(t *testing.T) {
|
||||
t.Parallel()
|
||||
ai := NewAIOllama("http://localhost:11434", "gemma:2b")
|
||||
|
||||
testAI(t, ai)
|
||||
}
|
||||
|
||||
func TestAILocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
d := os.TempDir()
|
||||
checkpoints := "checkpoints"
|
||||
tokenizer := "tokenizer"
|
||||
for u, p := range map[string]*string{
|
||||
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints,
|
||||
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer,
|
||||
} {
|
||||
func() {
|
||||
*p = path.Base(u)
|
||||
if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) {
|
||||
t.Logf("downloading %s from %s", u, *p)
|
||||
|
||||
resp, err := http.Get(u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
f, err := os.Create(path.Join(d, *p))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(f, resp.Body); err != nil {
|
||||
f.Close()
|
||||
os.Remove(path.Join(d, *p))
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ai := NewAILocal(
|
||||
path.Join(d, checkpoints),
|
||||
path.Join(d, tokenizer),
|
||||
0.9,
|
||||
256,
|
||||
0.9,
|
||||
)
|
||||
ai := NewAIOllama("http://localhost:11434", "llama3")
|
||||
|
||||
testAI(t, ai)
|
||||
}
|
||||
|
|
@ -78,7 +27,7 @@ func testAI(t *testing.T, ai AI) {
|
|||
defer can()
|
||||
|
||||
t.Run("mvp", func(t *testing.T) {
|
||||
if result, err := ai.Do(ctx, "hello world"); err != nil {
|
||||
if result, err := ai.Do(ctx, "Tell me a fun fact."); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(result) < 3 {
|
||||
t.Error(result)
|
||||
|
|
@ -87,32 +36,33 @@ func testAI(t *testing.T, ai AI) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("simulation", func(t *testing.T) {
|
||||
d := NewRAM()
|
||||
FillWithTestdata(ctx, d, renderAssetPattern, renderDatacenterPattern, renderEventNamePattern)
|
||||
s := NewStorage(d)
|
||||
/*
|
||||
t.Run("simulation", func(t *testing.T) {
|
||||
d := NewRAM()
|
||||
FillWithTestdata(ctx, d, renderAssetPattern, renderDatacenterPattern, renderEventNamePattern)
|
||||
s := NewStorage(d)
|
||||
|
||||
threads, err := s.Threads(ctx)
|
||||
if err != nil || len(threads) < 1 {
|
||||
t.Fatal(err)
|
||||
}
|
||||
threads, err := s.Threads(ctx)
|
||||
if err != nil || len(threads) < 1 {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
thread, err := s.Thread(ctx, threads[0])
|
||||
if err != nil || len(thread) < 1 {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
input := fmt.Sprintf(`
|
||||
Summarize the following forum converstion.
|
||||
---
|
||||
%s
|
||||
`, thread[0].Plaintext)
|
||||
t.Logf("\n\t%s", input)
|
||||
result, err := ai.Do(ctx, input)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("\n\t%s\n->\n\t%s", input, result)
|
||||
})
|
||||
thread, err := s.Thread(ctx, threads[0])
|
||||
if err != nil || len(thread) < 1 {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
input := fmt.Sprintf(`
|
||||
Summarize the following forum converstion.
|
||||
---
|
||||
%s
|
||||
`, thread[0].Plaintext)
|
||||
t.Logf("\n\t%s", input)
|
||||
result, err := ai.Do(ctx, input)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("\n\t%s\n->\n\t%s", input, result)
|
||||
})
|
||||
*/
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,8 +25,6 @@ type Config struct {
|
|||
FillWithTestdata bool
|
||||
OllamaURL string
|
||||
OllamaModel string
|
||||
LocalCheckpoint string
|
||||
LocalTokenizer string
|
||||
AssetPattern string
|
||||
DatacenterPattern string
|
||||
EventNamePattern string
|
||||
|
|
@ -132,8 +130,6 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
|
|||
|
||||
if result.OllamaURL != "" {
|
||||
result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel)
|
||||
} else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" {
|
||||
result.ai = NewAILocal(result.LocalCheckpoint, result.LocalTokenizer, 0.9, 128, 0.9)
|
||||
} else {
|
||||
result.ai = NewAINoop()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue