258 lines
6.2 KiB
Go
258 lines
6.2 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/breel-render/spoc-bot-vr/model"
|
|
)
|
|
|
|
type Storage struct {
|
|
driver Driver
|
|
}
|
|
|
|
func NewStorage(ctx context.Context, driver Driver) (Storage, error) {
|
|
if _, err := driver.ExecContext(ctx, `
|
|
CREATE TABLE IF NOT EXISTS events (ID TEXT UNIQUE);
|
|
CREATE TABLE IF NOT EXISTS messages (ID TEXT UNIQUE);
|
|
CREATE TABLE IF NOT EXISTS threads (ID TEXT UNIQUE);
|
|
`); err != nil {
|
|
return Storage{}, err
|
|
}
|
|
|
|
for table, v := range map[string]any{
|
|
"events": model.Event{},
|
|
"messages": model.Message{},
|
|
"threads": model.Thread{},
|
|
} {
|
|
b, _ := json.Marshal(v)
|
|
var m map[string]struct{}
|
|
json.Unmarshal(b, &m)
|
|
for k := range m {
|
|
if k == `ID` {
|
|
continue
|
|
}
|
|
driver.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN %s TEXT`, table, k))
|
|
}
|
|
}
|
|
|
|
return Storage{driver: driver}, nil
|
|
}
|
|
|
|
func (s Storage) GetEvent(ctx context.Context, ID string) (model.Event, error) {
|
|
v := model.Event{}
|
|
err := s.selectOne(ctx, "events", &v, "ID = $1", ID)
|
|
return v, err
|
|
}
|
|
|
|
func (s Storage) UpsertEvent(ctx context.Context, event model.Event) error {
|
|
return s.upsert(ctx, "events", event)
|
|
}
|
|
|
|
func (s Storage) GetMessage(ctx context.Context, ID string) (model.Message, error) {
|
|
v := model.Message{}
|
|
err := s.selectOne(ctx, "messages", &v, "ID = $1", ID)
|
|
return v, err
|
|
}
|
|
|
|
func (s Storage) UpsertMessage(ctx context.Context, message model.Message) error {
|
|
return s.upsert(ctx, "messages", message)
|
|
}
|
|
|
|
func (s Storage) GetThread(ctx context.Context, ID string) (model.Thread, error) {
|
|
v := model.Thread{}
|
|
err := s.selectOne(ctx, "threads", &v, "ID = $1", ID)
|
|
return v, err
|
|
}
|
|
|
|
func (s Storage) GetEventThreads(ctx context.Context, ID string) ([]model.Thread, error) {
|
|
return s.selectThreadsWhere(ctx, "EventID = $1", ID)
|
|
}
|
|
|
|
func (s Storage) GetThreadMessages(ctx context.Context, ID string) ([]model.Message, error) {
|
|
return s.selectMessagesWhere(ctx, "ThreadID = $1", ID)
|
|
}
|
|
|
|
func (s Storage) UpsertThread(ctx context.Context, thread model.Thread) error {
|
|
return s.upsert(ctx, "threads", thread)
|
|
}
|
|
|
|
func (s Storage) selectThreadsWhere(ctx context.Context, clause string, args ...any) ([]model.Thread, error) {
|
|
keys, _, _, _, err := keysArgsKeyargsValues(model.Thread{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
args2 := make([]any, len(args))
|
|
for i := range args {
|
|
args2[i], _ = json.Marshal(args[i])
|
|
}
|
|
scanTargets := make([]any, len(keys))
|
|
|
|
q := fmt.Sprintf(`
|
|
SELECT %s FROM threads WHERE %s
|
|
ORDER BY TS ASC
|
|
`, strings.Join(keys, ", "), clause)
|
|
rows, err := s.driver.QueryContext(ctx, q, args2...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var result []model.Thread
|
|
for rows.Next() {
|
|
for i := range scanTargets {
|
|
scanTargets[i] = &[]byte{}
|
|
}
|
|
|
|
if err := rows.Scan(scanTargets...); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
m := map[string]json.RawMessage{}
|
|
for i, k := range keys {
|
|
m[k] = *scanTargets[i].(*[]byte)
|
|
}
|
|
b, _ := json.Marshal(m)
|
|
|
|
var one model.Thread
|
|
if err := json.Unmarshal(b, &one); err != nil {
|
|
return nil, err
|
|
}
|
|
result = append(result, one)
|
|
}
|
|
|
|
return result, rows.Err()
|
|
}
|
|
|
|
func (s Storage) selectMessagesWhere(ctx context.Context, clause string, args ...any) ([]model.Message, error) {
|
|
keys, _, _, _, err := keysArgsKeyargsValues(model.Message{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
args2 := make([]any, len(args))
|
|
for i := range args {
|
|
args2[i], _ = json.Marshal(args[i])
|
|
}
|
|
scanTargets := make([]any, len(keys))
|
|
|
|
q := fmt.Sprintf(`
|
|
SELECT %s FROM messages WHERE %s
|
|
ORDER BY TS ASC
|
|
`, strings.Join(keys, ", "), clause)
|
|
rows, err := s.driver.QueryContext(ctx, q, args2...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var result []model.Message
|
|
for rows.Next() {
|
|
for i := range scanTargets {
|
|
scanTargets[i] = &[]byte{}
|
|
}
|
|
|
|
if err := rows.Scan(scanTargets...); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
m := map[string]json.RawMessage{}
|
|
for i, k := range keys {
|
|
m[k] = *scanTargets[i].(*[]byte)
|
|
}
|
|
b, _ := json.Marshal(m)
|
|
|
|
var one model.Message
|
|
if err := json.Unmarshal(b, &one); err != nil {
|
|
return nil, err
|
|
}
|
|
result = append(result, one)
|
|
}
|
|
|
|
return result, rows.Err()
|
|
}
|
|
|
|
func (s Storage) selectOne(ctx context.Context, table string, v any, clause string, args ...any) error {
|
|
if questions := strings.Count(clause, "$"); questions != len(args) {
|
|
return fmt.Errorf("expected %v args for clause but found %v", questions, len(args))
|
|
}
|
|
|
|
keys, _, _, _, err := keysArgsKeyargsValues(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i := range args {
|
|
args[i], _ = json.Marshal(args[i])
|
|
}
|
|
|
|
q := fmt.Sprintf(`
|
|
SELECT %s FROM %s WHERE %s
|
|
`, strings.Join(keys, ", "), table, clause)
|
|
row := s.driver.QueryRowContext(ctx, q, args...)
|
|
if err := row.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
scanTargets := make([]any, len(keys))
|
|
for i := range scanTargets {
|
|
scanTargets[i] = &[]byte{}
|
|
}
|
|
if err := row.Scan(scanTargets...); err != nil {
|
|
return err
|
|
}
|
|
|
|
m := map[string]json.RawMessage{}
|
|
for i, k := range keys {
|
|
m[k] = *scanTargets[i].(*[]byte)
|
|
}
|
|
b, _ := json.Marshal(m)
|
|
return json.Unmarshal(b, v)
|
|
}
|
|
|
|
func (s Storage) upsert(ctx context.Context, table string, v any) error {
|
|
keys, args, keyArgs, values, err := keysArgsKeyargsValues(v)
|
|
if err != nil || len(keys) == 0 {
|
|
return err
|
|
}
|
|
|
|
q := fmt.Sprintf(`
|
|
INSERT INTO %s (%s) VALUES (%s)
|
|
ON CONFLICT (ID) DO UPDATE SET %s
|
|
`, table, strings.Join(keys, ", "), strings.Join(args, ", "), strings.Join(keyArgs, ", "))
|
|
if result, err := s.driver.ExecContext(ctx, q, values...); err != nil {
|
|
return err
|
|
} else if n, err := result.RowsAffected(); err != nil {
|
|
return err
|
|
} else if n != 1 {
|
|
return fmt.Errorf("UpsertMessage affected %v rows", n)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func keysArgsKeyargsValues(v any) ([]string, []string, []string, []any, error) {
|
|
b, _ := json.Marshal(v)
|
|
var m map[string]json.RawMessage
|
|
err := json.Unmarshal(b, &m)
|
|
|
|
keys := []string{}
|
|
for k := range m {
|
|
keys = append(keys, k)
|
|
}
|
|
args := make([]string, len(keys))
|
|
for i := range args {
|
|
args[i] = fmt.Sprintf("$%d", i+1)
|
|
}
|
|
keyArgs := make([]string, len(keys))
|
|
for i := range keyArgs {
|
|
keyArgs[i] = fmt.Sprintf("%s=$%d", keys[i], i+1)
|
|
}
|
|
values := make([]any, len(keys))
|
|
for i := range values {
|
|
values[i] = []byte(m[keys[i]])
|
|
}
|
|
|
|
return keys, args, keyArgs, values, err
|
|
}
|