ooooo generic db.Query, db.QueryOne

main
bel 2025-04-24 22:10:56 -06:00
parent b7a7f2a82f
commit 738992468a
2 changed files with 81 additions and 9 deletions

View File

@ -3,14 +3,86 @@ package db
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt"
) )
func QueryOne(ctx context.Context, q string, args ...any) error { func QueryOne[T any](ctx context.Context, q string, args ...any) (T, error) {
return with(ctx, func(db *sql.DB) error { results, err := Query[T](ctx, q, args...)
row := db.QueryRowContext(ctx, q, args...) if err != nil || len(results) == 0 {
TODO generic and return value var a T
return row.Err() return a, err
}
if len(results) > 1 {
var a T
return a, fmt.Errorf("expected exactly 1 result but got %d", len(results))
}
return results[0], nil
}
func Query[T any](ctx context.Context, q string, args ...any) ([]T, error) {
var result []T
var a T
var m map[string]any
b, _ := json.Marshal(a)
if err := json.Unmarshal(b, &m); err != nil {
return nil, fmt.Errorf("%T is not map-like: %w", a, err)
}
scanners := func(columns []string) ([]any, error) {
s := make([]any, len(columns))
for i, k := range columns {
v, ok := m[k]
if !ok {
return nil, fmt.Errorf("cannot scan column %s to %T (%+v)", k, a, m)
}
s[i] = &v
}
return s, nil
}
err := with(ctx, func(db *sql.DB) error {
rows, err := db.QueryContext(ctx, q, args...)
if err != nil {
return err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return err
}
for rows.Next() {
scanners, err := scanners(columns)
if err != nil {
return err
}
if err := rows.Scan(scanners...); err != nil {
return err
}
m := map[string]any{}
for i, column := range columns {
m[column] = scanners[i]
}
var a T
if b, err := json.Marshal(m); err != nil {
return err
} else if err := json.Unmarshal(b, &a); err != nil {
return err
} else {
result = append(result, a)
}
}
return rows.Err()
}) })
return result, err
} }
func Exec(ctx context.Context, q string, args ...any) error { func Exec(ctx context.Context, q string, args ...any) error {

View File

@ -28,17 +28,17 @@ func TestDB(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var result struct { type result struct {
K string K string `json:"k"`
} }
if got, err := db.QueryOne[result](ctx, `SELECT k FROM test WHERE k='a'`); err != nil { if got, err := db.QueryOne[result](ctx, `SELECT k FROM test WHERE k='a'`); err != nil {
t.Errorf("failed query one: %w", err) t.Errorf("failed query one: %v", err)
} else if got.K != "a" { } else if got.K != "a" {
t.Errorf("bad query one: %+v", got) t.Errorf("bad query one: %+v", got)
} }
if gots, err := db.Query[result](ctx, `SELECT k FROM test`); err != nil { if gots, err := db.Query[result](ctx, `SELECT k FROM test`); err != nil {
t.Errorf("failed query: %w", err) t.Errorf("failed query: %v", err)
} else if len(gots) != 2 { } else if len(gots) != 2 {
t.Errorf("expected 2 but got %d gots", len(gots)) t.Errorf("expected 2 but got %d gots", len(gots))
} else if gots[0].K != "a" { } else if gots[0].K != "a" {