ooooo generic db.Query, db.QueryOne

main
bel 2025-04-24 22:10:56 -06:00
parent b7a7f2a82f
commit 738992468a
2 changed files with 81 additions and 9 deletions

View File

@ -3,14 +3,86 @@ package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
)
func QueryOne(ctx context.Context, q string, args ...any) error {
return with(ctx, func(db *sql.DB) error {
row := db.QueryRowContext(ctx, q, args...)
TODO generic and return value
return row.Err()
func QueryOne[T any](ctx context.Context, q string, args ...any) (T, error) {
results, err := Query[T](ctx, q, args...)
if err != nil || len(results) == 0 {
var a T
return a, err
}
if len(results) > 1 {
var a T
return a, fmt.Errorf("expected exactly 1 result but got %d", len(results))
}
return results[0], nil
}
func Query[T any](ctx context.Context, q string, args ...any) ([]T, error) {
var result []T
var a T
var m map[string]any
b, _ := json.Marshal(a)
if err := json.Unmarshal(b, &m); err != nil {
return nil, fmt.Errorf("%T is not map-like: %w", a, err)
}
scanners := func(columns []string) ([]any, error) {
s := make([]any, len(columns))
for i, k := range columns {
v, ok := m[k]
if !ok {
return nil, fmt.Errorf("cannot scan column %s to %T (%+v)", k, a, m)
}
s[i] = &v
}
return s, nil
}
err := with(ctx, func(db *sql.DB) error {
rows, err := db.QueryContext(ctx, q, args...)
if err != nil {
return err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return err
}
for rows.Next() {
scanners, err := scanners(columns)
if err != nil {
return err
}
if err := rows.Scan(scanners...); err != nil {
return err
}
m := map[string]any{}
for i, column := range columns {
m[column] = scanners[i]
}
var a T
if b, err := json.Marshal(m); err != nil {
return err
} else if err := json.Unmarshal(b, &a); err != nil {
return err
} else {
result = append(result, a)
}
}
return rows.Err()
})
return result, err
}
func Exec(ctx context.Context, q string, args ...any) error {

View File

@ -28,17 +28,17 @@ func TestDB(t *testing.T) {
t.Fatal(err)
}
var result struct {
K string
type result struct {
K string `json:"k"`
}
if got, err := db.QueryOne[result](ctx, `SELECT k FROM test WHERE k='a'`); err != nil {
t.Errorf("failed query one: %w", err)
t.Errorf("failed query one: %v", err)
} else if got.K != "a" {
t.Errorf("bad query one: %+v", got)
}
if gots, err := db.Query[result](ctx, `SELECT k FROM test`); err != nil {
t.Errorf("failed query: %w", err)
t.Errorf("failed query: %v", err)
} else if len(gots) != 2 {
t.Errorf("expected 2 but got %d gots", len(gots))
} else if gots[0].K != "a" {