spoc-bot-vr/driver.go

313 lines
6.0 KiB
Go

package main
import (
"context"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"os"
"path"
"sync"
"time"
"go.etcd.io/bbolt"
_ "github.com/lib/pq"
)
type Driver interface {
Close() error
ForEach(context.Context, string, func(string, []byte) error) error
Get(context.Context, string, string) ([]byte, error)
Set(context.Context, string, string, []byte) error
}
func FillWithTestdata(ctx context.Context, driver Driver, assetPattern, datacenterPattern string) error {
d := "./testdata/slack_events"
entries, err := os.ReadDir(d)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
b, err := os.ReadFile(path.Join(d, entry.Name()))
if err != nil {
return err
}
m, err := ParseSlack(b, assetPattern, datacenterPattern)
if errors.Is(err, ErrIrrelevantMessage) {
continue
} else if err != nil {
return err
}
if err := driver.Set(nil, "m", m.ID, m.Serialize()); err != nil {
return err
}
}
return nil
}
type Postgres struct {
db *sql.DB
}
func NewPostgres(ctx context.Context, conn string) (Postgres, error) {
db, err := sql.Open("postgres", conn)
if err != nil {
return Postgres{}, err
}
pg := Postgres{db: db}
if err := pg.setup(ctx); err != nil {
pg.Close()
return Postgres{}, fmt.Errorf("failed setup: %w", err)
}
return pg, nil
}
func (pg Postgres) setup(ctx context.Context) error {
tableQ, err := pg.table("q")
if err != nil {
return err
}
tableM, err := pg.table("m")
if err != nil {
return err
}
if _, err := pg.db.ExecContext(ctx, fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id TEXT NOT NULL,
v JSONB NOT NULL
);
CREATE TABLE IF NOT EXISTS %s (
id TEXT NOT NULL,
v JSONB NOT NULL
);
ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s_id_unique;
ALTER TABLE %s ADD CONSTRAINT %s_id_unique UNIQUE (id);
ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s_id_unique;
ALTER TABLE %s ADD CONSTRAINT %s_id_unique UNIQUE (id);
`, tableQ,
tableM,
tableQ, tableQ,
tableQ, tableQ,
tableM, tableM,
tableM, tableM,
)); err != nil {
return err
}
return nil
}
func (pg Postgres) table(s string) (string, error) {
switch s {
case "q":
return "spoc_bot_vr_q", nil
case "m":
return "spoc_bot_vr_messages", nil
}
return "", errors.New("invalid table " + s)
}
func (pg Postgres) Close() error {
return pg.db.Close()
}
func (pg Postgres) ForEach(ctx context.Context, ns string, cb func(string, []byte) error) error {
table, err := pg.table(ns)
if err != nil {
return err
}
rows, err := pg.db.QueryContext(ctx, fmt.Sprintf(`SELECT id, v FROM %s;`, table))
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var id string
var v []byte
if err := rows.Scan(&id, &v); err != nil {
return err
} else if err := cb(id, v); err != nil {
return err
}
}
return ctx.Err()
}
func (pg Postgres) Get(ctx context.Context, ns, id string) ([]byte, error) {
table, err := pg.table(ns)
if err != nil {
return nil, err
}
row := pg.db.QueryRowContext(ctx, fmt.Sprintf(`SELECT v FROM %s WHERE id='%s';`, table, id))
if err := row.Err(); err != nil {
return nil, err
}
var v []byte
if err := row.Scan(&v); err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
return v, nil
}
func (pg Postgres) Set(ctx context.Context, ns, id string, v []byte) error {
table, err := pg.table(ns)
if err != nil {
return err
}
if v == nil {
_, err = pg.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id='%s';`, table, id))
return err
}
_, err = pg.db.ExecContext(ctx, fmt.Sprintf(`INSERT INTO %s (id, v) VALUES ('%s', '%s') ON CONFLICT (id) DO UPDATE SET v = '%s'`, table, id, v, v))
return err
}
type RAM struct {
m map[string]map[string][]byte
lock *sync.RWMutex
}
func NewRAM() RAM {
return RAM{
m: make(map[string]map[string][]byte),
lock: &sync.RWMutex{},
}
}
func (ram RAM) Close() error {
return nil
}
func (ram RAM) ForEach(ctx context.Context, ns string, cb func(string, []byte) error) error {
ram.lock.RLock()
defer ram.lock.RUnlock()
for k, v := range ram.m[ns] {
if ctx.Err() != nil {
break
}
if err := cb(k, v); err != nil {
return err
}
}
return ctx.Err()
}
func (ram RAM) Get(_ context.Context, ns, id string) ([]byte, error) {
ram.lock.RLock()
defer ram.lock.RUnlock()
if _, ok := ram.m[ns]; !ok {
return nil, nil
}
return ram.m[ns][id], nil
}
func (ram RAM) Set(_ context.Context, ns, id string, v []byte) error {
ram.lock.Lock()
defer ram.lock.Unlock()
if _, ok := ram.m[ns]; !ok {
ram.m[ns] = map[string][]byte{}
}
ram.m[ns][id] = v
if v == nil {
delete(ram.m[ns], id)
}
return nil
}
type BBolt struct {
db *bbolt.DB
}
func NewTestDBIn(d string) BBolt {
d, err := ioutil.TempDir(d, "test-db-*")
if err != nil {
panic(err)
}
db, err := NewDB(path.Join(d, "bb"))
if err != nil {
panic(err)
}
return db
}
func NewDB(p string) (BBolt, error) {
db, err := bbolt.Open(p, 0600, &bbolt.Options{
Timeout: time.Second,
})
return BBolt{db: db}, err
}
func (bb BBolt) Close() error {
return bb.db.Close()
}
func (bb BBolt) ForEach(ctx context.Context, db string, cb func(string, []byte) error) error {
return bb.db.View(func(tx *bbolt.Tx) error {
bkt := tx.Bucket([]byte(db))
if bkt == nil {
return nil
}
c := bkt.Cursor()
for k, v := c.First(); k != nil && ctx.Err() == nil; k, v = c.Next() {
if err := cb(string(k), v); err != nil {
return err
}
}
return ctx.Err()
})
}
func (bb BBolt) Get(_ context.Context, db, id string) ([]byte, error) {
var b []byte
err := bb.db.View(func(tx *bbolt.Tx) error {
bkt := tx.Bucket([]byte(db))
if bkt == nil {
return nil
}
b = bkt.Get([]byte(id))
return nil
})
return b, err
}
func (bb BBolt) Set(_ context.Context, db, id string, value []byte) error {
return bb.db.Update(func(tx *bbolt.Tx) error {
bkt := tx.Bucket([]byte(db))
if bkt == nil {
var err error
bkt, err = tx.CreateBucket([]byte(db))
if err != nil {
return err
}
}
if value == nil {
return bkt.Delete([]byte(id))
}
return bkt.Put([]byte(id), value)
})
}