From 738992468a06491fb13c3e69e4ba19b5fe56adfc Mon Sep 17 00:00:00 2001 From: bel Date: Thu, 24 Apr 2025 22:10:56 -0600 Subject: [PATCH] ooooo generic db.Query, db.QueryOne --- src/db/db.go | 82 ++++++++++++++++++++++++++++++++++++++++++++--- src/db/db_test.go | 8 ++--- 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/src/db/db.go b/src/db/db.go index 6a9500c..506ead9 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -3,14 +3,86 @@ package db import ( "context" "database/sql" + "encoding/json" + "fmt" ) -func QueryOne(ctx context.Context, q string, args ...any) error { - return with(ctx, func(db *sql.DB) error { - row := db.QueryRowContext(ctx, q, args...) - TODO generic and return value - return row.Err() +func QueryOne[T any](ctx context.Context, q string, args ...any) (T, error) { + results, err := Query[T](ctx, q, args...) + if err != nil || len(results) == 0 { + var a T + return a, err + } + + if len(results) > 1 { + var a T + return a, fmt.Errorf("expected exactly 1 result but got %d", len(results)) + } + + return results[0], nil +} + +func Query[T any](ctx context.Context, q string, args ...any) ([]T, error) { + var result []T + + var a T + var m map[string]any + b, _ := json.Marshal(a) + if err := json.Unmarshal(b, &m); err != nil { + return nil, fmt.Errorf("%T is not map-like: %w", a, err) + } + scanners := func(columns []string) ([]any, error) { + s := make([]any, len(columns)) + for i, k := range columns { + v, ok := m[k] + if !ok { + return nil, fmt.Errorf("cannot scan column %s to %T (%+v)", k, a, m) + } + s[i] = &v + } + return s, nil + } + + err := with(ctx, func(db *sql.DB) error { + rows, err := db.QueryContext(ctx, q, args...) + if err != nil { + return err + } + defer rows.Close() + + columns, err := rows.Columns() + if err != nil { + return err + } + + for rows.Next() { + scanners, err := scanners(columns) + if err != nil { + return err + } + + if err := rows.Scan(scanners...); err != nil { + return err + } + + m := map[string]any{} + for i, column := range columns { + m[column] = scanners[i] + } + + var a T + if b, err := json.Marshal(m); err != nil { + return err + } else if err := json.Unmarshal(b, &a); err != nil { + return err + } else { + result = append(result, a) + } + } + + return rows.Err() }) + return result, err } func Exec(ctx context.Context, q string, args ...any) error { diff --git a/src/db/db_test.go b/src/db/db_test.go index c70ee94..4ea3f31 100644 --- a/src/db/db_test.go +++ b/src/db/db_test.go @@ -28,17 +28,17 @@ func TestDB(t *testing.T) { t.Fatal(err) } - var result struct { - K string + type result struct { + K string `json:"k"` } if got, err := db.QueryOne[result](ctx, `SELECT k FROM test WHERE k='a'`); err != nil { - t.Errorf("failed query one: %w", err) + t.Errorf("failed query one: %v", err) } else if got.K != "a" { t.Errorf("bad query one: %+v", got) } if gots, err := db.Query[result](ctx, `SELECT k FROM test`); err != nil { - t.Errorf("failed query: %w", err) + t.Errorf("failed query: %v", err) } else if len(gots) != 2 { t.Errorf("expected 2 but got %d gots", len(gots)) } else if gots[0].K != "a" {