Compare commits

..

3 Commits

Author SHA1 Message Date
bel
5f31a2c572 yay ai test fails 2024-04-12 16:31:23 -06:00
bel
046dc0e1ba WE GOT AI 2024-04-12 15:58:47 -06:00
bel
a48518fea6 yay 2024-04-12 15:30:50 -06:00
10 changed files with 185 additions and 24 deletions

2
.gitignore vendored
View File

@@ -1 +1,3 @@
/slack-bot-vr
**/*.sw*
/spoc-bot-vr /spoc-bot-vr

83
ai.go
View File

@@ -3,9 +3,9 @@ package main
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"os" "os"
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
"github.com/nikolaydubina/llama2.go/llama2" "github.com/nikolaydubina/llama2.go/llama2"
) )
@@ -50,7 +50,86 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
return "", err return "", err
} }
isSharedWeights := config.VocabSize > 0
if config.VocabSize < 0 {
config.VocabSize = -config.VocabSize
}
tokenizerFile, err := os.OpenFile(ai.tokenizerPath, os.O_RDONLY, 0)
if err != nil {
return "", err
}
defer tokenizerFile.Close()
vocab := llama2.NewVocabFromFile(config.VocabSize, tokenizerFile)
w := llama2.NewTransformerWeightsFromCheckpoint(config, checkpointFile, isSharedWeights)
// right now we cannot run for more than config.SeqLen steps
steps := ai.steps
if steps <= 0 || steps > config.SeqLen {
steps = config.SeqLen
}
runState := llama2.NewRunState(config)
promptTokens := vocab.Encode(prompt)
out := bytes.NewBuffer(nil) out := bytes.NewBuffer(nil)
return string(out.Bytes()), errors.New("not impl") // the current position we are in
var token int = 1 // 1 = BOS token in llama-2 sentencepiece
var pos = 0
for pos < steps {
// forward the transformer to get logits for the next token
llama2.Transformer(token, pos, config, runState, w)
var next int
if pos < len(promptTokens) {
next = promptTokens[pos]
} else {
// sample the next token
if ai.temperature == 0 {
// greedy argmax sampling
next = nn.ArgMax(runState.Logits)
} else {
// apply the temperature to the logits
for q := 0; q < config.VocabSize; q++ {
runState.Logits[q] /= float32(ai.temperature)
}
// apply softmax to the logits to the probabilities for next token
nn.SoftMax(runState.Logits)
// we now want to sample from this distribution to get the next token
if ai.topp <= 0 || ai.topp >= 1 {
// simply sample from the predicted probability distribution
next = nn.Sample(runState.Logits)
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = nn.SampleTopP(runState.Logits, float32(ai.topp))
}
}
}
pos++
// data-dependent terminating condition: the BOS (1) token delimits sequences
if next == 1 {
break
}
// following BOS (1) token, sentencepiece decoder strips any leading whitespace
var tokenStr string
if token == 1 && vocab.Words[next][0] == ' ' {
tokenStr = vocab.Words[next][1:]
} else {
tokenStr = vocab.Words[next]
}
out.Write([]byte(tokenStr))
// advance forward
token = next
}
out.Write([]byte("\n"))
return string(out.Bytes()), nil
} }

View File

@@ -4,6 +4,11 @@ package main
import ( import (
"context" "context"
"fmt"
"io"
"net/http"
"os"
"path"
"testing" "testing"
"time" "time"
) )
@@ -12,17 +17,80 @@ func TestAILocal(t *testing.T) {
ctx, can := context.WithTimeout(context.Background(), time.Minute) ctx, can := context.WithTimeout(context.Background(), time.Minute)
defer can() defer can()
d := os.TempDir()
for k, u := range map[string]string{
"checkpoints": "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin",
"tokenizer": "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin",
} {
func() {
if _, err := os.Stat(path.Join(d, k)); os.IsNotExist(err) {
t.Logf("downloading %s from %s", u, k)
resp, err := http.Get(u)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
f, err := os.Create(path.Join(d, k))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := io.Copy(f, resp.Body); err != nil {
f.Close()
os.Remove(path.Join(d, k))
t.Fatal(err)
}
}
}()
}
ai := NewAILocal( ai := NewAILocal(
"checkpointPath", path.Join(d, "checkpoints"),
"tokenizerPath", path.Join(d, "tokenizer"),
0.9, 0.9,
256, 256,
0.9, 0.9,
) )
if result, err := ai.Do(ctx, "hello world"); err != nil { t.Run("mvp", func(t *testing.T) {
t.Fatal(err) if result, err := ai.Do(ctx, "hello world"); err != nil {
} else { t.Fatal(err)
t.Logf("%s", result) } else if len(result) < 250 {
} t.Error(result)
} else {
t.Logf("%s", result)
}
})
t.Run("simulation", func(t *testing.T) {
d := NewRAM()
FillWithTestdata(ctx, d)
s := NewStorage(d)
threads, err := s.Threads(ctx)
if err != nil || len(threads) < 1 {
t.Fatal(err)
}
thread, err := s.Thread(ctx, threads[0])
if err != nil || len(thread) < 1 {
t.Fatal(err)
}
input := fmt.Sprintf(`
Summarize the following forum converstion.
---
%s
`, thread[0].Plaintext)
t.Logf("\n\t%s", input)
result, err := ai.Do(ctx, input)
if err != nil {
t.Fatal(err)
}
t.Logf("\n\t%s\n->\n\t%s", input, result)
})
} }

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"os" "os"
"regexp" "regexp"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -47,8 +46,8 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
for k, v := range m { for k, v := range m {
envK := k envK := k
idxes := re.FindAllIndex([]byte(envK), -1) idxes := re.FindAllIndex([]byte(envK), -1)
slices.Reverse(idxes) for i := len(idxes) - 1; i >= 0; i-- {
for _, idx := range idxes { idx := idxes[i]
if idx[0] > 0 { if idx[0] > 0 {
envK = fmt.Sprintf("%s_%s", envK[:idx[0]], envK[idx[0]:]) envK = fmt.Sprintf("%s_%s", envK[:idx[0]], envK[idx[0]:])
} }

8
go.mod
View File

@@ -4,11 +4,9 @@ go 1.22.1
require ( require (
github.com/go-errors/errors v1.5.1 github.com/go-errors/errors v1.5.1
github.com/lib/pq v1.10.9
github.com/nikolaydubina/llama2.go v0.7.1
go.etcd.io/bbolt v1.3.9 go.etcd.io/bbolt v1.3.9
) )
require ( require golang.org/x/sys v0.4.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/nikolaydubina/llama2.go v0.7.1 // indirect
golang.org/x/sys v0.4.0 // indirect
)

10
go.sum
View File

@@ -1,10 +1,20 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk= github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk=
github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE= github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE=
github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM= github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI= go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

10
main.go
View File

@@ -11,7 +11,6 @@ import (
"net" "net"
"net/http" "net/http"
"os/signal" "os/signal"
"slices"
"strconv" "strconv"
"strings" "strings"
"syscall" "syscall"
@@ -169,7 +168,14 @@ func _newHandlerPostAPIV1EventsSlack(cfg Config) http.HandlerFunc {
} else if allowList.Token != cfg.SlackToken { } else if allowList.Token != cfg.SlackToken {
http.Error(w, "invalid .token", http.StatusForbidden) http.Error(w, "invalid .token", http.StatusForbidden)
return return
} else if !slices.Contains(cfg.SlackChannels, allowList.Event.Channel) { } else if !func() bool {
for _, slackChannel := range cfg.SlackChannels {
if slackChannel == allowList.Event.Channel {
return true
}
}
return false
}() {
return return
} }

View File

@@ -11,7 +11,6 @@ import (
"net/url" "net/url"
"os" "os"
"path" "path"
"slices"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
@@ -97,7 +96,7 @@ func TestRun(t *testing.T) {
} }
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatal(err) t.Fatal(err)
} else if !slices.Contains(result.Threads, "1712911957.023359") { } else if result.Threads[0] != "1712911957.023359" {
t.Fatal(result.Threads) t.Fatal(result.Threads)
} }
}) })

View File

@@ -38,6 +38,7 @@ func (s Storage) ThreadsSince(ctx context.Context, t time.Time) ([]string, error
for k := range threads { for k := range threads {
result = append(result, k) result = append(result, k)
} }
sort.Strings(result)
return result, nil return result, nil
} }

View File

@@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"slices"
"testing" "testing"
"time" "time"
) )
@@ -30,9 +29,9 @@ func TestStorage(t *testing.T) {
t.Error(err) t.Error(err)
} else if len(threads) != 2 { } else if len(threads) != 2 {
t.Error(threads) t.Error(threads)
} else if !slices.Contains(threads, "X") { } else if threads[0] != "X" {
t.Error(threads, "X") t.Error(threads, "X")
} else if !slices.Contains(threads, "Y") { } else if threads[1] != "Y" {
t.Error(threads, "Y") t.Error(threads, "Y")
} }