From d70a0e313f929bd51c31dcdf9e46ea9a3b613f4f Mon Sep 17 00:00:00 2001 From: Bel LaPointe <153096461+breel-render@users.noreply.github.com> Date: Tue, 16 Apr 2024 06:43:25 -0600 Subject: [PATCH] STORAGE TEST WERKS --- config.go | 7 +++ driver.go | 8 ++- persistence.go | 15 +---- storage.go | 156 ++++++++++++++++++++++++++++++++++++++++++++++++ storage_test.go | 42 +++++++++++++ 5 files changed, 214 insertions(+), 14 deletions(-) create mode 100644 storage.go create mode 100644 storage_test.go diff --git a/config.go b/config.go index dea18ab..315014c 100644 --- a/config.go +++ b/config.go @@ -30,6 +30,7 @@ type Config struct { DatacenterPattern string EventNamePattern string driver Driver + storage Storage ai AI slackToMessagePipeline Pipeline messageToPersistencePipeline Pipeline @@ -119,6 +120,12 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config, return Config{}, errors.New("not impl") } + storage, err := NewStorage(ctx, result.driver) + if err != nil { + return Config{}, err + } + result.storage = storage + if result.OllamaURL != "" { result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel) } else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" { diff --git a/driver.go b/driver.go index d32bec8..25955c4 100644 --- a/driver.go +++ b/driver.go @@ -17,8 +17,12 @@ type Driver struct { *sql.DB } -func NewTestDriver(t *testing.T) Driver { - driver, err := NewDriver(context.Background(), path.Join(t.TempDir(), "db")) +func NewTestDriver(t *testing.T, optionalP ...string) Driver { + p := path.Join(t.TempDir(), "db") + if len(optionalP) > 0 { + p = optionalP[0] + } + driver, err := NewDriver(context.Background(), p) if err != nil { t.Fatal(err) } diff --git a/persistence.go b/persistence.go index f34264a..8fbc9f7 100644 --- a/persistence.go +++ b/persistence.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" ) type MessageToPersistence struct { @@ -21,26 +20,18 @@ func NewMessageToPersistencePipeline(ctx context.Context, cfg Config) (Pipeline, return Pipeline{ writer: writer, reader: reader, - process: newMessageToPersistenceProcess(cfg.driver), + process: newMessageToPersistenceProcess(cfg.storage), }, nil } -func newMessageToPersistenceProcess(driver Driver) processFunc { +func newMessageToPersistenceProcess(storage Storage) processFunc { return func(ctx context.Context, msg []byte) ([]byte, error) { m, err := Deserialize(msg) if err != nil { return nil, err } - if result, err := driver.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS messages (id TEXT UNIQUE, v TEXT); - INSERT INTO messages (id, v) VALUES (?, ?) - ON CONFLICT(id) DO UPDATE set v = ?; - `, m.ID, msg, msg); err != nil { + if err := storage.UpsertMessage(ctx, m); err != nil { return nil, err - } else if n, err := result.RowsAffected(); err != nil { - return nil, err - } else if n != 1 { - return nil, fmt.Errorf("upserting event to persistence modified %v rows", n) } return msg, nil } diff --git a/storage.go b/storage.go new file mode 100644 index 0000000..69c786a --- /dev/null +++ b/storage.go @@ -0,0 +1,156 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/breel-render/spoc-bot-vr/model" +) + +type Storage struct { + driver Driver +} + +func NewStorage(ctx context.Context, driver Driver) (Storage, error) { + if _, err := driver.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS events (ID TEXT UNIQUE); + CREATE TABLE IF NOT EXISTS messages (ID TEXT UNIQUE); + CREATE TABLE IF NOT EXISTS threads (ID TEXT UNIQUE); + `); err != nil { + return Storage{}, err + } + + for table, v := range map[string]any{ + "events": model.Event{}, + "messages": model.Message{}, + "threads": model.Thread{}, + } { + b, _ := json.Marshal(v) + var m map[string]struct{} + json.Unmarshal(b, &m) + for k := range m { + if k == `ID` { + continue + } + driver.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN %s TEXT`, table, k)) + } + } + + return Storage{driver: driver}, nil +} + +func (s Storage) GetEvent(ctx context.Context, ID string) (model.Event, error) { + return model.Event{}, errors.New("not impl") +} + +func (s Storage) UpsertEvent(ctx context.Context, event model.Event) error { + return s.upsert(ctx, "events", event) +} + +func (s Storage) GetMessage(ctx context.Context, ID string) (model.Message, error) { + v := model.Message{} + err := s.selectOne(ctx, "messages", &v, "ID = ?", ID) + return v, err +} + +func (s Storage) UpsertMessage(ctx context.Context, message model.Message) error { + return s.upsert(ctx, "messages", message) +} + +func (s Storage) GetThread(ctx context.Context, ID string) (model.Thread, error) { + return model.Thread{}, errors.New("not impl") +} + +func (s Storage) UpsertThread(ctx context.Context, thread model.Thread) error { + return s.upsert(ctx, "threads", thread) +} + +func (s Storage) selectOne(ctx context.Context, table string, v any, clause string, args ...any) error { + if questions := strings.Count(clause, "?"); questions != len(args) { + return fmt.Errorf("expected %v args for clause but found %v", questions, len(args)) + } + + keys, _, _, _, err := keysArgsKeyargsValues(v) + if err != nil { + return err + } + + for i := range args { + args[i], _ = json.Marshal(args[i]) + } + + q := fmt.Sprintf(` + SELECT %s FROM %s WHERE %s + `, strings.Join(keys, ", "), table, clause) + row := s.driver.QueryRowContext(ctx, q, args...) + if err := row.Err(); err != nil { + return err + } + + scanTargets := make([]any, len(keys)) + for i := range scanTargets { + scanTargets[i] = &[]byte{} + } + if err := row.Scan(scanTargets...); err != nil { + return err + } + + m := map[string]json.RawMessage{} + for i, k := range keys { + m[k] = *scanTargets[i].(*[]byte) + } + b, _ := json.Marshal(m) + return json.Unmarshal(b, v) +} + +func (s Storage) upsert(ctx context.Context, table string, v any) error { + keys, args, keyArgs, values, err := keysArgsKeyargsValues(v) + if err != nil || len(keys) == 0 { + return err + } + + for i := range keys { + values = append(values, values[i]) + } + + q := fmt.Sprintf(` + INSERT INTO %s (%s) VALUES (%s) + ON CONFLICT (ID) DO UPDATE SET %s + `, table, strings.Join(keys, ", "), strings.Join(args, ", "), strings.Join(keyArgs, ", ")) + if result, err := s.driver.ExecContext(ctx, q, values...); err != nil { + return err + } else if n, err := result.RowsAffected(); err != nil { + return err + } else if n != 1 { + return fmt.Errorf("UpsertMessage affected %v rows", n) + } + return nil +} + +func keysArgsKeyargsValues(v any) ([]string, []string, []string, []any, error) { + b, _ := json.Marshal(v) + var m map[string]json.RawMessage + err := json.Unmarshal(b, &m) + + keys := []string{} + for k := range m { + keys = append(keys, k) + } + args := make([]string, len(keys)) + for i := range args { + args[i] = "?" + } + keyArgs := make([]string, len(keys)) + for i := range keyArgs { + keyArgs[i] = fmt.Sprintf("%s=?", keys[i]) + } + values := make([]any, len(keys)) + for i := range values { + values[i] = []byte(m[keys[i]]) + } + + return keys, args, keyArgs, values, err +} diff --git a/storage_test.go b/storage_test.go new file mode 100644 index 0000000..4581214 --- /dev/null +++ b/storage_test.go @@ -0,0 +1,42 @@ +package main + +import ( + "context" + "testing" + "time" + + "github.com/breel-render/spoc-bot-vr/model" +) + +func TestStorage(t *testing.T) { + ctx, can := context.WithTimeout(context.Background(), time.Minute) + defer can() + + s, err := NewStorage(ctx, NewTestDriver(t, "/tmp/ff")) + if err != nil { + t.Fatal(err) + } + + t.Run("upsert get message", func(t *testing.T) { + m := model.NewMessage( + "ID", + "URL", + 1, + "Author", + "Plaintext", + "ThreadID", + ) + + if err := s.UpsertMessage(ctx, m); err != nil { + t.Fatal("unexpected error on insert:", err) + } else if err := s.UpsertMessage(ctx, m); err != nil { + t.Fatal("unexpected error on noop update:", err) + } + + if got, err := s.GetMessage(ctx, m.ID); err != nil { + t.Fatal("unexpected error on get:", err) + } else if got != m { + t.Fatal("unexpected result from get:", got) + } + }) +}