package main import ( "context" "database/sql" "errors" "fmt" "io/ioutil" "os" "path" "sync" "time" "go.etcd.io/bbolt" _ "github.com/lib/pq" ) type Driver interface { Close() error ForEach(context.Context, string, func(string, []byte) error) error Get(context.Context, string, string) ([]byte, error) Set(context.Context, string, string, []byte) error } func FillWithTestdata(ctx context.Context, driver Driver) error { d := "./testdata/slack_events" entries, err := os.ReadDir(d) if err != nil { return err } for _, entry := range entries { if entry.IsDir() { continue } b, err := os.ReadFile(path.Join(d, entry.Name())) if err != nil { return err } m, err := ParseSlack(b) if errors.Is(err, ErrIrrelevantMessage) { continue } else if err != nil { return err } if err := driver.Set(nil, "m", m.ID, m.Serialize()); err != nil { return err } } return nil } type Postgres struct { db *sql.DB } func NewPostgres(ctx context.Context, conn string) (Postgres, error) { db, err := sql.Open("postgres", conn) if err != nil { return Postgres{}, err } pg := Postgres{db: db} if err := pg.setup(ctx); err != nil { pg.Close() return Postgres{}, fmt.Errorf("failed setup: %w", err) } return pg, nil } func (pg Postgres) setup(ctx context.Context) error { tableQ, err := pg.table("q") if err != nil { return err } tableM, err := pg.table("m") if err != nil { return err } if _, err := pg.db.ExecContext(ctx, fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id TEXT NOT NULL, v JSONB NOT NULL ); CREATE TABLE IF NOT EXISTS %s ( id TEXT NOT NULL, v JSONB NOT NULL ); ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s_id_unique; ALTER TABLE %s ADD CONSTRAINT %s_id_unique UNIQUE (id); ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s_id_unique; ALTER TABLE %s ADD CONSTRAINT %s_id_unique UNIQUE (id); `, tableQ, tableM, tableQ, tableQ, tableQ, tableQ, tableM, tableM, tableM, tableM, )); err != nil { return err } return nil } func (pg Postgres) table(s string) (string, error) { switch s { case "q": return "spoc_bot_vr_q", nil case "m": return "spoc_bot_vr_messages", nil } return "", errors.New("invalid table " + s) } func (pg Postgres) Close() error { return pg.db.Close() } func (pg Postgres) ForEach(ctx context.Context, ns string, cb func(string, []byte) error) error { table, err := pg.table(ns) if err != nil { return err } rows, err := pg.db.QueryContext(ctx, fmt.Sprintf(`SELECT id, v FROM %s;`, table)) if err != nil { return err } defer rows.Close() for rows.Next() { var id string var v []byte if err := rows.Scan(&id, &v); err != nil { return err } else if err := cb(id, v); err != nil { return err } } return ctx.Err() } func (pg Postgres) Get(ctx context.Context, ns, id string) ([]byte, error) { table, err := pg.table(ns) if err != nil { return nil, err } row := pg.db.QueryRowContext(ctx, fmt.Sprintf(`SELECT v FROM %s WHERE id='%s';`, table, id)) if err := row.Err(); err != nil { return nil, err } var v []byte if err := row.Scan(&v); err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, err } return v, nil } func (pg Postgres) Set(ctx context.Context, ns, id string, v []byte) error { table, err := pg.table(ns) if err != nil { return err } if v == nil { _, err = pg.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id='%s';`, table, id)) return err } _, err = pg.db.ExecContext(ctx, fmt.Sprintf(`INSERT INTO %s (id, v) VALUES ('%s', '%s') ON CONFLICT (id) DO UPDATE SET v = '%s'`, table, id, v, v)) return err } type RAM struct { m map[string]map[string][]byte lock *sync.RWMutex } func NewRAM() RAM { return RAM{ m: make(map[string]map[string][]byte), lock: &sync.RWMutex{}, } } func (ram RAM) Close() error { return nil } func (ram RAM) ForEach(ctx context.Context, ns string, cb func(string, []byte) error) error { ram.lock.RLock() defer ram.lock.RUnlock() for k, v := range ram.m[ns] { if ctx.Err() != nil { break } if err := cb(k, v); err != nil { return err } } return ctx.Err() } func (ram RAM) Get(_ context.Context, ns, id string) ([]byte, error) { ram.lock.RLock() defer ram.lock.RUnlock() if _, ok := ram.m[ns]; !ok { return nil, nil } return ram.m[ns][id], nil } func (ram RAM) Set(_ context.Context, ns, id string, v []byte) error { ram.lock.Lock() defer ram.lock.Unlock() if _, ok := ram.m[ns]; !ok { ram.m[ns] = map[string][]byte{} } ram.m[ns][id] = v if v == nil { delete(ram.m[ns], id) } return nil } type BBolt struct { db *bbolt.DB } func NewTestDBIn(d string) BBolt { d, err := ioutil.TempDir(d, "test-db-*") if err != nil { panic(err) } db, err := NewDB(path.Join(d, "bb")) if err != nil { panic(err) } return db } func NewDB(p string) (BBolt, error) { db, err := bbolt.Open(p, 0600, &bbolt.Options{ Timeout: time.Second, }) return BBolt{db: db}, err } func (bb BBolt) Close() error { return bb.db.Close() } func (bb BBolt) ForEach(ctx context.Context, db string, cb func(string, []byte) error) error { return bb.db.View(func(tx *bbolt.Tx) error { bkt := tx.Bucket([]byte(db)) if bkt == nil { return nil } c := bkt.Cursor() for k, v := c.First(); k != nil && ctx.Err() == nil; k, v = c.Next() { if err := cb(string(k), v); err != nil { return err } } return ctx.Err() }) } func (bb BBolt) Get(_ context.Context, db, id string) ([]byte, error) { var b []byte err := bb.db.View(func(tx *bbolt.Tx) error { bkt := tx.Bucket([]byte(db)) if bkt == nil { return nil } b = bkt.Get([]byte(id)) return nil }) return b, err } func (bb BBolt) Set(_ context.Context, db, id string, value []byte) error { return bb.db.Update(func(tx *bbolt.Tx) error { bkt := tx.Bucket([]byte(db)) if bkt == nil { var err error bkt, err = tx.CreateBucket([]byte(db)) if err != nil { return err } } if value == nil { return bkt.Delete([]byte(id)) } return bkt.Put([]byte(id), value) }) }