From f78d0d1f59166aaeca50d5f3dd90b66c989b8a9a Mon Sep 17 00:00:00 2001 From: Bel LaPointe <153096461+breel-render@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:08:08 -0700 Subject: [PATCH] to generic func Select --- src/db.go | 40 ++++++++++++++++++++++++++++++---------- src/db_test.go | 20 ++++++++++++++------ 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/src/db.go b/src/db.go index d963c5f..5920765 100644 --- a/src/db.go +++ b/src/db.go @@ -13,6 +13,8 @@ type DB struct { *sql.DB } +var global DB + func NewDB(ctx context.Context, driver, conn string) (DB, error) { sql, err := sql.Open(driver, conn) if err != nil { @@ -25,6 +27,7 @@ func NewDB(ctx context.Context, driver, conn string) (DB, error) { return db, err } + global = db return db, nil } @@ -45,30 +48,47 @@ func (db DB) setup(ctx context.Context) error { return nil } -func (db DB) Get(ctx context.Context, database int, k string) (any, error) { +func Select[V any](ctx context.Context, database int, k string) (*V, error) { + var v V + err := global.Get(ctx, database, k, &v) + if err != nil { + if err != sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &v, nil +} + +func (db DB) Get(ctx context.Context, database int, k string, vPtr any) error { + if err := db.get(ctx, database, k, vPtr); err != nil { + if err != sql.ErrNoRows { + return err + } + } + return nil +} + +func (db DB) get(ctx context.Context, database int, k string, vPtr any) error { row := db.QueryRowContext(ctx, ` SELECT value FROM data WHERE database=$1 AND key=$2 `, database, k) if err := row.Err(); err != nil { - return nil, err + return err } var v string if err := row.Scan(&v); err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err + return err } - var a any - if err := json.Unmarshal([]byte(v), &a); err != nil { - return nil, err + if err := json.Unmarshal([]byte(v), vPtr); err != nil { + return err } - return a, nil + return nil } func (db DB) Put(ctx context.Context, database int, k string, v any) error { diff --git a/src/db_test.go b/src/db_test.go index 11ab434..fadc5b2 100644 --- a/src/db_test.go +++ b/src/db_test.go @@ -18,9 +18,10 @@ func TestDBSqlite(t *testing.T) { } t.Cleanup(func() { db.Close() }) - if v, err := db.Get(ctx, 5, "k"); err != nil { + v := -1 + if err := db.Get(ctx, 5, "k", &v); err != nil { t.Error("failed get 404:", err) - } else if v != nil { + } else if v != -1 { t.Error("got 404:", v) } @@ -30,9 +31,15 @@ func TestDBSqlite(t *testing.T) { t.Error("failed update:", err) } - if v, err := db.Get(ctx, 5, "k"); err != nil { + if err := db.Get(ctx, 5, "k", &v); err != nil { t.Error("failed get:", err) - } else if v != float64(7) { + } else if v != 7 { + t.Errorf("wrong get: (%T) %v but wanted 7", v, v) + } + + if v, err := src.Select[int](ctx, 5, "k"); err != nil { + t.Error("failed select[int]:", err) + } else if v == nil || *v != 7 { t.Errorf("wrong get: (%T) %v but wanted 7", v, v) } @@ -42,9 +49,10 @@ func TestDBSqlite(t *testing.T) { t.Error("failed del 404:", err) } - if v, err := db.Get(ctx, 5, "k"); err != nil { + v = -1 + if err := db.Get(ctx, 5, "k", &v); err != nil { t.Error("failed get 410:", err) - } else if v != nil { + } else if v != -1 { t.Error("wrong get 410:", v) } }