STORAGE TEST WERKS
parent
0cecd5ea04
commit
d70a0e313f
|
|
@ -30,6 +30,7 @@ type Config struct {
|
|||
DatacenterPattern string
|
||||
EventNamePattern string
|
||||
driver Driver
|
||||
storage Storage
|
||||
ai AI
|
||||
slackToMessagePipeline Pipeline
|
||||
messageToPersistencePipeline Pipeline
|
||||
|
|
@ -119,6 +120,12 @@ func newConfigFromEnv(ctx context.Context, getEnv func(string) string) (Config,
|
|||
return Config{}, errors.New("not impl")
|
||||
}
|
||||
|
||||
storage, err := NewStorage(ctx, result.driver)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
result.storage = storage
|
||||
|
||||
if result.OllamaURL != "" {
|
||||
result.ai = NewAIOllama(result.OllamaURL, result.OllamaModel)
|
||||
} else if result.LocalCheckpoint != "" && result.LocalTokenizer != "" {
|
||||
|
|
|
|||
|
|
@ -17,8 +17,12 @@ type Driver struct {
|
|||
*sql.DB
|
||||
}
|
||||
|
||||
func NewTestDriver(t *testing.T) Driver {
|
||||
driver, err := NewDriver(context.Background(), path.Join(t.TempDir(), "db"))
|
||||
func NewTestDriver(t *testing.T, optionalP ...string) Driver {
|
||||
p := path.Join(t.TempDir(), "db")
|
||||
if len(optionalP) > 0 {
|
||||
p = optionalP[0]
|
||||
}
|
||||
driver, err := NewDriver(context.Background(), p)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type MessageToPersistence struct {
|
||||
|
|
@ -21,26 +20,18 @@ func NewMessageToPersistencePipeline(ctx context.Context, cfg Config) (Pipeline,
|
|||
return Pipeline{
|
||||
writer: writer,
|
||||
reader: reader,
|
||||
process: newMessageToPersistenceProcess(cfg.driver),
|
||||
process: newMessageToPersistenceProcess(cfg.storage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newMessageToPersistenceProcess(driver Driver) processFunc {
|
||||
func newMessageToPersistenceProcess(storage Storage) processFunc {
|
||||
return func(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
m, err := Deserialize(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result, err := driver.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS messages (id TEXT UNIQUE, v TEXT);
|
||||
INSERT INTO messages (id, v) VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE set v = ?;
|
||||
`, m.ID, msg, msg); err != nil {
|
||||
if err := storage.UpsertMessage(ctx, m); err != nil {
|
||||
return nil, err
|
||||
} else if n, err := result.RowsAffected(); err != nil {
|
||||
return nil, err
|
||||
} else if n != 1 {
|
||||
return nil, fmt.Errorf("upserting event to persistence modified %v rows", n)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/breel-render/spoc-bot-vr/model"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
driver Driver
|
||||
}
|
||||
|
||||
func NewStorage(ctx context.Context, driver Driver) (Storage, error) {
|
||||
if _, err := driver.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS events (ID TEXT UNIQUE);
|
||||
CREATE TABLE IF NOT EXISTS messages (ID TEXT UNIQUE);
|
||||
CREATE TABLE IF NOT EXISTS threads (ID TEXT UNIQUE);
|
||||
`); err != nil {
|
||||
return Storage{}, err
|
||||
}
|
||||
|
||||
for table, v := range map[string]any{
|
||||
"events": model.Event{},
|
||||
"messages": model.Message{},
|
||||
"threads": model.Thread{},
|
||||
} {
|
||||
b, _ := json.Marshal(v)
|
||||
var m map[string]struct{}
|
||||
json.Unmarshal(b, &m)
|
||||
for k := range m {
|
||||
if k == `ID` {
|
||||
continue
|
||||
}
|
||||
driver.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN %s TEXT`, table, k))
|
||||
}
|
||||
}
|
||||
|
||||
return Storage{driver: driver}, nil
|
||||
}
|
||||
|
||||
func (s Storage) GetEvent(ctx context.Context, ID string) (model.Event, error) {
|
||||
return model.Event{}, errors.New("not impl")
|
||||
}
|
||||
|
||||
func (s Storage) UpsertEvent(ctx context.Context, event model.Event) error {
|
||||
return s.upsert(ctx, "events", event)
|
||||
}
|
||||
|
||||
func (s Storage) GetMessage(ctx context.Context, ID string) (model.Message, error) {
|
||||
v := model.Message{}
|
||||
err := s.selectOne(ctx, "messages", &v, "ID = ?", ID)
|
||||
return v, err
|
||||
}
|
||||
|
||||
func (s Storage) UpsertMessage(ctx context.Context, message model.Message) error {
|
||||
return s.upsert(ctx, "messages", message)
|
||||
}
|
||||
|
||||
func (s Storage) GetThread(ctx context.Context, ID string) (model.Thread, error) {
|
||||
return model.Thread{}, errors.New("not impl")
|
||||
}
|
||||
|
||||
func (s Storage) UpsertThread(ctx context.Context, thread model.Thread) error {
|
||||
return s.upsert(ctx, "threads", thread)
|
||||
}
|
||||
|
||||
func (s Storage) selectOne(ctx context.Context, table string, v any, clause string, args ...any) error {
|
||||
if questions := strings.Count(clause, "?"); questions != len(args) {
|
||||
return fmt.Errorf("expected %v args for clause but found %v", questions, len(args))
|
||||
}
|
||||
|
||||
keys, _, _, _, err := keysArgsKeyargsValues(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range args {
|
||||
args[i], _ = json.Marshal(args[i])
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
SELECT %s FROM %s WHERE %s
|
||||
`, strings.Join(keys, ", "), table, clause)
|
||||
row := s.driver.QueryRowContext(ctx, q, args...)
|
||||
if err := row.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scanTargets := make([]any, len(keys))
|
||||
for i := range scanTargets {
|
||||
scanTargets[i] = &[]byte{}
|
||||
}
|
||||
if err := row.Scan(scanTargets...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := map[string]json.RawMessage{}
|
||||
for i, k := range keys {
|
||||
m[k] = *scanTargets[i].(*[]byte)
|
||||
}
|
||||
b, _ := json.Marshal(m)
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
||||
func (s Storage) upsert(ctx context.Context, table string, v any) error {
|
||||
keys, args, keyArgs, values, err := keysArgsKeyargsValues(v)
|
||||
if err != nil || len(keys) == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range keys {
|
||||
values = append(values, values[i])
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
INSERT INTO %s (%s) VALUES (%s)
|
||||
ON CONFLICT (ID) DO UPDATE SET %s
|
||||
`, table, strings.Join(keys, ", "), strings.Join(args, ", "), strings.Join(keyArgs, ", "))
|
||||
if result, err := s.driver.ExecContext(ctx, q, values...); err != nil {
|
||||
return err
|
||||
} else if n, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
} else if n != 1 {
|
||||
return fmt.Errorf("UpsertMessage affected %v rows", n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func keysArgsKeyargsValues(v any) ([]string, []string, []string, []any, error) {
|
||||
b, _ := json.Marshal(v)
|
||||
var m map[string]json.RawMessage
|
||||
err := json.Unmarshal(b, &m)
|
||||
|
||||
keys := []string{}
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
args := make([]string, len(keys))
|
||||
for i := range args {
|
||||
args[i] = "?"
|
||||
}
|
||||
keyArgs := make([]string, len(keys))
|
||||
for i := range keyArgs {
|
||||
keyArgs[i] = fmt.Sprintf("%s=?", keys[i])
|
||||
}
|
||||
values := make([]any, len(keys))
|
||||
for i := range values {
|
||||
values[i] = []byte(m[keys[i]])
|
||||
}
|
||||
|
||||
return keys, args, keyArgs, values, err
|
||||
}
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/breel-render/spoc-bot-vr/model"
|
||||
)
|
||||
|
||||
func TestStorage(t *testing.T) {
|
||||
ctx, can := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer can()
|
||||
|
||||
s, err := NewStorage(ctx, NewTestDriver(t, "/tmp/ff"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("upsert get message", func(t *testing.T) {
|
||||
m := model.NewMessage(
|
||||
"ID",
|
||||
"URL",
|
||||
1,
|
||||
"Author",
|
||||
"Plaintext",
|
||||
"ThreadID",
|
||||
)
|
||||
|
||||
if err := s.UpsertMessage(ctx, m); err != nil {
|
||||
t.Fatal("unexpected error on insert:", err)
|
||||
} else if err := s.UpsertMessage(ctx, m); err != nil {
|
||||
t.Fatal("unexpected error on noop update:", err)
|
||||
}
|
||||
|
||||
if got, err := s.GetMessage(ctx, m.ID); err != nil {
|
||||
t.Fatal("unexpected error on get:", err)
|
||||
} else if got != m {
|
||||
t.Fatal("unexpected result from get:", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
Loading…
Reference in New Issue