package main import ( "context" "database/sql" "io" "sync" "time" _ "github.com/glebarez/sqlite" ) type DB struct { scheme string conn string rw *sync.RWMutex locked *bool } func NewDB(ctx context.Context, scheme, conn string) (DB, error) { ctx, can := context.WithTimeout(ctx, time.Second*10) defer can() locked := false db := DB{ scheme: scheme, conn: conn, rw: &sync.RWMutex{}, locked: &locked, } sql, err := db.dial(ctx) if err != nil { return DB{}, err } defer sql.Close() return db, err } func (db DB) WithLock(cb func() error) error { db.rw.Lock() defer db.rw.Unlock() *db.locked = true defer func() { *db.locked = false }() return cb() } func (db DB) Exec(ctx context.Context, q string, args ...any) error { if !*db.locked { db.rw.RLock() defer db.rw.RUnlock() } return db.exec(ctx, q, args...) } func (db DB) exec(ctx context.Context, q string, args ...any) error { c, err := db.dial(ctx) if err != nil { return err } defer c.Close() _, err = c.ExecContext(ctx, q, args...) return err } func (db DB) Query(ctx context.Context, cb func(*sql.Rows) error, q string, args ...any) error { if !*db.locked { db.rw.RLock() defer db.rw.RUnlock() } return db.query(ctx, cb, q, args...) } func (db DB) query(ctx context.Context, cb func(*sql.Rows) error, q string, args ...any) error { c, err := db.dial(ctx) if err != nil { return err } defer c.Close() rows, err := c.QueryContext(ctx, q, args...) if err != nil { return err } defer rows.Close() for rows.Next() { if err := cb(rows); err != nil { return err } } return rows.Err() } func (db DB) dial(ctx context.Context) (*sql.DB, error) { c, err := sql.Open(db.scheme, db.conn) if err != nil { return nil, err } if err := c.PingContext(ctx); err != nil { return nil, err } return c, nil } func (db DB) GetParty(id string) (string, error) { return "", io.EOF }