diff --git a/src/db/ctx.go b/src/db/ctx.go index 1ca8c30..f44e123 100644 --- a/src/db/ctx.go +++ b/src/db/ctx.go @@ -5,6 +5,7 @@ import ( "fmt" "path" "strings" + "sync" "testing" "time" @@ -25,7 +26,7 @@ func Test(t *testing.T, ctx context.Context) context.Context { t.Fatalf("failed to inject db %s: %v", p, err) } t.Cleanup(func() { - db, err := extract(ctx) + db, _, err := extract(ctx) if err != nil { return } @@ -35,7 +36,7 @@ func Test(t *testing.T, ctx context.Context) context.Context { } func Inject(ctx context.Context, conn string) (context.Context, error) { - if _, err := extract(ctx); err == nil { + if _, _, err := extract(ctx); err == nil { return ctx, nil } @@ -69,13 +70,13 @@ func Inject(ctx context.Context, conn string) (context.Context, error) { return ctx, err } - return context.WithValue(ctx, ctxKey, db), ctx.Err() + return context.WithValue(context.WithValue(ctx, ctxKey+"_lock", &sync.Mutex{}), ctxKey, db), ctx.Err() } -func extract(ctx context.Context) (*sql.DB, error) { +func extract(ctx context.Context) (*sql.DB, *sync.Mutex, error) { db := ctx.Value(ctxKey) if db == nil { - return nil, fmt.Errorf("db not injected") + return nil, nil, fmt.Errorf("db not injected") } - return db.(*sql.DB), nil + return db.(*sql.DB), ctx.Value(ctxKey + "_lock").(*sync.Mutex), nil } diff --git a/src/db/db.go b/src/db/db.go index 5e75fc8..7762ce5 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -129,9 +129,11 @@ func Exec(ctx context.Context, q string, args ...any) error { } func with(ctx context.Context, foo func(*sql.DB) error) error { - db, err := extract(ctx) + db, lock, err := extract(ctx) if err != nil { return err } + lock.Lock() + defer lock.Unlock() return foo(db) }