yay
parent
163f894f3a
commit
a48518fea6
|
|
@ -1 +1,2 @@
|
||||||
|
/slack-bot-vr
|
||||||
/spoc-bot-vr
|
/spoc-bot-vr
|
||||||
|
|
|
||||||
80
ai.go
80
ai.go
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
||||||
"github.com/nikolaydubina/llama2.go/llama2"
|
"github.com/nikolaydubina/llama2.go/llama2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -50,7 +51,86 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isSharedWeights := config.VocabSize > 0
|
||||||
|
if config.VocabSize < 0 {
|
||||||
|
config.VocabSize = -config.VocabSize
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizerFile, err := os.OpenFile(ai.tokenizerPath, os.O_RDONLY, 0)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer tokenizerFile.Close()
|
||||||
|
|
||||||
|
vocab := llama2.NewVocabFromFile(config.VocabSize, tokenizerFile)
|
||||||
|
|
||||||
|
w := llama2.NewTransformerWeightsFromCheckpoint(config, checkpointFile, isSharedWeights)
|
||||||
|
|
||||||
|
// right now we cannot run for more than config.SeqLen steps
|
||||||
|
steps := ai.steps
|
||||||
|
if steps <= 0 || steps > config.SeqLen {
|
||||||
|
steps = config.SeqLen
|
||||||
|
}
|
||||||
|
|
||||||
|
runState := llama2.NewRunState(config)
|
||||||
|
|
||||||
|
promptTokens := vocab.Encode(prompt)
|
||||||
|
|
||||||
out := bytes.NewBuffer(nil)
|
out := bytes.NewBuffer(nil)
|
||||||
|
|
||||||
|
// the current position we are in
|
||||||
|
var token int = 1 // 1 = BOS token in llama-2 sentencepiece
|
||||||
|
var pos = 0
|
||||||
|
|
||||||
|
for pos < steps {
|
||||||
|
// forward the transformer to get logits for the next token
|
||||||
|
llama2.Transformer(token, pos, config, runState, w)
|
||||||
|
|
||||||
|
var next int
|
||||||
|
if pos < len(promptTokens) {
|
||||||
|
next = promptTokens[pos]
|
||||||
|
} else {
|
||||||
|
// sample the next token
|
||||||
|
if ai.temperature == 0 {
|
||||||
|
// greedy argmax sampling
|
||||||
|
next = nn.ArgMax(runState.Logits)
|
||||||
|
} else {
|
||||||
|
// apply the temperature to the logits
|
||||||
|
for q := 0; q < config.VocabSize; q++ {
|
||||||
|
runState.Logits[q] /= float32(ai.temperature)
|
||||||
|
}
|
||||||
|
// apply softmax to the logits to the probabilities for next token
|
||||||
|
nn.SoftMax(runState.Logits)
|
||||||
|
// we now want to sample from this distribution to get the next token
|
||||||
|
if ai.topp <= 0 || ai.topp >= 1 {
|
||||||
|
// simply sample from the predicted probability distribution
|
||||||
|
next = nn.Sample(runState.Logits)
|
||||||
|
} else {
|
||||||
|
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||||
|
next = nn.SampleTopP(runState.Logits, float32(ai.topp))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
|
||||||
|
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||||
|
if next == 1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// following BOS (1) token, sentencepiece decoder strips any leading whitespace
|
||||||
|
var tokenStr string
|
||||||
|
if token == 1 && vocab.Words[next][0] == ' ' {
|
||||||
|
tokenStr = vocab.Words[next][1:]
|
||||||
|
} else {
|
||||||
|
tokenStr = vocab.Words[next]
|
||||||
|
}
|
||||||
|
out.Write([]byte(tokenStr))
|
||||||
|
|
||||||
|
// advance forward
|
||||||
|
token = next
|
||||||
|
}
|
||||||
|
out.Write([]byte("\n"))
|
||||||
|
|
||||||
return string(out.Bytes()), errors.New("not impl")
|
return string(out.Bytes()), errors.New("not impl")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -47,8 +46,8 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
|
||||||
for k, v := range m {
|
for k, v := range m {
|
||||||
envK := k
|
envK := k
|
||||||
idxes := re.FindAllIndex([]byte(envK), -1)
|
idxes := re.FindAllIndex([]byte(envK), -1)
|
||||||
slices.Reverse(idxes)
|
for i := len(idxes) - 1; i >= 0; i-- {
|
||||||
for _, idx := range idxes {
|
idx := idxes[i]
|
||||||
if idx[0] > 0 {
|
if idx[0] > 0 {
|
||||||
envK = fmt.Sprintf("%s_%s", envK[:idx[0]], envK[idx[0]:])
|
envK = fmt.Sprintf("%s_%s", envK[:idx[0]], envK[idx[0]:])
|
||||||
}
|
}
|
||||||
|
|
|
||||||
8
go.mod
8
go.mod
|
|
@ -4,11 +4,9 @@ go 1.22.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-errors/errors v1.5.1
|
github.com/go-errors/errors v1.5.1
|
||||||
|
github.com/lib/pq v1.10.9
|
||||||
|
github.com/nikolaydubina/llama2.go v0.7.1
|
||||||
go.etcd.io/bbolt v1.3.9
|
go.etcd.io/bbolt v1.3.9
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require golang.org/x/sys v0.4.0 // indirect
|
||||||
github.com/lib/pq v1.10.9 // indirect
|
|
||||||
github.com/nikolaydubina/llama2.go v0.7.1 // indirect
|
|
||||||
golang.org/x/sys v0.4.0 // indirect
|
|
||||||
)
|
|
||||||
|
|
|
||||||
10
go.sum
10
go.sum
|
|
@ -1,10 +1,20 @@
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk=
|
github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk=
|
||||||
github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
||||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE=
|
github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE=
|
||||||
github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM=
|
github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||||
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
|
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
|
||||||
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
|
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
|
||||||
|
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
|
||||||
|
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
|
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
|
||||||
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|
|
||||||
10
main.go
10
main.go
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
@ -169,7 +168,14 @@ func _newHandlerPostAPIV1EventsSlack(cfg Config) http.HandlerFunc {
|
||||||
} else if allowList.Token != cfg.SlackToken {
|
} else if allowList.Token != cfg.SlackToken {
|
||||||
http.Error(w, "invalid .token", http.StatusForbidden)
|
http.Error(w, "invalid .token", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
} else if !slices.Contains(cfg.SlackChannels, allowList.Event.Channel) {
|
} else if !func() bool {
|
||||||
|
for _, slackChannel := range cfg.SlackChannels {
|
||||||
|
if slackChannel == allowList.Event.Channel {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
@ -97,7 +96,7 @@ func TestRun(t *testing.T) {
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if !slices.Contains(result.Threads, "1712911957.023359") {
|
} else if result.Threads[0] != "1712911957.023359" {
|
||||||
t.Fatal(result.Threads)
|
t.Fatal(result.Threads)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"slices"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -30,9 +29,9 @@ func TestStorage(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
} else if len(threads) != 2 {
|
} else if len(threads) != 2 {
|
||||||
t.Error(threads)
|
t.Error(threads)
|
||||||
} else if !slices.Contains(threads, "X") {
|
} else if threads[0] != "X" {
|
||||||
t.Error(threads, "X")
|
t.Error(threads, "X")
|
||||||
} else if !slices.Contains(threads, "Y") {
|
} else if threads[1] != "Y" {
|
||||||
t.Error(threads, "Y")
|
t.Error(threads, "Y")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue