Compare commits
3 Commits
da0125c663
...
d5b09db0c6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5b09db0c6 | ||
|
|
f288a9b098 | ||
|
|
ecf29d54b8 |
38
ai.go
38
ai.go
@@ -4,15 +4,49 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
||||||
"github.com/nikolaydubina/llama2.go/llama2"
|
"github.com/nikolaydubina/llama2.go/llama2"
|
||||||
|
"github.com/tmc/langchaingo/llms"
|
||||||
|
"github.com/tmc/langchaingo/llms/ollama"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AI interface {
|
type AI interface {
|
||||||
Do(context.Context, string) (string, error)
|
Do(context.Context, string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AINoop struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAINoop() AINoop {
|
||||||
|
return AINoop{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ai AINoop) Do(ctx context.Context, prompt string) (string, error) {
|
||||||
|
return ":shrug:", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIOllama struct {
|
||||||
|
model string
|
||||||
|
url string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAIOllama(url, model string) AIOllama {
|
||||||
|
return AIOllama{url: url, model: model}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) {
|
||||||
|
llm, err := ollama.New(
|
||||||
|
ollama.WithModel(ai.model),
|
||||||
|
ollama.WithServerURL(ai.url),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return llms.GenerateFromSinglePrompt(ctx, llm, prompt)
|
||||||
|
}
|
||||||
|
|
||||||
type AILocal struct {
|
type AILocal struct {
|
||||||
checkpointPath string
|
checkpointPath string
|
||||||
tokenizerPath string
|
tokenizerPath string
|
||||||
@@ -73,7 +107,7 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
|
|||||||
|
|
||||||
runState := llama2.NewRunState(config)
|
runState := llama2.NewRunState(config)
|
||||||
|
|
||||||
promptTokens := vocab.Encode(prompt)
|
promptTokens := vocab.Encode(strings.ReplaceAll(prompt, "\n", "<0x0A>"))
|
||||||
|
|
||||||
out := bytes.NewBuffer(nil)
|
out := bytes.NewBuffer(nil)
|
||||||
|
|
||||||
@@ -131,5 +165,5 @@ func (ai AILocal) Do(ctx context.Context, prompt string) (string, error) {
|
|||||||
}
|
}
|
||||||
out.Write([]byte("\n"))
|
out.Write([]byte("\n"))
|
||||||
|
|
||||||
return string(out.Bytes()), nil
|
return strings.ReplaceAll(string(out.Bytes()), "<0x0A>", "\n"), nil
|
||||||
}
|
}
|
||||||
|
|||||||
45
ai_test.go
45
ai_test.go
@@ -13,18 +13,30 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAILocal(t *testing.T) {
|
func TestAINoop(t *testing.T) {
|
||||||
ctx, can := context.WithTimeout(context.Background(), time.Minute)
|
ai := NewAINoop()
|
||||||
defer can()
|
|
||||||
|
|
||||||
|
testAI(t, ai)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAIOllama(t *testing.T) {
|
||||||
|
ai := NewAIOllama("http://localhost:11434", "gemma:2b")
|
||||||
|
|
||||||
|
testAI(t, ai)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAILocal(t *testing.T) {
|
||||||
d := os.TempDir()
|
d := os.TempDir()
|
||||||
for k, u := range map[string]string{
|
checkpoints := "checkpoints"
|
||||||
"checkpoints": "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin",
|
tokenizer := "tokenizer"
|
||||||
"tokenizer": "https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin",
|
for u, p := range map[string]*string{
|
||||||
|
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin": &checkpoints,
|
||||||
|
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin": &tokenizer,
|
||||||
} {
|
} {
|
||||||
func() {
|
func() {
|
||||||
if _, err := os.Stat(path.Join(d, k)); os.IsNotExist(err) {
|
*p = path.Base(u)
|
||||||
t.Logf("downloading %s from %s", u, k)
|
if _, err := os.Stat(path.Join(d, *p)); os.IsNotExist(err) {
|
||||||
|
t.Logf("downloading %s from %s", u, *p)
|
||||||
|
|
||||||
resp, err := http.Get(u)
|
resp, err := http.Get(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -32,7 +44,7 @@ func TestAILocal(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
f, err := os.Create(path.Join(d, k))
|
f, err := os.Create(path.Join(d, *p))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -40,7 +52,7 @@ func TestAILocal(t *testing.T) {
|
|||||||
|
|
||||||
if _, err := io.Copy(f, resp.Body); err != nil {
|
if _, err := io.Copy(f, resp.Body); err != nil {
|
||||||
f.Close()
|
f.Close()
|
||||||
os.Remove(path.Join(d, k))
|
os.Remove(path.Join(d, *p))
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -48,17 +60,24 @@ func TestAILocal(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ai := NewAILocal(
|
ai := NewAILocal(
|
||||||
path.Join(d, "checkpoints"),
|
path.Join(d, checkpoints),
|
||||||
path.Join(d, "tokenizer"),
|
path.Join(d, tokenizer),
|
||||||
0.9,
|
0.9,
|
||||||
256,
|
256,
|
||||||
0.9,
|
0.9,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
testAI(t, ai)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAI(t *testing.T, ai AI) {
|
||||||
|
ctx, can := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer can()
|
||||||
|
|
||||||
t.Run("mvp", func(t *testing.T) {
|
t.Run("mvp", func(t *testing.T) {
|
||||||
if result, err := ai.Do(ctx, "hello world"); err != nil {
|
if result, err := ai.Do(ctx, "hello world"); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if len(result) < 250 {
|
} else if len(result) < 3 {
|
||||||
t.Error(result)
|
t.Error(result)
|
||||||
} else {
|
} else {
|
||||||
t.Logf("%s", result)
|
t.Logf("%s", result)
|
||||||
|
|||||||
16
config.go
16
config.go
@@ -21,9 +21,14 @@ type Config struct {
|
|||||||
BasicAuthUser string
|
BasicAuthUser string
|
||||||
BasicAuthPassword string
|
BasicAuthPassword string
|
||||||
FillWithTestdata bool
|
FillWithTestdata bool
|
||||||
|
OllamaURL string
|
||||||
|
OllamaModel string
|
||||||
|
LocalCheckpoint string
|
||||||
|
LocalTokenizer string
|
||||||
storage Storage
|
storage Storage
|
||||||
queue Queue
|
queue Queue
|
||||||
driver Driver
|
driver Driver
|
||||||
|
ai AI
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConfig(ctx context.Context) (Config, error) {
|
func newConfig(ctx context.Context) (Config, error) {
|
||||||
@@ -32,7 +37,8 @@ func newConfig(ctx context.Context) (Config, error) {
|
|||||||
|
|
||||||
func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) {
|
func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) {
|
||||||
def := Config{
|
def := Config{
|
||||||
Port: 8080,
|
Port: 8080,
|
||||||
|
OllamaModel: "gemma:2b",
|
||||||
}
|
}
|
||||||
|
|
||||||
var m map[string]any
|
var m map[string]any
|
||||||
@@ -104,5 +110,13 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
|
|||||||
result.storage = NewStorage(result.driver)
|
result.storage = NewStorage(result.driver)
|
||||||
result.queue = NewQueue(result.driver)
|
result.queue = NewQueue(result.driver)
|
||||||
|
|
||||||
|
if result.OllamaURL != "" {
|
||||||
|
result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel)
|
||||||
|
} else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" {
|
||||||
|
result.ai = NewAILocal(result.LocalCheckpoint, result.LocalTokenizer, 0.9, 128, 0.9)
|
||||||
|
} else {
|
||||||
|
result.ai = NewAINoop()
|
||||||
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|||||||
9
go.mod
9
go.mod
@@ -6,7 +6,14 @@ require (
|
|||||||
github.com/go-errors/errors v1.5.1
|
github.com/go-errors/errors v1.5.1
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
github.com/nikolaydubina/llama2.go v0.7.1
|
github.com/nikolaydubina/llama2.go v0.7.1
|
||||||
|
github.com/tmc/langchaingo v0.1.8
|
||||||
go.etcd.io/bbolt v1.3.9
|
go.etcd.io/bbolt v1.3.9
|
||||||
)
|
)
|
||||||
|
|
||||||
require golang.org/x/sys v0.4.0 // indirect
|
require (
|
||||||
|
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||||
|
github.com/gage-technologies/mistral-go v1.0.0 // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
|
||||||
|
golang.org/x/sys v0.16.0 // indirect
|
||||||
|
)
|
||||||
|
|||||||
22
go.sum
22
go.sum
@@ -1,20 +1,30 @@
|
|||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||||
|
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
|
github.com/gage-technologies/mistral-go v1.0.0 h1:Hwk0uJO+Iq4kMX/EwbfGRUq9zkO36w7HZ/g53N4N73A=
|
||||||
|
github.com/gage-technologies/mistral-go v1.0.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
|
||||||
github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk=
|
github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk=
|
||||||
github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE=
|
github.com/nikolaydubina/llama2.go v0.7.1 h1:ORmH1XbwFYGIOPHprkjtUPOEovlVXhnmnMjbMckaSyE=
|
||||||
github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM=
|
github.com/nikolaydubina/llama2.go v0.7.1/go.mod h1:ggXhXOaDnEAgSSkcYsomqx/RLjInxe5ZAbcJ+/Y2mTM=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/tmc/langchaingo v0.1.8 h1:nrImgh0aWdu3stJTHz80N60WGwPWY8HXCK10gQny7bA=
|
||||||
|
github.com/tmc/langchaingo v0.1.8/go.mod h1:iNBfS9e6jxBKsJSPWnlqNhoVWgdA3D1g5cdFJjbIZNQ=
|
||||||
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
|
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
|
||||||
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
|
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
|
||||||
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
|
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
|
||||||
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
|
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
|
||||||
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
23
testdata/llama2.d/ingest.sh
vendored
23
testdata/llama2.d/ingest.sh
vendored
@@ -1,23 +0,0 @@
|
|||||||
#! /bin/bash
|
|
||||||
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
cd "$(dirname "$(realpath "$BASH_SOURCE")")"
|
|
||||||
if ! test -d ./llama2.c; then
|
|
||||||
echo downloading llama2.c >&2
|
|
||||||
git clone https://github.com/karpathy/llama2.c.git
|
|
||||||
fi
|
|
||||||
|
|
||||||
url="$1"
|
|
||||||
output="${url##*/}"
|
|
||||||
if [ ! -f "$output" ]; then
|
|
||||||
echo downloading "$url" >&2
|
|
||||||
wget -O "$output" "$url"
|
|
||||||
fi
|
|
||||||
|
|
||||||
cd ./llama2.c
|
|
||||||
pip3 install --break-system-packages -r ./requirements.txt
|
|
||||||
python3 ./export.py \
|
|
||||||
"../$output.bin" \
|
|
||||||
--meta-llama "../$output"
|
|
||||||
cd ..
|
|
||||||
Reference in New Issue
Block a user