This commit is contained in:
Bel LaPointe
2025-04-24 21:50:18 -06:00
parent 36f8efd0e7
commit b7a7f2a82f
9 changed files with 278 additions and 0 deletions

28
src/cleanup/cleanup.go Normal file
View File

@@ -0,0 +1,28 @@
package cleanup
import "context"
const ctxKey = "__cleanup"
func Inject(ctx context.Context, foo func()) context.Context {
before := Extract(ctx)
after := func() {
foo()
before()
}
return context.WithValue(ctx, ctxKey, after)
}
func Extract(ctx context.Context) func() {
v := ctx.Value(ctxKey)
if v == nil {
return func() {}
}
v2, _ := v.(func())
if v2 == nil {
return func() {}
}
return v2
}

View File

@@ -0,0 +1,33 @@
package cleanup_test
import (
"context"
"show-rss/src/cleanup"
"testing"
)
func TestCleanup(t *testing.T) {
ctx := context.Background()
called := make([]bool, 100)
for i := range called {
i := i
ctx = cleanup.Inject(ctx, func() {
t.Logf("cleaning %d", i)
if i < len(called)-1 && !called[i+1] {
t.Errorf("cleaning %d before %d", i, i+1)
}
called[i] = true
})
}
t.Logf("cleaning")
cleanup.Extract(ctx)()
t.Logf("cleaned")
for i := range called {
if !called[i] {
t.Fatalf("missing called[%d]", i)
}
}
}

15
src/cmd/config.go Normal file
View File

@@ -0,0 +1,15 @@
package cmd
import (
"context"
"show-rss/src/db"
)
func Config(ctx context.Context) (context.Context, error) {
ctx, err := db.Inject(ctx)
if err != nil {
return ctx, err
}
return ctx, nil
}

View File

@@ -14,6 +14,11 @@ func Main(ctx context.Context) error {
ctx, can := context.WithCancel(ctx)
defer can()
ctx, err := Config(ctx)
if err != nil {
return fmt.Errorf("failed to inject: %w", err)
}
foos := map[string]func(context.Context) error{
"server": server.Main,
"cron": cron.Main,

57
src/db/ctx.go Normal file
View File

@@ -0,0 +1,57 @@
package db
import (
"context"
"fmt"
"time"
"database/sql"
"show-rss/src/cleanup"
_ "modernc.org/sqlite"
)
const ctxKey = "__db"
func Inject(ctx context.Context, conn string) (context.Context, error) {
connctx, can := context.WithTimeout(ctx, 15*time.Second)
defer can()
db, err := sql.Open("sqlite", conn)
if err != nil {
return ctx, err
}
ctx = cleanup.Inject(ctx, func() {
db.Close()
})
if err := func() error {
c := time.NewTicker(100 * time.Millisecond)
defer c.Stop()
var err error
for connctx.Err() == nil {
if err = db.PingContext(connctx); err == nil {
return nil
}
select {
case <-connctx.Done():
case <-c.C:
}
}
return err
}(); err != nil {
return ctx, err
}
return context.WithValue(ctx, ctxKey, db), ctx.Err()
}
func extract(ctx context.Context) (*sql.DB, error) {
db := ctx.Value(ctxKey)
if db == nil {
return nil, fmt.Errorf("db not injected")
}
return db.(*sql.DB), nil
}

29
src/db/db.go Normal file
View File

@@ -0,0 +1,29 @@
package db
import (
"context"
"database/sql"
)
func QueryOne(ctx context.Context, q string, args ...any) error {
return with(ctx, func(db *sql.DB) error {
row := db.QueryRowContext(ctx, q, args...)
TODO generic and return value
return row.Err()
})
}
func Exec(ctx context.Context, q string, args ...any) error {
return with(ctx, func(db *sql.DB) error {
_, err := db.ExecContext(ctx, q, args...)
return err
})
}
func with(ctx context.Context, foo func(*sql.DB) error) error {
db, err := extract(ctx)
if err != nil {
return err
}
return foo(db)
}

49
src/db/db_test.go Normal file
View File

@@ -0,0 +1,49 @@
package db_test
import (
"context"
"path"
"show-rss/src/cleanup"
"show-rss/src/db"
"testing"
)
func TestDB(t *testing.T) {
ctx := context.Background()
ctx, err := db.Inject(ctx, path.Join(t.TempDir(), "db"))
if err != nil {
t.Fatal(err)
}
defer func() {
cleanup.Extract(ctx)()
}()
if err := db.Exec(ctx, `
CREATE TABLE IF NOT EXISTS test (k TEXT);
CREATE UNIQUE INDEX IF NOT EXISTS test_idx ON test (k);
INSERT INTO test (k) SELECT 'a';
INSERT INTO test (k) SELECT 'b';
`); err != nil {
t.Fatal(err)
}
var result struct {
K string
}
if got, err := db.QueryOne[result](ctx, `SELECT k FROM test WHERE k='a'`); err != nil {
t.Errorf("failed query one: %w", err)
} else if got.K != "a" {
t.Errorf("bad query one: %+v", got)
}
if gots, err := db.Query[result](ctx, `SELECT k FROM test`); err != nil {
t.Errorf("failed query: %w", err)
} else if len(gots) != 2 {
t.Errorf("expected 2 but got %d gots", len(gots))
} else if gots[0].K != "a" {
t.Errorf("expected [0]='a' but got %q", gots[0].K)
} else if gots[1].K != "b" {
t.Errorf("expected [1]='b' but got %q", gots[1].K)
}
}