285 lines
5.4 KiB
Go
285 lines
5.4 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"path"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.etcd.io/bbolt"
|
|
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
type Driver interface {
|
|
Close() error
|
|
ForEach(context.Context, string, func(string, []byte) error) error
|
|
Get(context.Context, string, string) ([]byte, error)
|
|
Set(context.Context, string, string, []byte) error
|
|
}
|
|
|
|
type Postgres struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewPostgres(ctx context.Context, conn string) (Postgres, error) {
|
|
db, err := sql.Open("postgres", conn)
|
|
if err != nil {
|
|
return Postgres{}, err
|
|
}
|
|
|
|
pg := Postgres{db: db}
|
|
if err := pg.setup(ctx); err != nil {
|
|
pg.Close()
|
|
return Postgres{}, fmt.Errorf("failed setup: %w", err)
|
|
}
|
|
|
|
return pg, nil
|
|
}
|
|
|
|
func (pg Postgres) setup(ctx context.Context) error {
|
|
tableQ, err := pg.table("q")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tableM, err := pg.table("m")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := pg.db.ExecContext(ctx, fmt.Sprintf(`
|
|
CREATE TABLE IF NOT EXISTS %s (
|
|
id TEXT NOT NULL,
|
|
v JSONB NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS %s (
|
|
id TEXT NOT NULL,
|
|
v JSONB NOT NULL
|
|
);
|
|
ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s_id_unique;
|
|
ALTER TABLE %s ADD CONSTRAINT %s_id_unique UNIQUE (id);
|
|
ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s_id_unique;
|
|
ALTER TABLE %s ADD CONSTRAINT %s_id_unique UNIQUE (id);
|
|
`, tableQ,
|
|
tableM,
|
|
tableQ, tableQ,
|
|
tableQ, tableQ,
|
|
tableM, tableM,
|
|
tableM, tableM,
|
|
)); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (pg Postgres) table(s string) (string, error) {
|
|
switch s {
|
|
case "q":
|
|
return "spoc_bot_vr_q", nil
|
|
case "m":
|
|
return "spoc_bot_vr_messages", nil
|
|
}
|
|
return "", errors.New("invalid table " + s)
|
|
}
|
|
|
|
func (pg Postgres) Close() error {
|
|
return pg.db.Close()
|
|
}
|
|
|
|
func (pg Postgres) ForEach(ctx context.Context, ns string, cb func(string, []byte) error) error {
|
|
table, err := pg.table(ns)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rows, err := pg.db.QueryContext(ctx, fmt.Sprintf(`SELECT id, v FROM %s;`, table))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var id string
|
|
var v []byte
|
|
if err := rows.Scan(&id, &v); err != nil {
|
|
return err
|
|
} else if err := cb(id, v); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return ctx.Err()
|
|
}
|
|
|
|
func (pg Postgres) Get(ctx context.Context, ns, id string) ([]byte, error) {
|
|
table, err := pg.table(ns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
row := pg.db.QueryRowContext(ctx, fmt.Sprintf(`SELECT v FROM %s WHERE id='%s';`, table, id))
|
|
if err := row.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var v []byte
|
|
if err := row.Scan(&v); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return nil, err
|
|
}
|
|
|
|
return v, nil
|
|
}
|
|
|
|
func (pg Postgres) Set(ctx context.Context, ns, id string, v []byte) error {
|
|
table, err := pg.table(ns)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if v == nil {
|
|
_, err = pg.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id='%s';`, table, id))
|
|
return err
|
|
|
|
}
|
|
|
|
_, err = pg.db.ExecContext(ctx, fmt.Sprintf(`INSERT INTO %s (id, v) VALUES ('%s', '%s') ON CONFLICT (id) DO UPDATE SET v = '%s'`, table, id, v, v))
|
|
return err
|
|
}
|
|
|
|
type RAM struct {
|
|
m map[string]map[string][]byte
|
|
lock *sync.RWMutex
|
|
}
|
|
|
|
func NewRAM() RAM {
|
|
return RAM{
|
|
m: make(map[string]map[string][]byte),
|
|
lock: &sync.RWMutex{},
|
|
}
|
|
}
|
|
|
|
func (ram RAM) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (ram RAM) ForEach(ctx context.Context, ns string, cb func(string, []byte) error) error {
|
|
ram.lock.RLock()
|
|
defer ram.lock.RUnlock()
|
|
|
|
for k, v := range ram.m[ns] {
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
if err := cb(k, v); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return ctx.Err()
|
|
}
|
|
|
|
func (ram RAM) Get(_ context.Context, ns, id string) ([]byte, error) {
|
|
ram.lock.RLock()
|
|
defer ram.lock.RUnlock()
|
|
|
|
if _, ok := ram.m[ns]; !ok {
|
|
return nil, nil
|
|
}
|
|
|
|
return ram.m[ns][id], nil
|
|
}
|
|
|
|
func (ram RAM) Set(_ context.Context, ns, id string, v []byte) error {
|
|
ram.lock.Lock()
|
|
defer ram.lock.Unlock()
|
|
|
|
if _, ok := ram.m[ns]; !ok {
|
|
ram.m[ns] = map[string][]byte{}
|
|
}
|
|
ram.m[ns][id] = v
|
|
if v == nil {
|
|
delete(ram.m[ns], id)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type BBolt struct {
|
|
db *bbolt.DB
|
|
}
|
|
|
|
func NewTestDBIn(d string) BBolt {
|
|
d, err := ioutil.TempDir(d, "test-db-*")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
db, err := NewDB(path.Join(d, "bb"))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return db
|
|
}
|
|
|
|
func NewDB(p string) (BBolt, error) {
|
|
db, err := bbolt.Open(p, 0600, &bbolt.Options{
|
|
Timeout: time.Second,
|
|
})
|
|
return BBolt{db: db}, err
|
|
}
|
|
|
|
func (bb BBolt) Close() error {
|
|
return bb.db.Close()
|
|
}
|
|
|
|
func (bb BBolt) ForEach(ctx context.Context, db string, cb func(string, []byte) error) error {
|
|
return bb.db.View(func(tx *bbolt.Tx) error {
|
|
bkt := tx.Bucket([]byte(db))
|
|
if bkt == nil {
|
|
return nil
|
|
}
|
|
|
|
c := bkt.Cursor()
|
|
for k, v := c.First(); k != nil && ctx.Err() == nil; k, v = c.Next() {
|
|
if err := cb(string(k), v); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return ctx.Err()
|
|
})
|
|
}
|
|
|
|
func (bb BBolt) Get(_ context.Context, db, id string) ([]byte, error) {
|
|
var b []byte
|
|
err := bb.db.View(func(tx *bbolt.Tx) error {
|
|
bkt := tx.Bucket([]byte(db))
|
|
if bkt == nil {
|
|
return nil
|
|
}
|
|
|
|
b = bkt.Get([]byte(id))
|
|
return nil
|
|
})
|
|
return b, err
|
|
}
|
|
|
|
func (bb BBolt) Set(_ context.Context, db, id string, value []byte) error {
|
|
return bb.db.Update(func(tx *bbolt.Tx) error {
|
|
bkt := tx.Bucket([]byte(db))
|
|
if bkt == nil {
|
|
var err error
|
|
bkt, err = tx.CreateBucket([]byte(db))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if value == nil {
|
|
return bkt.Delete([]byte(id))
|
|
}
|
|
return bkt.Put([]byte(id), value)
|
|
})
|
|
}
|