ai not lookin great
parent
f288a9b098
commit
d5b09db0c6
29
ai.go
29
ai.go
|
|
@ -9,7 +9,6 @@ import (
|
||||||
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
nn "github.com/nikolaydubina/llama2.go/exp/nnfast"
|
||||||
"github.com/nikolaydubina/llama2.go/llama2"
|
"github.com/nikolaydubina/llama2.go/llama2"
|
||||||
"github.com/tmc/langchaingo/llms"
|
"github.com/tmc/langchaingo/llms"
|
||||||
"github.com/tmc/langchaingo/llms/mistral"
|
|
||||||
"github.com/tmc/langchaingo/llms/ollama"
|
"github.com/tmc/langchaingo/llms/ollama"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -17,35 +16,31 @@ type AI interface {
|
||||||
Do(context.Context, string) (string, error)
|
Do(context.Context, string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AIMistral struct {
|
type AINoop struct {
|
||||||
model string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAIMistral(model string) AIMistral {
|
func NewAINoop() AINoop {
|
||||||
return AIMistral{model: model}
|
return AINoop{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ai AIMistral) Do(ctx context.Context, prompt string) (string, error) {
|
func (ai AINoop) Do(ctx context.Context, prompt string) (string, error) {
|
||||||
llm, err := mistral.New(mistral.WithModel(ai.model))
|
return ":shrug:", nil
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return llms.GenerateFromSinglePrompt(ctx, llm, prompt,
|
|
||||||
llms.WithTemperature(0.8),
|
|
||||||
llms.WithModel("mistral-small-latest"),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AIOllama struct {
|
type AIOllama struct {
|
||||||
model string
|
model string
|
||||||
|
url string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAIOllama(model string) AIOllama {
|
func NewAIOllama(url, model string) AIOllama {
|
||||||
return AIOllama{model: model}
|
return AIOllama{url: url, model: model}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) {
|
func (ai AIOllama) Do(ctx context.Context, prompt string) (string, error) {
|
||||||
llm, err := ollama.New(ollama.WithModel(ai.model))
|
llm, err := ollama.New(
|
||||||
|
ollama.WithModel(ai.model),
|
||||||
|
ollama.WithServerURL(ai.url),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,14 +13,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAIMistral(t *testing.T) {
|
func TestAINoop(t *testing.T) {
|
||||||
ai := NewAIMistral("open-mistral-7b")
|
ai := NewAINoop()
|
||||||
|
|
||||||
testAI(t, ai)
|
testAI(t, ai)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAIOllama(t *testing.T) {
|
func TestAIOllama(t *testing.T) {
|
||||||
ai := NewAIOllama("gemma:2b")
|
ai := NewAIOllama("http://localhost:11434", "gemma:2b")
|
||||||
|
|
||||||
testAI(t, ai)
|
testAI(t, ai)
|
||||||
}
|
}
|
||||||
|
|
@ -77,7 +77,7 @@ func testAI(t *testing.T, ai AI) {
|
||||||
t.Run("mvp", func(t *testing.T) {
|
t.Run("mvp", func(t *testing.T) {
|
||||||
if result, err := ai.Do(ctx, "hello world"); err != nil {
|
if result, err := ai.Do(ctx, "hello world"); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if len(result) < 250 {
|
} else if len(result) < 3 {
|
||||||
t.Error(result)
|
t.Error(result)
|
||||||
} else {
|
} else {
|
||||||
t.Logf("%s", result)
|
t.Logf("%s", result)
|
||||||
|
|
|
||||||
14
config.go
14
config.go
|
|
@ -21,9 +21,14 @@ type Config struct {
|
||||||
BasicAuthUser string
|
BasicAuthUser string
|
||||||
BasicAuthPassword string
|
BasicAuthPassword string
|
||||||
FillWithTestdata bool
|
FillWithTestdata bool
|
||||||
|
OllamaURL string
|
||||||
|
OllamaModel string
|
||||||
|
LocalCheckpoint string
|
||||||
|
LocalTokenizer string
|
||||||
storage Storage
|
storage Storage
|
||||||
queue Queue
|
queue Queue
|
||||||
driver Driver
|
driver Driver
|
||||||
|
ai AI
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConfig(ctx context.Context) (Config, error) {
|
func newConfig(ctx context.Context) (Config, error) {
|
||||||
|
|
@ -33,6 +38,7 @@ func newConfig(ctx context.Context) (Config, error) {
|
||||||
func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) {
|
func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) {
|
||||||
def := Config{
|
def := Config{
|
||||||
Port: 8080,
|
Port: 8080,
|
||||||
|
OllamaModel: "gemma:2b",
|
||||||
}
|
}
|
||||||
|
|
||||||
var m map[string]any
|
var m map[string]any
|
||||||
|
|
@ -104,5 +110,13 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
|
||||||
result.storage = NewStorage(result.driver)
|
result.storage = NewStorage(result.driver)
|
||||||
result.queue = NewQueue(result.driver)
|
result.queue = NewQueue(result.driver)
|
||||||
|
|
||||||
|
if result.OllamaURL != "" {
|
||||||
|
result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel)
|
||||||
|
} else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" {
|
||||||
|
result.ai = NewAILocal(result.LocalCheckpoint, result.LocalTokenizer, 0.9, 128, 0.9)
|
||||||
|
} else {
|
||||||
|
result.ai = NewAINoop()
|
||||||
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue