ai not lookin great
parent
f288a9b098
commit
d5b09db0c6
29
ai.go
29
ai.go
|
|
@ -9,7 +9,6 @@ import (
|
|||
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
||||
"github.com/nikolaydubina/llama2.go/llama2"
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/mistral"
|
||||
"github.com/tmc/langchaingo/llms/ollama"
|
||||
)
|
||||
|
||||
|
|
@ -17,35 +16,31 @@ type AI interface {
|
|||
Do(context.Context, string) (string, error)
|
||||
}
|
||||
|
||||
type AIMistral struct {
|
||||
model string
|
||||
type AINoop struct {
|
||||
}
|
||||
|
||||
func NewAIMistral(model string) AIMistral {
|
||||
return AIMistral{model: model}
|
||||
func NewAINoop() AINoop {
|
||||
return AINoop{}
|
||||
}
|
||||
|
||||
func (ai AIMistral) Do(ctx context.Context, prompt string) (string, error) {
|
||||
llm, err := mistral.New(mistral.WithModel(ai.model))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return llms.GenerateFromSinglePrompt(ctx, llm, prompt,
|
||||
llms.WithTemperature(0.8),
|
||||
llms.WithModel("mistral-small-latest"),
|
||||
)
|
||||
func (ai AINoop) Do(ctx context.Context, prompt string) (string, error) {
|
||||
return ":shrug:", nil
|
||||
}
|
||||
|
||||
type AIOllama struct {
|
||||
model string
|
||||
url string
|
||||
}
|
||||
|
||||
func NewAIOllama(model string) AIOllama {
|
||||
return AIOllama{model: model}
|
||||
func NewAIOllama(url, model string) AIOllama {
|
||||
return AIOllama{url: url, model: model}
|
||||
}
|
||||
|
||||
func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) {
|
||||
llm, err := ollama.New(ollama.WithModel(ai.model))
|
||||
llm, err := ollama.New(
|
||||
ollama.WithModel(ai.model),
|
||||
ollama.WithServerURL(ai.url),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,14 +13,14 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func TestAIMistral(t *testing.T) {
|
||||
ai := NewAIMistral("open-mistral-7b")
|
||||
func TestAINoop(t *testing.T) {
|
||||
ai := NewAINoop()
|
||||
|
||||
testAI(t, ai)
|
||||
}
|
||||
|
||||
func TestAIOllama(t *testing.T) {
|
||||
ai := NewAIOllama("gemma:2b")
|
||||
ai := NewAIOllama("http://localhost:11434", "gemma:2b")
|
||||
|
||||
testAI(t, ai)
|
||||
}
|
||||
|
|
@ -77,7 +77,7 @@ func testAI(t *testing.T, ai AI) {
|
|||
t.Run("mvp", func(t *testing.T) {
|
||||
if result, err := ai.Do(ctx, "hello world"); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(result) < 250 {
|
||||
} else if len(result) < 3 {
|
||||
t.Error(result)
|
||||
} else {
|
||||
t.Logf("%s", result)
|
||||
|
|
|
|||
16
config.go
16
config.go
|
|
@ -21,9 +21,14 @@ type Config struct {
|
|||
BasicAuthUser string
|
||||
BasicAuthPassword string
|
||||
FillWithTestdata bool
|
||||
OllamaURL string
|
||||
OllamaModel string
|
||||
LocalCheckpoint string
|
||||
LocalTokenizer string
|
||||
storage Storage
|
||||
queue Queue
|
||||
driver Driver
|
||||
ai AI
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context) (Config, error) {
|
||||
|
|
@ -32,7 +37,8 @@ func newConfig(ctx context.Context) (Config, error) {
|
|||
|
||||
func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) {
|
||||
def := Config{
|
||||
Port: 8080,
|
||||
Port: 8080,
|
||||
OllamaModel: "gemma:2b",
|
||||
}
|
||||
|
||||
var m map[string]any
|
||||
|
|
@ -104,5 +110,13 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
|
|||
result.storage = NewStorage(result.driver)
|
||||
result.queue = NewQueue(result.driver)
|
||||
|
||||
if result.OllamaURL != "" {
|
||||
result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel)
|
||||
} else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" {
|
||||
result.ai = NewAILocal(result.LocalCheckpoint, result.LocalTokenizer, 0.9, 128, 0.9)
|
||||
} else {
|
||||
result.ai = NewAINoop()
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue