From d5b09db0c638343b9d582d74c4d6723544408024 Mon Sep 17 00:00:00 2001 From: bel Date: Fri, 12 Apr 2024 23:18:17 -0600 Subject: [PATCH] ai not lookin great --- ai.go | 29 ++++++++++++----------------- ai_test.go | 8 ++++---- config.go | 16 +++++++++++++++- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/ai.go b/ai.go index 6f0d51b..4116b5f 100644 --- a/ai.go +++ b/ai.go @@ -9,7 +9,6 @@ import ( nn "github.com/nikolaydubina/llama2.go/exp/nnfast" "github.com/nikolaydubina/llama2.go/llama2" "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/mistral" "github.com/tmc/langchaingo/llms/ollama" ) @@ -17,35 +16,31 @@ type AI interface { Do(context.Context, string) (string, error) } -type AIMistral struct { - model string +type AINoop struct { } -func NewAIMistral(model string) AIMistral { - return AIMistral{model: model} +func NewAINoop() AINoop { + return AINoop{} } -func (ai AIMistral) Do(ctx context.Context, prompt string) (string, error) { - llm, err := mistral.New(mistral.WithModel(ai.model)) - if err != nil { - return "", err - } - return llms.GenerateFromSinglePrompt(ctx, llm, prompt, - llms.WithTemperature(0.8), - llms.WithModel("mistral-small-latest"), - ) +func (ai AINoop) Do(ctx context.Context, prompt string) (string, error) { + return ":shrug:", nil } type AIOllama struct { model string + url string } -func NewAIOllama(model string) AIOllama { - return AIOllama{model: model} +func NewAIOllama(url, model string) AIOllama { + return AIOllama{url: url, model: model} } func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) { - llm, err := ollama.New(ollama.WithModel(ai.model)) + llm, err := ollama.New( + ollama.WithModel(ai.model), + ollama.WithServerURL(ai.url), + ) if err != nil { return "", err } diff --git a/ai_test.go b/ai_test.go index 4fba4c1..b7dde8f 100644 --- a/ai_test.go +++ b/ai_test.go @@ -13,14 +13,14 @@ import ( "time" ) -func TestAIMistral(t *testing.T) { - ai := NewAIMistral("open-mistral-7b") +func TestAINoop(t *testing.T) { + ai := NewAINoop() testAI(t, ai) } func TestAIOllama(t *testing.T) { - ai := NewAIOllama("gemma:2b") + ai := NewAIOllama("http://localhost:11434", "gemma:2b") testAI(t, ai) } @@ -77,7 +77,7 @@ func testAI(t *testing.T, ai AI) { t.Run("mvp", func(t *testing.T) { if result, err := ai.Do(ctx, "hello world"); err != nil { t.Fatal(err) - } else if len(result) < 250 { + } else if len(result) < 3 { t.Error(result) } else { t.Logf("%s", result) diff --git a/config.go b/config.go index d383f8f..4c4c851 100644 --- a/config.go +++ b/config.go @@ -21,9 +21,14 @@ type Config struct { BasicAuthUser string BasicAuthPassword string FillWithTestdata bool + OllamaURL string + OllamaModel string + LocalCheckpoint string + LocalTokenizer string storage Storage queue Queue driver Driver + ai AI } func newConfig(ctx context.Context) (Config, error) { @@ -32,7 +37,8 @@ func newConfig(ctx context.Context) (Config, error) { func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) { def := Config{ - Port: 8080, + Port: 8080, + OllamaModel: "gemma:2b", } var m map[string]any @@ -104,5 +110,13 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, result.storage = NewStorage(result.driver) result.queue = NewQueue(result.driver) + if result.OllamaURL != "" { + result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel) + } else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" { + result.ai = NewAILocal(result.LocalCheckpoint, result.LocalTokenizer, 0.9, 128, 0.9) + } else { + result.ai = NewAINoop() + } + return result, nil }