diff --git a/driver.go b/driver.go index 9a38f30..d991156 100644 --- a/driver.go +++ b/driver.go @@ -2,12 +2,17 @@ package main import ( "context" + "database/sql" + "errors" + "fmt" "io/ioutil" "path" "sync" "time" "go.etcd.io/bbolt" + + _ "github.com/lib/pq" ) type Driver interface { @@ -17,6 +22,133 @@ type Driver interface { Set(context.Context, string, string, []byte) error } +type Postgres struct { + db *sql.DB +} + +func NewPostgres(ctx context.Context, conn string) (Postgres, error) { + db, err := sql.Open("postgres", conn) + if err != nil { + return Postgres{}, err + } + + pg := Postgres{db: db} + if err := pg.setup(ctx); err != nil { + pg.Close() + return Postgres{}, fmt.Errorf("failed setup: %w", err) + } + + return pg, nil +} + +func (pg Postgres) setup(ctx context.Context) error { + tableQ, err := pg.table("q") + if err != nil { + return err + } + tableM, err := pg.table("m") + if err != nil { + return err + } + if _, err := pg.db.ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id TEXT NOT NULL, + v JSONB NOT NULL + ); + CREATE TABLE IF NOT EXISTS %s ( + id TEXT NOT NULL, + v JSONB NOT NULL + ); + ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s_id_unique; + ALTER TABLE %s ADD CONSTRAINT %s_id_unique UNIQUE (id); + ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s_id_unique; + ALTER TABLE %s ADD CONSTRAINT %s_id_unique UNIQUE (id); + `, tableQ, + tableM, + tableQ, tableQ, + tableQ, tableQ, + tableM, tableM, + tableM, tableM, + )); err != nil { + return err + } + return nil +} + +func (pg Postgres) table(s string) (string, error) { + switch s { + case "q": + return "spoc_bot_vr_q", nil + case "m": + return "spoc_bot_vr_messages", nil + } + return "", errors.New("invalid table " + s) +} + +func (pg Postgres) Close() error { + return pg.db.Close() +} + +func (pg Postgres) ForEach(ctx context.Context, ns string, cb func(string, []byte) error) error { + table, err := pg.table(ns) + if err != nil { + return err + } + + rows, err := pg.db.QueryContext(ctx, fmt.Sprintf(`SELECT id, v FROM %s;`, table)) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var id string + var v []byte + if err := rows.Scan(&id, &v); err != nil { + return err + } else if err := cb(id, v); err != nil { + return err + } + } + + return ctx.Err() +} + +func (pg Postgres) Get(ctx context.Context, ns, id string) ([]byte, error) { + table, err := pg.table(ns) + if err != nil { + return nil, err + } + + row := pg.db.QueryRowContext(ctx, fmt.Sprintf(`SELECT v FROM %s WHERE id='%s';`, table, id)) + if err := row.Err(); err != nil { + return nil, err + } + + var v []byte + if err := row.Scan(&v); err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + + return v, nil +} + +func (pg Postgres) Set(ctx context.Context, ns, id string, v []byte) error { + table, err := pg.table(ns) + if err != nil { + return err + } + + if v == nil { + _, err = pg.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id='%s';`, table, id)) + return err + + } + + _, err = pg.db.ExecContext(ctx, fmt.Sprintf(`INSERT INTO %s (id, v) VALUES ('%s', '%s') ON CONFLICT (id) DO UPDATE SET v = '%s'`, table, id, v, v)) + return err +} + type RAM struct { m map[string]map[string][]byte lock *sync.RWMutex diff --git a/driver_test.go b/driver_test.go index 025b536..bd0c9ac 100644 --- a/driver_test.go +++ b/driver_test.go @@ -4,9 +4,26 @@ import ( "context" "errors" "io" + "os" "testing" + "time" ) +func TestPostgres(t *testing.T) { + ctx, can := context.WithTimeout(context.Background(), time.Second*15) + defer can() + + conn := os.Getenv("INTEGRATION_POSTGRES_CONN") + if conn == "" { + t.Skip() + } + pg, err := NewPostgres(ctx, conn) + if err != nil { + t.Fatal(err) + } + testDriver(t, pg) +} + func TestDriverRAM(t *testing.T) { testDriver(t, NewRAM()) } @@ -16,48 +33,51 @@ func TestDriverBBolt(t *testing.T) { } func testDriver(t *testing.T, d Driver) { + ctx, can := context.WithTimeout(context.Background(), time.Second*15) + defer can() + defer d.Close() - if b, err := d.Get(nil, "db", "id"); err != nil { - t.Error("cannot get from empty", err) + if b, err := d.Get(ctx, "m", "id"); err != nil { + t.Error("cannot get from empty:", err) } else if b != nil { t.Error("got fake from empty") } - if err := d.ForEach(context.Background(), "db", func(string, []byte) error { + if err := d.ForEach(ctx, "m", func(string, []byte) error { return errors.New("should have no hits") }); err != nil { - t.Error("failed to forEach empty", err) + t.Error("failed to forEach empty:", err) } - if err := d.Set(nil, "db", "id", []byte("hello world")); err != nil { - t.Error("cannot set from empty", err) + if err := d.Set(ctx, "m", "id", []byte(`"hello world"`)); err != nil { + t.Error("cannot set from empty:", err) } - if b, err := d.Get(nil, "db", "id"); err != nil { - t.Error("cannot get from full", err) - } else if string(b) != "hello world" { + if b, err := d.Get(ctx, "m", "id"); err != nil { + t.Error("cannot get from full:", err) + } else if string(b) != `"hello world"` { t.Error("got fake from full") } - if err := d.ForEach(context.Background(), "db", func(id string, v []byte) error { + if err := d.ForEach(ctx, "m", func(id string, v []byte) error { if id != "id" { - t.Error(id) + t.Error("for each id weird:", id) } - if string(v) != "hello world" { - t.Error(string(v)) + if string(v) != `"hello world"` { + t.Error("for each value weird:", string(v)) } return io.EOF }); err != io.EOF { - t.Error("failed to forEach full", err) + t.Error("failed to forEach full:", err) } - if err := d.Set(nil, "db", "id", nil); err != nil { - t.Error("cannot set from full", err) + if err := d.Set(ctx, "m", "id", nil); err != nil { + t.Error("cannot set from full:", err) } - if b, err := d.Get(nil, "db", "id"); err != nil { - t.Error("cannot get from deleted", err) + if b, err := d.Get(ctx, "m", "id"); err != nil { + t.Error("cannot get from deleted:", err) } else if b != nil { t.Error("got fake from deleted") } diff --git a/go.mod b/go.mod index 0fec412..5b55d72 100644 --- a/go.mod +++ b/go.mod @@ -7,4 +7,7 @@ require ( go.etcd.io/bbolt v1.3.9 ) -require golang.org/x/sys v0.4.0 // indirect +require ( + github.com/lib/pq v1.10.9 // indirect + golang.org/x/sys v0.4.0 // indirect +) diff --git a/go.sum b/go.sum index fe9ccb0..ddfc625 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk= github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI= go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= diff --git a/main.go b/main.go index 9e384c9..cc3ec33 100644 --- a/main.go +++ b/main.go @@ -122,7 +122,7 @@ func _newHandlerPostAPIV1EventsSlack(cfg Config) http.HandlerFunc { return } - if err := cfg.storage.Upsert(ctx, m); err != nil { + if err := cfg.storage.Upsert(r.Context(), m); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return }