Compare commits

..

2 Commits

Author SHA1 Message Date
Bel LaPointe
6259a4f179 from mutex to semaphore chan 2025-05-08 11:35:41 -06:00
Bel LaPointe
3ac7ae63b6 db locks rather than returning dbInUse errs 2025-05-08 11:31:37 -06:00
3 changed files with 31 additions and 8 deletions

View File

@@ -25,7 +25,7 @@ func Test(t *testing.T, ctx context.Context) context.Context {
t.Fatalf("failed to inject db %s: %v", p, err) t.Fatalf("failed to inject db %s: %v", p, err)
} }
t.Cleanup(func() { t.Cleanup(func() {
db, err := extract(ctx) db, _, err := extract(ctx)
if err != nil { if err != nil {
return return
} }
@@ -35,7 +35,7 @@ func Test(t *testing.T, ctx context.Context) context.Context {
} }
func Inject(ctx context.Context, conn string) (context.Context, error) { func Inject(ctx context.Context, conn string) (context.Context, error) {
if _, err := extract(ctx); err == nil { if _, _, err := extract(ctx); err == nil {
return ctx, nil return ctx, nil
} }
@@ -69,13 +69,13 @@ func Inject(ctx context.Context, conn string) (context.Context, error) {
return ctx, err return ctx, err
} }
return context.WithValue(ctx, ctxKey, db), ctx.Err() return context.WithValue(context.WithValue(ctx, ctxKey+"_lock", newSemaphore()), ctxKey, db), ctx.Err()
} }
func extract(ctx context.Context) (*sql.DB, error) { func extract(ctx context.Context) (*sql.DB, semaphore, error) {
db := ctx.Value(ctxKey) db := ctx.Value(ctxKey)
if db == nil { if db == nil {
return nil, fmt.Errorf("db not injected") return nil, nil, fmt.Errorf("db not injected")
} }
return db.(*sql.DB), nil return db.(*sql.DB), ctx.Value(ctxKey + "_lock").(semaphore), nil
} }

View File

@@ -129,9 +129,11 @@ func Exec(ctx context.Context, q string, args ...any) error {
} }
func with(ctx context.Context, foo func(*sql.DB) error) error { func with(ctx context.Context, foo func(*sql.DB) error) error {
db, err := extract(ctx) db, sem, err := extract(ctx)
if err != nil { if err != nil {
return err return err
} }
return sem.With(ctx, func() error {
return foo(db) return foo(db)
})
} }

21
src/db/sem.go Normal file
View File

@@ -0,0 +1,21 @@
package db
import "context"
type semaphore chan struct{}
func newSemaphore() semaphore {
return make(semaphore, 1)
}
func (semaphore semaphore) With(ctx context.Context, cb func() error) error {
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}
defer func() {
<-semaphore
}()
return cb()
}