ai not lookin great

main
bel 2024-04-12 23:18:17 -06:00
parent f288a9b098
commit d5b09db0c6
3 changed files with 31 additions and 22 deletions

29
ai.go
View File

@ -9,7 +9,6 @@ import (
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
"github.com/nikolaydubina/llama2.go/llama2"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/mistral"
"github.com/tmc/langchaingo/llms/ollama"
)
@ -17,35 +16,31 @@ type AI interface {
Do(context.Context, string) (string, error)
}
type AIMistral struct {
model string
type AINoop struct {
}
func NewAIMistral(model string) AIMistral {
return AIMistral{model: model}
func NewAINoop() AINoop {
return AINoop{}
}
func (ai AIMistral) Do(ctx context.Context, prompt string) (string, error) {
llm, err := mistral.New(mistral.WithModel(ai.model))
if err != nil {
return "", err
}
return llms.GenerateFromSinglePrompt(ctx, llm, prompt,
llms.WithTemperature(0.8),
llms.WithModel("mistral-small-latest"),
)
func (ai AINoop) Do(ctx context.Context, prompt string) (string, error) {
return ":shrug:", nil
}
type AIOllama struct {
model string
url string
}
func NewAIOllama(model string) AIOllama {
return AIOllama{model: model}
func NewAIOllama(url, model string) AIOllama {
return AIOllama{url: url, model: model}
}
func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) {
llm, err := ollama.New(ollama.WithModel(ai.model))
llm, err := ollama.New(
ollama.WithModel(ai.model),
ollama.WithServerURL(ai.url),
)
if err != nil {
return "", err
}

View File

@ -13,14 +13,14 @@ import (
"time"
)
func TestAIMistral(t *testing.T) {
ai := NewAIMistral("open-mistral-7b")
func TestAINoop(t *testing.T) {
ai := NewAINoop()
testAI(t, ai)
}
func TestAIOllama(t *testing.T) {
ai := NewAIOllama("gemma:2b")
ai := NewAIOllama("http://localhost:11434", "gemma:2b")
testAI(t, ai)
}
@ -77,7 +77,7 @@ func testAI(t *testing.T, ai AI) {
t.Run("mvp", func(t *testing.T) {
if result, err := ai.Do(ctx, "hello world"); err != nil {
t.Fatal(err)
} else if len(result) < 250 {
} else if len(result) < 3 {
t.Error(result)
} else {
t.Logf("%s", result)

View File

@ -21,9 +21,14 @@ type Config struct {
BasicAuthUser string
BasicAuthPassword string
FillWithTestdata bool
OllamaURL string
OllamaModel string
LocalCheckpoint string
LocalTokenizer string
storage Storage
queue Queue
driver Driver
ai AI
}
func newConfig(ctx context.Context) (Config, error) {
@ -32,7 +37,8 @@ func newConfig(ctx context.Context) (Config, error) {
func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) {
def := Config{
Port: 8080,
Port: 8080,
OllamaModel: "gemma:2b",
}
var m map[string]any
@ -104,5 +110,13 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
result.storage = NewStorage(result.driver)
result.queue = NewQueue(result.driver)
if result.OllamaURL != "" {
result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel)
} else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" {
result.ai = NewAILocal(result.LocalCheckpoint, result.LocalTokenizer, 0.9, 128, 0.9)
} else {
result.ai = NewAINoop()
}
return result, nil
}