From a48518fea6be7570e565078464a26b84ccf03eb6 Mon Sep 17 00:00:00 2001 From: bel Date: Fri, 12 Apr 2024 15:30:50 -0600 Subject: [PATCH] yay --- .gitignore | 1 + ai.go | 80 +++++++++++++++++++++++++++++++++++++++++++++++++ config.go | 5 ++-- go.mod | 8 ++--- go.sum | 10 +++++++ main.go | 10 +++++-- main_test.go | 3 +- storage_test.go | 5 ++-- 8 files changed, 107 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index e92c850..7cde8b2 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ +/slack-bot-vr /spoc-bot-vr diff --git a/ai.go b/ai.go index 9a018d1..081147e 100644 --- a/ai.go +++ b/ai.go @@ -6,6 +6,7 @@ import ( "errors" "os" + nn "github.com/nikolaydubina/llama2.go/exp/nnfast" "github.com/nikolaydubina/llama2.go/llama2" ) @@ -50,7 +51,86 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) { return "", err } + isSharedWeights := config.VocabSize > 0 + if config.VocabSize < 0 { + config.VocabSize = -config.VocabSize + } + + tokenizerFile, err := os.OpenFile(ai.tokenizerPath, os.O_RDONLY, 0) + if err != nil { + return "", err + } + defer tokenizerFile.Close() + + vocab := llama2.NewVocabFromFile(config.VocabSize, tokenizerFile) + + w := llama2.NewTransformerWeightsFromCheckpoint(config, checkpointFile, isSharedWeights) + + // right now we cannot run for more than config.SeqLen steps + steps := ai.steps + if steps <= 0 || steps > config.SeqLen { + steps = config.SeqLen + } + + runState := llama2.NewRunState(config) + + promptTokens := vocab.Encode(prompt) + out := bytes.NewBuffer(nil) + // the current position we are in + var token int = 1 // 1 = BOS token in llama-2 sentencepiece + var pos = 0 + + for pos < steps { + // forward the transformer to get logits for the next token + llama2.Transformer(token, pos, config, runState, w) + + var next int + if pos < len(promptTokens) { + next = promptTokens[pos] + } else { + // sample the next token + if ai.temperature == 0 { + // greedy argmax sampling + next = nn.ArgMax(runState.Logits) + } else { + // apply the temperature to the logits + for q := 0; q < config.VocabSize; q++ { + runState.Logits[q] /= float32(ai.temperature) + } + // apply softmax to the logits to the probabilities for next token + nn.SoftMax(runState.Logits) + // we now want to sample from this distribution to get the next token + if ai.topp <= 0 || ai.topp >= 1 { + // simply sample from the predicted probability distribution + next = nn.Sample(runState.Logits) + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + next = nn.SampleTopP(runState.Logits, float32(ai.topp)) + } + } + } + pos++ + + // data-dependent terminating condition: the BOS (1) token delimits sequences + if next == 1 { + break + } + + // following BOS (1) token, sentencepiece decoder strips any leading whitespace + var tokenStr string + if token == 1 && vocab.Words[next][0] == ' ' { + tokenStr = vocab.Words[next][1:] + } else { + tokenStr = vocab.Words[next] + } + out.Write([]byte(tokenStr)) + + // advance forward + token = next + } + out.Write([]byte("\n")) + return string(out.Bytes()), errors.New("not impl") } diff --git a/config.go b/config.go index ffb707b..d383f8f 100644 --- a/config.go +++ b/config.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "regexp" - "slices" "strconv" "strings" "time" @@ -47,8 +46,8 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, for k, v := range m { envK := k idxes := re.FindAllIndex([]byte(envK), -1) - slices.Reverse(idxes) - for _, idx := range idxes { + for i := len(idxes) - 1; i >= 0; i-- { + idx := idxes[i] if idx[0] > 0 { envK = fmt.Sprintf("%s_%s", envK[:idx[0]], envK[idx[0]:]) } diff --git a/go.mod b/go.mod index 4f2035a..f48b54b 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,9 @@ go 1.22.1 require ( github.com/go-errors/errors v1.5.1 + github.com/lib/pq v1.10.9 + github.com/nikolaydubina/llama2.go v0.7.1 go.etcd.io/bbolt v1.3.9 ) -require ( - github.com/lib/pq v1.10.9 // indirect - github.com/nikolaydubina/llama2.go v0.7.1 // indirect - golang.org/x/sys v0.4.0 // indirect -) +require golang.org/x/sys v0.4.0 // indirect diff --git a/go.sum b/go.sum index da4a17a..a496fba 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,20 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk= github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE= github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI= go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index a1bc241..4b218e9 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,6 @@ import ( "net" "net/http" "os/signal" - "slices" "strconv" "strings" "syscall" @@ -169,7 +168,14 @@ func _newHandlerPostAPIV1EventsSlack(cfg Config) http.HandlerFunc { } else if allowList.Token != cfg.SlackToken { http.Error(w, "invalid .token", http.StatusForbidden) return - } else if !slices.Contains(cfg.SlackChannels, allowList.Event.Channel) { + } else if !func() bool { + for _, slackChannel := range cfg.SlackChannels { + if slackChannel == allowList.Event.Channel { + return true + } + } + return false + }() { return } diff --git a/main_test.go b/main_test.go index deb495d..79fcdde 100644 --- a/main_test.go +++ b/main_test.go @@ -11,7 +11,6 @@ import ( "net/url" "os" "path" - "slices" "strconv" "strings" "testing" @@ -97,7 +96,7 @@ func TestRun(t *testing.T) { } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { t.Fatal(err) - } else if !slices.Contains(result.Threads, "1712911957.023359") { + } else if result.Threads[0] != "1712911957.023359" { t.Fatal(result.Threads) } }) diff --git a/storage_test.go b/storage_test.go index 304a872..a409eae 100644 --- a/storage_test.go +++ b/storage_test.go @@ -2,7 +2,6 @@ package main import ( "context" - "slices" "testing" "time" ) @@ -30,9 +29,9 @@ func TestStorage(t *testing.T) { t.Error(err) } else if len(threads) != 2 { t.Error(threads) - } else if !slices.Contains(threads, "X") { + } else if threads[0] != "X" { t.Error(threads, "X") - } else if !slices.Contains(threads, "Y") { + } else if threads[1] != "Y" { t.Error(threads, "Y") }