spoc-bot-vr/driver.go

113 lines
2.2 KiB
Go

package main
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"os"
"path"
"testing"
_ "github.com/glebarez/go-sqlite"
_ "github.com/lib/pq"
)
type Driver struct {
engine string
conn string
*sql.DB
}
func NewTestDriver(t *testing.T, optionalP ...string) Driver {
p := path.Join(t.TempDir(), "db")
if len(optionalP) > 0 {
p = optionalP[0]
}
driver, err := NewDriver(context.Background(), p)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { driver.Close() })
return driver
}
func NewDriver(ctx context.Context, conn string) (Driver, error) {
engine := "sqlite"
if conn == "" {
f, err := os.CreateTemp(os.TempDir(), "spoc-bot-vr-undef-*.db")
if err != nil {
return Driver{}, err
}
f.Close()
conn = f.Name()
} else {
if u, err := url.Parse(conn); err != nil {
return Driver{}, err
} else if u.Scheme != "" {
engine = u.Scheme
}
}
db, err := sql.Open(engine, conn)
if err != nil {
return Driver{}, err
}
driver := Driver{DB: db, conn: conn, engine: engine}
if err := driver.setup(ctx); err != nil {
driver.Close()
return Driver{}, fmt.Errorf("failed setup: %w", err)
}
return driver, nil
}
/*
func (driver Driver) FillWithTestdata(ctx context.Context, assetPattern, datacenterPattern, eventNamePattern string) error {
d := "./testdata/slack_events"
entries, err := os.ReadDir(d)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
b, err := os.ReadFile(path.Join(d, entry.Name()))
if err != nil {
return err
}
m, err := ParseSlack(b, assetPattern, datacenterPattern, eventNamePattern)
if errors.Is(err, ErrIrrelevantMessage) {
continue
} else if err != nil {
return err
}
if err := driver.Set(nil, "m", m.ID, m.Serialize()); err != nil {
return err
}
}
return nil
}
*/
func (driver Driver) setup(ctx context.Context) error {
_, err := driver.ExecContext(ctx, `
DROP TABLE IF EXISTS spoc_bot_vr_q;
DROP TABLE IF EXISTS spoc_bot_vr_messages;
`)
return err
}
func (d Driver) table(s string) (string, error) {
switch s {
case "q":
return "spoc_bot_vr_q", nil
case "m":
return "spoc_bot_vr_messages", nil
}
return "", errors.New("invalid table " + s)
}