ooooo generic db.Query, db.QueryOne
parent
b7a7f2a82f
commit
738992468a
82
src/db/db.go
82
src/db/db.go
|
|
@ -3,14 +3,86 @@ package db
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func QueryOne(ctx context.Context, q string, args ...any) error {
|
||||
return with(ctx, func(db *sql.DB) error {
|
||||
row := db.QueryRowContext(ctx, q, args...)
|
||||
TODO generic and return value
|
||||
return row.Err()
|
||||
func QueryOne[T any](ctx context.Context, q string, args ...any) (T, error) {
|
||||
results, err := Query[T](ctx, q, args...)
|
||||
if err != nil || len(results) == 0 {
|
||||
var a T
|
||||
return a, err
|
||||
}
|
||||
|
||||
if len(results) > 1 {
|
||||
var a T
|
||||
return a, fmt.Errorf("expected exactly 1 result but got %d", len(results))
|
||||
}
|
||||
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
func Query[T any](ctx context.Context, q string, args ...any) ([]T, error) {
|
||||
var result []T
|
||||
|
||||
var a T
|
||||
var m map[string]any
|
||||
b, _ := json.Marshal(a)
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return nil, fmt.Errorf("%T is not map-like: %w", a, err)
|
||||
}
|
||||
scanners := func(columns []string) ([]any, error) {
|
||||
s := make([]any, len(columns))
|
||||
for i, k := range columns {
|
||||
v, ok := m[k]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot scan column %s to %T (%+v)", k, a, m)
|
||||
}
|
||||
s[i] = &v
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
err := with(ctx, func(db *sql.DB) error {
|
||||
rows, err := db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
scanners, err := scanners(columns)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rows.Scan(scanners...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := map[string]any{}
|
||||
for i, column := range columns {
|
||||
m[column] = scanners[i]
|
||||
}
|
||||
|
||||
var a T
|
||||
if b, err := json.Marshal(m); err != nil {
|
||||
return err
|
||||
} else if err := json.Unmarshal(b, &a); err != nil {
|
||||
return err
|
||||
} else {
|
||||
result = append(result, a)
|
||||
}
|
||||
}
|
||||
|
||||
return rows.Err()
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
func Exec(ctx context.Context, q string, args ...any) error {
|
||||
|
|
|
|||
|
|
@ -28,17 +28,17 @@ func TestDB(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
K string
|
||||
type result struct {
|
||||
K string `json:"k"`
|
||||
}
|
||||
if got, err := db.QueryOne[result](ctx, `SELECT k FROM test WHERE k='a'`); err != nil {
|
||||
t.Errorf("failed query one: %w", err)
|
||||
t.Errorf("failed query one: %v", err)
|
||||
} else if got.K != "a" {
|
||||
t.Errorf("bad query one: %+v", got)
|
||||
}
|
||||
|
||||
if gots, err := db.Query[result](ctx, `SELECT k FROM test`); err != nil {
|
||||
t.Errorf("failed query: %w", err)
|
||||
t.Errorf("failed query: %v", err)
|
||||
} else if len(gots) != 2 {
|
||||
t.Errorf("expected 2 but got %d gots", len(gots))
|
||||
} else if gots[0].K != "a" {
|
||||
|
|
|
|||
Loading…
Reference in New Issue