spoc-bot-vr/config.go

150 lines
3.7 KiB
Go

package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"regexp"
"strconv"
"strings"
"time"
)
type Config struct {
Port int
Debug bool
InitializeSlack bool
SlackToken string
SlackChannels []string
DriverConn string
BasicAuthUser string
BasicAuthPassword string
FillWithTestdata bool
OllamaURL string
OllamaModel string
LocalCheckpoint string
LocalTokenizer string
AssetPattern string
DatacenterPattern string
EventNamePattern string
driver Driver
storage Storage
ai AI
slackToModelPipeline Pipeline
modelToPersistencePipeline Pipeline
}
var (
renderAssetPattern = `(dpg|svc|red)-[a-z0-9-]*[a-z0-9]`
renderDatacenterPattern = `[a-z]{4}[a-z]*-[0-9]`
renderEventNamePattern = `(\[[^\]]*\] *)?(?P<result>.*)`
)
func newConfig(ctx context.Context) (Config, error) {
return newConfigFromEnv(ctx, os.Getenv)
}
func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, error) {
def := Config{
Port: 38080,
OllamaModel: "gemma:2b",
AssetPattern: renderAssetPattern,
DatacenterPattern: renderDatacenterPattern,
EventNamePattern: renderEventNamePattern,
}
var m map[string]any
if b, err := json.Marshal(def); err != nil {
return Config{}, err
} else if err := json.Unmarshal(b, &m); err != nil {
return Config{}, err
}
re := regexp.MustCompile(`[A-Z]`)
for k, v := range m {
envK := k
idxes := re.FindAllIndex([]byte(envK), -1)
for i := len(idxes) - 1; i >= 0; i-- {
idx := idxes[i]
if idx[0] > 0 {
envK = fmt.Sprintf("%s_%s", envK[:idx[0]], envK[idx[0]:])
}
}
envK = strings.ToUpper(envK)
s := getEnv(envK)
if s == "" {
continue
}
switch v.(type) {
case string:
m[k] = s
case int64, float64:
n, err := strconv.ParseFloat(s, 32)
if err != nil {
return Config{}, err
}
m[k] = n
case bool:
got, err := strconv.ParseBool(s)
if err != nil {
return Config{}, err
}
m[k] = got
case nil, []interface{}:
m[k] = strings.Split(s, ",")
default:
return Config{}, fmt.Errorf("not impl: parse %s as %T", envK, v)
}
}
var result Config
if b, err := json.Marshal(m); err != nil {
return Config{}, err
} else if err := json.Unmarshal(b, &result); err != nil {
return Config{}, err
}
ctx, can := context.WithTimeout(ctx, time.Second*10)
defer can()
driver, err := NewDriver(ctx, result.DriverConn)
if err != nil {
return Config{}, err
}
result.driver = driver
if !result.FillWithTestdata {
//} else if err := result.driver.FillWithTestdata(ctx, result.AssetPattern, result.DatacenterPattern, result.EventNamePattern); err != nil {
} else {
return Config{}, errors.New("not impl")
}
storage, err := NewStorage(ctx, result.driver)
if err != nil {
return Config{}, err
}
result.storage = storage
if result.OllamaURL != "" {
result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel)
} else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" {
result.ai = NewAILocal(result.LocalCheckpoint, result.LocalTokenizer, 0.9, 128, 0.9)
} else {
result.ai = NewAINoop()
}
slackToModelPipeline, err := NewSlackToModelPipeline(ctx, result)
if err != nil {
return Config{}, err
}
result.slackToModelPipeline = slackToModelPipeline
modelToPersistencePipeline, err := NewModelToPersistencePipeline(ctx, result)
if err != nil {
return Config{}, err
}
result.modelToPersistencePipeline = modelToPersistencePipeline
return result, nil
}