to generic func Select

This commit is contained in:
Bel LaPointe
2026-02-03 17:08:08 -07:00
parent 72a976bdd0
commit f78d0d1f59
2 changed files with 44 additions and 16 deletions

View File

@@ -13,6 +13,8 @@ type DB struct {
*sql.DB
}
var global DB
func NewDB(ctx context.Context, driver, conn string) (DB, error) {
sql, err := sql.Open(driver, conn)
if err != nil {
@@ -25,6 +27,7 @@ func NewDB(ctx context.Context, driver, conn string) (DB, error) {
return db, err
}
global = db
return db, nil
}
@@ -45,30 +48,47 @@ func (db DB) setup(ctx context.Context) error {
return nil
}
func (db DB) Get(ctx context.Context, database int, k string) (any, error) {
func Select[V any](ctx context.Context, database int, k string) (*V, error) {
var v V
err := global.Get(ctx, database, k, &v)
if err != nil {
if err != sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return &v, nil
}
func (db DB) Get(ctx context.Context, database int, k string, vPtr any) error {
if err := db.get(ctx, database, k, vPtr); err != nil {
if err != sql.ErrNoRows {
return err
}
}
return nil
}
func (db DB) get(ctx context.Context, database int, k string, vPtr any) error {
row := db.QueryRowContext(ctx, `
SELECT value
FROM data
WHERE database=$1 AND key=$2
`, database, k)
if err := row.Err(); err != nil {
return nil, err
return err
}
var v string
if err := row.Scan(&v); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
return err
}
var a any
if err := json.Unmarshal([]byte(v), &a); err != nil {
return nil, err
if err := json.Unmarshal([]byte(v), vPtr); err != nil {
return err
}
return a, nil
return nil
}
func (db DB) Put(ctx context.Context, database int, k string, v any) error {

View File

@@ -18,9 +18,10 @@ func TestDBSqlite(t *testing.T) {
}
t.Cleanup(func() { db.Close() })
if v, err := db.Get(ctx, 5, "k"); err != nil {
v := -1
if err := db.Get(ctx, 5, "k", &v); err != nil {
t.Error("failed get 404:", err)
} else if v != nil {
} else if v != -1 {
t.Error("got 404:", v)
}
@@ -30,9 +31,15 @@ func TestDBSqlite(t *testing.T) {
t.Error("failed update:", err)
}
if v, err := db.Get(ctx, 5, "k"); err != nil {
if err := db.Get(ctx, 5, "k", &v); err != nil {
t.Error("failed get:", err)
} else if v != float64(7) {
} else if v != 7 {
t.Errorf("wrong get: (%T) %v but wanted 7", v, v)
}
if v, err := src.Select[int](ctx, 5, "k"); err != nil {
t.Error("failed select[int]:", err)
} else if v == nil || *v != 7 {
t.Errorf("wrong get: (%T) %v but wanted 7", v, v)
}
@@ -42,9 +49,10 @@ func TestDBSqlite(t *testing.T) {
t.Error("failed del 404:", err)
}
if v, err := db.Get(ctx, 5, "k"); err != nil {
v = -1
if err := db.Get(ctx, 5, "k", &v); err != nil {
t.Error("failed get 410:", err)
} else if v != nil {
} else if v != -1 {
t.Error("wrong get 410:", v)
}
}