package main import ( "context" "encoding/json" "fmt" "strings" "github.com/breel-render/spoc-bot-vr/model" ) type Storage struct { driver Driver } func NewStorage(ctx context.Context, driver Driver) (Storage, error) { if _, err := driver.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS events (ID TEXT UNIQUE); CREATE TABLE IF NOT EXISTS messages (ID TEXT UNIQUE); CREATE TABLE IF NOT EXISTS threads (ID TEXT UNIQUE); `); err != nil { return Storage{}, err } for table, v := range map[string]any{ "events": model.Event{}, "messages": model.Message{}, "threads": model.Thread{}, } { b, _ := json.Marshal(v) var m map[string]struct{} json.Unmarshal(b, &m) for k := range m { if k == `ID` { continue } driver.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN %s TEXT`, table, k)) } } return Storage{driver: driver}, nil } func (s Storage) GetEvent(ctx context.Context, ID string) (model.Event, error) { v := model.Event{} err := s.selectOne(ctx, "events", &v, "ID = $1", ID) return v, err } func (s Storage) UpsertEvent(ctx context.Context, event model.Event) error { return s.upsert(ctx, "events", event) } func (s Storage) GetMessage(ctx context.Context, ID string) (model.Message, error) { v := model.Message{} err := s.selectOne(ctx, "messages", &v, "ID = $1", ID) return v, err } func (s Storage) UpsertMessage(ctx context.Context, message model.Message) error { return s.upsert(ctx, "messages", message) } func (s Storage) GetThread(ctx context.Context, ID string) (model.Thread, error) { v := model.Thread{} err := s.selectOne(ctx, "threads", &v, "ID = $1", ID) return v, err } func (s Storage) UpsertThread(ctx context.Context, thread model.Thread) error { return s.upsert(ctx, "threads", thread) } func (s Storage) selectOne(ctx context.Context, table string, v any, clause string, args ...any) error { if questions := strings.Count(clause, "$"); questions != len(args) { return fmt.Errorf("expected %v args for clause but found %v", questions, len(args)) } keys, _, _, _, err := keysArgsKeyargsValues(v) if err != nil { return err } for i := range args { args[i], _ = json.Marshal(args[i]) } q := fmt.Sprintf(` SELECT %s FROM %s WHERE %s `, strings.Join(keys, ", "), table, clause) row := s.driver.QueryRowContext(ctx, q, args...) if err := row.Err(); err != nil { return err } scanTargets := make([]any, len(keys)) for i := range scanTargets { scanTargets[i] = &[]byte{} } if err := row.Scan(scanTargets...); err != nil { return err } m := map[string]json.RawMessage{} for i, k := range keys { m[k] = *scanTargets[i].(*[]byte) } b, _ := json.Marshal(m) return json.Unmarshal(b, v) } func (s Storage) upsert(ctx context.Context, table string, v any) error { keys, args, keyArgs, values, err := keysArgsKeyargsValues(v) if err != nil || len(keys) == 0 { return err } q := fmt.Sprintf(` INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (ID) DO UPDATE SET %s `, table, strings.Join(keys, ", "), strings.Join(args, ", "), strings.Join(keyArgs, ", ")) if result, err := s.driver.ExecContext(ctx, q, values...); err != nil { return err } else if n, err := result.RowsAffected(); err != nil { return err } else if n != 1 { return fmt.Errorf("UpsertMessage affected %v rows", n) } return nil } func keysArgsKeyargsValues(v any) ([]string, []string, []string, []any, error) { b, _ := json.Marshal(v) var m map[string]json.RawMessage err := json.Unmarshal(b, &m) keys := []string{} for k := range m { keys = append(keys, k) } args := make([]string, len(keys)) for i := range args { args[i] = fmt.Sprintf("$%d", i+1) } keyArgs := make([]string, len(keys)) for i := range keyArgs { keyArgs[i] = fmt.Sprintf("%s=$%d", keys[i], i+1) } values := make([]any, len(keys)) for i := range values { values[i] = []byte(m[keys[i]]) } return keys, args, keyArgs, values, err }