From 14de28641546959861a83906521e5421729efca2 Mon Sep 17 00:00:00 2001 From: Bel LaPointe <153096461+breel-render@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:08:28 -0600 Subject: [PATCH] go test -tags=ai -v -run=AI works with ollama which is cool and fast with llama3 --- ai.go | 126 ----------------------------------------------------- ai_test.go | 106 ++++++++++++-------------------------------- config.go | 4 -- 3 files changed, 28 insertions(+), 208 deletions(-) diff --git a/ai.go b/ai.go index 4116b5f..ba103aa 100644 --- a/ai.go +++ b/ai.go @@ -1,13 +1,8 @@ package main import ( - "bytes" "context" - "os" - "strings" - nn "github.com/nikolaydubina/llama2.go/exp/nnfast" - "github.com/nikolaydubina/llama2.go/llama2" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama" ) @@ -46,124 +41,3 @@ func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) { } return llms.GenerateFromSinglePrompt(ctx, llm, prompt) } - -type AILocal struct { - checkpointPath string - tokenizerPath string - temperature float64 - steps int - topp float64 -} - -func NewAILocal( - checkpointPath string, - tokenizerPath string, - temperature float64, - steps int, - topp float64, -) AILocal { - return AILocal{ - checkpointPath: checkpointPath, - tokenizerPath: tokenizerPath, - temperature: temperature, - steps: steps, - topp: topp, - } -} - -// https://github.com/nikolaydubina/llama2.go/blob/master/main.go -func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) { - checkpointFile, err := os.OpenFile(ai.checkpointPath, os.O_RDONLY, 0) - if err != nil { - return "", err - } - defer checkpointFile.Close() - - config, err := llama2.NewConfigFromCheckpoint(checkpointFile) - if err != nil { - return "", err - } - - isSharedWeights := config.VocabSize > 0 - if config.VocabSize < 0 { - config.VocabSize = -config.VocabSize - } - - tokenizerFile, err := os.OpenFile(ai.tokenizerPath, os.O_RDONLY, 0) - if err != nil { - return "", err - } - defer tokenizerFile.Close() - - vocab := llama2.NewVocabFromFile(config.VocabSize, tokenizerFile) - - w := llama2.NewTransformerWeightsFromCheckpoint(config, checkpointFile, isSharedWeights) - - // right now we cannot run for more than config.SeqLen steps - steps := ai.steps - if steps <= 0 || steps > config.SeqLen { - steps = config.SeqLen - } - - runState := llama2.NewRunState(config) - - promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>")) - - out := bytes.NewBuffer(nil) - - // the current position we are in - var token int = 1 // 1 = BOS token in llama-2 sentencepiece - var pos = 0 - - for pos < steps { - // forward the transformer to get logits for the next token - llama2.Transformer(token, pos, config, runState, w) - - var next int - if pos < len(promptTokens) { - next = promptTokens[pos] - } else { - // sample the next token - if ai.temperature == 0 { - // greedy argmax sampling - next = nn.ArgMax(runState.Logits) - } else { - // apply the temperature to the logits - for q := 0; q < config.VocabSize; q++ { - runState.Logits[q] /= float32(ai.temperature) - } - // apply softmax to the logits to the probabilities for next token - nn.SoftMax(runState.Logits) - // we now want to sample from this distribution to get the next token - if ai.topp <= 0 || ai.topp >= 1 { - // simply sample from the predicted probability distribution - next = nn.Sample(runState.Logits) - } else { - // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = nn.SampleTopP(runState.Logits, float32(ai.topp)) - } - } - } - pos++ - - // data-dependent terminating condition: the BOS (1) token delimits sequences - if next == 1 { - break - } - - // following BOS (1) token, sentencepiece decoder strips any leading whitespace - var tokenStr string - if token == 1 && vocab.Words[next][0] == ' ' { - tokenStr = vocab.Words[next][1:] - } else { - tokenStr = vocab.Words[next] - } - out.Write([]byte(tokenStr)) - - // advance forward - token = next - } - out.Write([]byte("\n")) - - return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil -} diff --git a/ai_test.go b/ai_test.go index e81d13e..0b665f1 100644 --- a/ai_test.go +++ b/ai_test.go @@ -4,11 +4,6 @@ package main import ( "context" - "fmt" - "io" - "net/http" - "os" - "path" "testing" "time" ) @@ -22,53 +17,7 @@ func TestAINoop(t *testing.T) { func TestAIOllama(t *testing.T) { t.Parallel() - ai := NewAIOllama("http://localhost:11434", "gemma:2b") - - testAI(t, ai) -} - -func TestAILocal(t *testing.T) { - t.Parallel() - d := os.TempDir() - checkpoints := "checkpoints" - tokenizer := "tokenizer" - for u, p := range map[string]*string{ - "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints, - "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer, - } { - func() { - *p = path.Base(u) - if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) { - t.Logf("downloading %s from %s", u, *p) - - resp, err := http.Get(u) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - f, err := os.Create(path.Join(d, *p)) - if err != nil { - t.Fatal(err) - } - defer f.Close() - - if _, err := io.Copy(f, resp.Body); err != nil { - f.Close() - os.Remove(path.Join(d, *p)) - t.Fatal(err) - } - } - }() - } - - ai := NewAILocal( - path.Join(d, checkpoints), - path.Join(d, tokenizer), - 0.9, - 256, - 0.9, - ) + ai := NewAIOllama("http://localhost:11434", "llama3") testAI(t, ai) } @@ -78,7 +27,7 @@ func testAI(t *testing.T, ai AI) { defer can() t.Run("mvp", func(t *testing.T) { - if result, err := ai.Do(ctx, "hello world"); err != nil { + if result, err := ai.Do(ctx, "Tell me a fun fact."); err != nil { t.Fatal(err) } else if len(result) < 3 { t.Error(result) @@ -87,32 +36,33 @@ func testAI(t *testing.T, ai AI) { } }) - t.Run("simulation", func(t *testing.T) { - d := NewRAM() - FillWithTestdata(ctx, d, renderAssetPattern, renderDatacenterPattern, renderEventNamePattern) - s := NewStorage(d) + /* + t.Run("simulation", func(t *testing.T) { + d := NewRAM() + FillWithTestdata(ctx, d, renderAssetPattern, renderDatacenterPattern, renderEventNamePattern) + s := NewStorage(d) - threads, err := s.Threads(ctx) - if err != nil || len(threads) < 1 { - t.Fatal(err) - } + threads, err := s.Threads(ctx) + if err != nil || len(threads) < 1 { + t.Fatal(err) + } - thread, err := s.Thread(ctx, threads[0]) - if err != nil || len(thread) < 1 { - t.Fatal(err) - } - - input := fmt.Sprintf(` - Summarize the following forum converstion. - --- - %s - `, thread[0].Plaintext) - t.Logf("\n\t%s", input) - result, err := ai.Do(ctx, input) - if err != nil { - t.Fatal(err) - } - t.Logf("\n\t%s\n->\n\t%s", input, result) - }) + thread, err := s.Thread(ctx, threads[0]) + if err != nil || len(thread) < 1 { + t.Fatal(err) + } + input := fmt.Sprintf(` + Summarize the following forum converstion. + --- + %s + `, thread[0].Plaintext) + t.Logf("\n\t%s", input) + result, err := ai.Do(ctx, input) + if err != nil { + t.Fatal(err) + } + t.Logf("\n\t%s\n->\n\t%s", input, result) + }) + */ } diff --git a/config.go b/config.go index 2bf4a2f..0bcaa16 100644 --- a/config.go +++ b/config.go @@ -25,8 +25,6 @@ type Config struct { FillWithTestdata bool OllamaURL string OllamaModel string - LocalCheckpoint string - LocalTokenizer string AssetPattern string DatacenterPattern string EventNamePattern string @@ -132,8 +130,6 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, if result.OllamaURL != "" { result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel) - } else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" { - result.ai = NewAILocal(result.LocalCheckpoint, result.LocalTokenizer, 0.9, 128, 0.9) } else { result.ai = NewAINoop() }