STORAGE TEST WERKS

main
Bel LaPointe 2024-04-16 06:43:25 -06:00
parent 0cecd5ea04
commit d70a0e313f
5 changed files with 214 additions and 14 deletions

View File

@ -30,6 +30,7 @@ type Config struct {
DatacenterPattern string DatacenterPattern string
EventNamePattern string EventNamePattern string
driver Driver driver Driver
storage Storage
ai AI ai AI
slackToMessagePipeline Pipeline slackToMessagePipeline Pipeline
messageToPersistencePipeline Pipeline messageToPersistencePipeline Pipeline
@ -119,6 +120,12 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
return Config{}, errors.New("not impl") return Config{}, errors.New("not impl")
} }
storage, err := NewStorage(ctx, result.driver)
if err != nil {
return Config{}, err
}
result.storage = storage
if result.OllamaURL != "" { if result.OllamaURL != "" {
result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel) result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel)
} else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" { } else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" {

View File

@ -17,8 +17,12 @@ type Driver struct {
*sql.DB *sql.DB
} }
func NewTestDriver(t *testing.T) Driver { func NewTestDriver(t *testing.T, optionalP ...string) Driver {
driver, err := NewDriver(context.Background(), path.Join(t.TempDir(), "db")) p := path.Join(t.TempDir(), "db")
if len(optionalP) > 0 {
p = optionalP[0]
}
driver, err := NewDriver(context.Background(), p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"fmt"
) )
type MessageToPersistence struct { type MessageToPersistence struct {
@ -21,26 +20,18 @@ func NewMessageToPersistencePipeline(ctx context.Context, cfg Config) (Pipeline,
return Pipeline{ return Pipeline{
writer: writer, writer: writer,
reader: reader, reader: reader,
process: newMessageToPersistenceProcess(cfg.driver), process: newMessageToPersistenceProcess(cfg.storage),
}, nil }, nil
} }
func newMessageToPersistenceProcess(driver Driver) processFunc { func newMessageToPersistenceProcess(storage Storage) processFunc {
return func(ctx context.Context, msg []byte) ([]byte, error) { return func(ctx context.Context, msg []byte) ([]byte, error) {
m, err := Deserialize(msg) m, err := Deserialize(msg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if result, err := driver.ExecContext(ctx, ` if err := storage.UpsertMessage(ctx, m); err != nil {
CREATE TABLE IF NOT EXISTS messages (id TEXT UNIQUE, v TEXT);
INSERT INTO messages (id, v) VALUES (?, ?)
ON CONFLICT(id) DO UPDATE set v = ?;
`, m.ID, msg, msg); err != nil {
return nil, err return nil, err
} else if n, err := result.RowsAffected(); err != nil {
return nil, err
} else if n != 1 {
return nil, fmt.Errorf("upserting event to persistence modified %v rows", n)
} }
return msg, nil return msg, nil
} }

156
storage.go Normal file
View File

@ -0,0 +1,156 @@
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/breel-render/spoc-bot-vr/model"
)
type Storage struct {
driver Driver
}
func NewStorage(ctx context.Context, driver Driver) (Storage, error) {
if _, err := driver.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS events (ID TEXT UNIQUE);
CREATE TABLE IF NOT EXISTS messages (ID TEXT UNIQUE);
CREATE TABLE IF NOT EXISTS threads (ID TEXT UNIQUE);
`); err != nil {
return Storage{}, err
}
for table, v := range map[string]any{
"events": model.Event{},
"messages": model.Message{},
"threads": model.Thread{},
} {
b, _ := json.Marshal(v)
var m map[string]struct{}
json.Unmarshal(b, &m)
for k := range m {
if k == `ID` {
continue
}
driver.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN %s TEXT`, table, k))
}
}
return Storage{driver: driver}, nil
}
func (s Storage) GetEvent(ctx context.Context, ID string) (model.Event, error) {
return model.Event{}, errors.New("not impl")
}
func (s Storage) UpsertEvent(ctx context.Context, event model.Event) error {
return s.upsert(ctx, "events", event)
}
func (s Storage) GetMessage(ctx context.Context, ID string) (model.Message, error) {
v := model.Message{}
err := s.selectOne(ctx, "messages", &v, "ID = ?", ID)
return v, err
}
func (s Storage) UpsertMessage(ctx context.Context, message model.Message) error {
return s.upsert(ctx, "messages", message)
}
func (s Storage) GetThread(ctx context.Context, ID string) (model.Thread, error) {
return model.Thread{}, errors.New("not impl")
}
func (s Storage) UpsertThread(ctx context.Context, thread model.Thread) error {
return s.upsert(ctx, "threads", thread)
}
func (s Storage) selectOne(ctx context.Context, table string, v any, clause string, args ...any) error {
if questions := strings.Count(clause, "?"); questions != len(args) {
return fmt.Errorf("expected %v args for clause but found %v", questions, len(args))
}
keys, _, _, _, err := keysArgsKeyargsValues(v)
if err != nil {
return err
}
for i := range args {
args[i], _ = json.Marshal(args[i])
}
q := fmt.Sprintf(`
SELECT %s FROM %s WHERE %s
`, strings.Join(keys, ", "), table, clause)
row := s.driver.QueryRowContext(ctx, q, args...)
if err := row.Err(); err != nil {
return err
}
scanTargets := make([]any, len(keys))
for i := range scanTargets {
scanTargets[i] = &[]byte{}
}
if err := row.Scan(scanTargets...); err != nil {
return err
}
m := map[string]json.RawMessage{}
for i, k := range keys {
m[k] = *scanTargets[i].(*[]byte)
}
b, _ := json.Marshal(m)
return json.Unmarshal(b, v)
}
func (s Storage) upsert(ctx context.Context, table string, v any) error {
keys, args, keyArgs, values, err := keysArgsKeyargsValues(v)
if err != nil || len(keys) == 0 {
return err
}
for i := range keys {
values = append(values, values[i])
}
q := fmt.Sprintf(`
INSERT INTO %s (%s) VALUES (%s)
ON CONFLICT (ID) DO UPDATE SET %s
`, table, strings.Join(keys, ", "), strings.Join(args, ", "), strings.Join(keyArgs, ", "))
if result, err := s.driver.ExecContext(ctx, q, values...); err != nil {
return err
} else if n, err := result.RowsAffected(); err != nil {
return err
} else if n != 1 {
return fmt.Errorf("UpsertMessage affected %v rows", n)
}
return nil
}
func keysArgsKeyargsValues(v any) ([]string, []string, []string, []any, error) {
b, _ := json.Marshal(v)
var m map[string]json.RawMessage
err := json.Unmarshal(b, &m)
keys := []string{}
for k := range m {
keys = append(keys, k)
}
args := make([]string, len(keys))
for i := range args {
args[i] = "?"
}
keyArgs := make([]string, len(keys))
for i := range keyArgs {
keyArgs[i] = fmt.Sprintf("%s=?", keys[i])
}
values := make([]any, len(keys))
for i := range values {
values[i] = []byte(m[keys[i]])
}
return keys, args, keyArgs, values, err
}

42
storage_test.go Normal file
View File

@ -0,0 +1,42 @@
package main
import (
"context"
"testing"
"time"
"github.com/breel-render/spoc-bot-vr/model"
)
func TestStorage(t *testing.T) {
ctx, can := context.WithTimeout(context.Background(), time.Minute)
defer can()
s, err := NewStorage(ctx, NewTestDriver(t, "/tmp/ff"))
if err != nil {
t.Fatal(err)
}
t.Run("upsert get message", func(t *testing.T) {
m := model.NewMessage(
"ID",
"URL",
1,
"Author",
"Plaintext",
"ThreadID",
)
if err := s.UpsertMessage(ctx, m); err != nil {
t.Fatal("unexpected error on insert:", err)
} else if err := s.UpsertMessage(ctx, m); err != nil {
t.Fatal("unexpected error on noop update:", err)
}
if got, err := s.GetMessage(ctx, m.ID); err != nil {
t.Fatal("unexpected error on get:", err)
} else if got != m {
t.Fatal("unexpected result from get:", got)
}
})
}