spoc-bot-vr/storage.go

258 lines
6.2 KiB
Go

package main
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/breel-render/spoc-bot-vr/model"
)
type Storage struct {
driver Driver
}
func NewStorage(ctx context.Context, driver Driver) (Storage, error) {
if _, err := driver.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS events (ID TEXT UNIQUE);
CREATE TABLE IF NOT EXISTS messages (ID TEXT UNIQUE);
CREATE TABLE IF NOT EXISTS threads (ID TEXT UNIQUE);
`); err != nil {
return Storage{}, err
}
for table, v := range map[string]any{
"events": model.Event{},
"messages": model.Message{},
"threads": model.Thread{},
} {
b, _ := json.Marshal(v)
var m map[string]struct{}
json.Unmarshal(b, &m)
for k := range m {
if k == `ID` {
continue
}
driver.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN %s TEXT`, table, k))
}
}
return Storage{driver: driver}, nil
}
func (s Storage) GetEvent(ctx context.Context, ID string) (model.Event, error) {
v := model.Event{}
err := s.selectOne(ctx, "events", &v, "ID = $1", ID)
return v, err
}
func (s Storage) UpsertEvent(ctx context.Context, event model.Event) error {
return s.upsert(ctx, "events", event)
}
func (s Storage) GetMessage(ctx context.Context, ID string) (model.Message, error) {
v := model.Message{}
err := s.selectOne(ctx, "messages", &v, "ID = $1", ID)
return v, err
}
func (s Storage) UpsertMessage(ctx context.Context, message model.Message) error {
return s.upsert(ctx, "messages", message)
}
func (s Storage) GetThread(ctx context.Context, ID string) (model.Thread, error) {
v := model.Thread{}
err := s.selectOne(ctx, "threads", &v, "ID = $1", ID)
return v, err
}
func (s Storage) GetEventThreads(ctx context.Context, ID string) ([]model.Thread, error) {
return s.selectThreadsWhere(ctx, "EventID = $1", ID)
}
func (s Storage) GetThreadMessages(ctx context.Context, ID string) ([]model.Message, error) {
return s.selectMessagesWhere(ctx, "ThreadID = $1", ID)
}
func (s Storage) UpsertThread(ctx context.Context, thread model.Thread) error {
return s.upsert(ctx, "threads", thread)
}
func (s Storage) selectThreadsWhere(ctx context.Context, clause string, args ...any) ([]model.Thread, error) {
keys, _, _, _, err := keysArgsKeyargsValues(model.Thread{})
if err != nil {
return nil, err
}
args2 := make([]any, len(args))
for i := range args {
args2[i], _ = json.Marshal(args[i])
}
scanTargets := make([]any, len(keys))
q := fmt.Sprintf(`
SELECT %s FROM threads WHERE %s
ORDER BY TS ASC
`, strings.Join(keys, ", "), clause)
rows, err := s.driver.QueryContext(ctx, q, args2...)
if err != nil {
return nil, err
}
defer rows.Close()
var result []model.Thread
for rows.Next() {
for i := range scanTargets {
scanTargets[i] = &[]byte{}
}
if err := rows.Scan(scanTargets...); err != nil {
return nil, err
}
m := map[string]json.RawMessage{}
for i, k := range keys {
m[k] = *scanTargets[i].(*[]byte)
}
b, _ := json.Marshal(m)
var one model.Thread
if err := json.Unmarshal(b, &one); err != nil {
return nil, err
}
result = append(result, one)
}
return result, rows.Err()
}
func (s Storage) selectMessagesWhere(ctx context.Context, clause string, args ...any) ([]model.Message, error) {
keys, _, _, _, err := keysArgsKeyargsValues(model.Message{})
if err != nil {
return nil, err
}
args2 := make([]any, len(args))
for i := range args {
args2[i], _ = json.Marshal(args[i])
}
scanTargets := make([]any, len(keys))
q := fmt.Sprintf(`
SELECT %s FROM messages WHERE %s
ORDER BY TS ASC
`, strings.Join(keys, ", "), clause)
rows, err := s.driver.QueryContext(ctx, q, args2...)
if err != nil {
return nil, err
}
defer rows.Close()
var result []model.Message
for rows.Next() {
for i := range scanTargets {
scanTargets[i] = &[]byte{}
}
if err := rows.Scan(scanTargets...); err != nil {
return nil, err
}
m := map[string]json.RawMessage{}
for i, k := range keys {
m[k] = *scanTargets[i].(*[]byte)
}
b, _ := json.Marshal(m)
var one model.Message
if err := json.Unmarshal(b, &one); err != nil {
return nil, err
}
result = append(result, one)
}
return result, rows.Err()
}
func (s Storage) selectOne(ctx context.Context, table string, v any, clause string, args ...any) error {
if questions := strings.Count(clause, "$"); questions != len(args) {
return fmt.Errorf("expected %v args for clause but found %v", questions, len(args))
}
keys, _, _, _, err := keysArgsKeyargsValues(v)
if err != nil {
return err
}
for i := range args {
args[i], _ = json.Marshal(args[i])
}
q := fmt.Sprintf(`
SELECT %s FROM %s WHERE %s
`, strings.Join(keys, ", "), table, clause)
row := s.driver.QueryRowContext(ctx, q, args...)
if err := row.Err(); err != nil {
return err
}
scanTargets := make([]any, len(keys))
for i := range scanTargets {
scanTargets[i] = &[]byte{}
}
if err := row.Scan(scanTargets...); err != nil {
return err
}
m := map[string]json.RawMessage{}
for i, k := range keys {
m[k] = *scanTargets[i].(*[]byte)
}
b, _ := json.Marshal(m)
return json.Unmarshal(b, v)
}
func (s Storage) upsert(ctx context.Context, table string, v any) error {
keys, args, keyArgs, values, err := keysArgsKeyargsValues(v)
if err != nil || len(keys) == 0 {
return err
}
q := fmt.Sprintf(`
INSERT INTO %s (%s) VALUES (%s)
ON CONFLICT (ID) DO UPDATE SET %s
`, table, strings.Join(keys, ", "), strings.Join(args, ", "), strings.Join(keyArgs, ", "))
if result, err := s.driver.ExecContext(ctx, q, values...); err != nil {
return err
} else if n, err := result.RowsAffected(); err != nil {
return err
} else if n != 1 {
return fmt.Errorf("UpsertMessage affected %v rows", n)
}
return nil
}
func keysArgsKeyargsValues(v any) ([]string, []string, []string, []any, error) {
b, _ := json.Marshal(v)
var m map[string]json.RawMessage
err := json.Unmarshal(b, &m)
keys := []string{}
for k := range m {
keys = append(keys, k)
}
args := make([]string, len(keys))
for i := range args {
args[i] = fmt.Sprintf("$%d", i+1)
}
keyArgs := make([]string, len(keys))
for i := range keyArgs {
keyArgs[i] = fmt.Sprintf("%s=$%d", keys[i], i+1)
}
values := make([]any, len(keys))
for i := range values {
values[i] = []byte(m[keys[i]])
}
return keys, args, keyArgs, values, err
}