diff --git a/src/db/db.go b/src/db/db.go index 506ead9..5e75fc8 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "strings" ) func QueryOne[T any](ctx context.Context, q string, args ...any) (T, error) { @@ -33,12 +34,33 @@ func Query[T any](ctx context.Context, q string, args ...any) ([]T, error) { } scanners := func(columns []string) ([]any, error) { s := make([]any, len(columns)) - for i, k := range columns { - v, ok := m[k] - if !ok { - return nil, fmt.Errorf("cannot scan column %s to %T (%+v)", k, a, m) + i := 0 + for i < len(columns) { + k := columns[i] + if strings.Contains(k, ".") { + columns := strings.SplitN(k, ".", 2) + m2, ok := m[columns[0]] + if !ok { + return nil, fmt.Errorf("no column %s in %T (%+v)", columns[0], a, m) + } + m3, ok := m2.(map[string]any) + if !ok { + return nil, fmt.Errorf("cannot scan subfield %s of %s of %T (%+v)", columns[1], columns[0], a, m) + } + v, ok := m3[columns[1]] + if !ok { + return nil, fmt.Errorf("no subfield %s of %s of %T (%+v)", columns[1], columns[0], a, m) + } + s[i] = &v + i += 1 + } else { + v, ok := m[k] + if !ok { + return nil, fmt.Errorf("no column %s in %T (%+v)", k, a, m) + } + s[i] = &v + i += 1 } - s[i] = &v } return s, nil } @@ -67,7 +89,21 @@ func Query[T any](ctx context.Context, q string, args ...any) ([]T, error) { m := map[string]any{} for i, column := range columns { - m[column] = scanners[i] + if !strings.Contains(column, ".") { + m[column] = scanners[i] + } else { + columns := strings.SplitN(column, ".", 2) + m2, ok := m[columns[0]] + if !ok { + m2 = map[string]any{} + } + m3, ok := m2.(map[string]any) + if !ok { + return fmt.Errorf("%s is not a submap", columns[0]) + } + m3[columns[1]] = scanners[i] + m[columns[0]] = m3 + } } var a T diff --git a/src/db/db_test.go b/src/db/db_test.go index 4ea3f31..29d36ad 100644 --- a/src/db/db_test.go +++ b/src/db/db_test.go @@ -46,4 +46,15 @@ func TestDB(t *testing.T) { } else if gots[1].K != "b" { t.Errorf("expected [1]='b' but got %q", gots[1].K) } + + type NestedResult struct { + Nest struct { + K string `json:"k"` + } + } + if got, err := db.QueryOne[NestedResult](ctx, `SELECT k AS "Nest.k" FROM test WHERE k='a'`); err != nil { + t.Errorf("failed nested query one: %v", err) + } else if got.Nest.K != "a" { + t.Errorf("bad nested query one: %+v", got) + } }