From 6259a4f1791d8341fcbe3a177d61e0810c1c7bc9 Mon Sep 17 00:00:00 2001 From: Bel LaPointe <153096461+breel-render@users.noreply.github.com> Date: Thu, 8 May 2025 11:35:41 -0600 Subject: [PATCH] from mutex to semaphore chan --- src/db/ctx.go | 7 +++---- src/db/db.go | 8 ++++---- src/db/sem.go | 21 +++++++++++++++++++++ 3 files changed, 28 insertions(+), 8 deletions(-) create mode 100644 src/db/sem.go diff --git a/src/db/ctx.go b/src/db/ctx.go index f44e123..ab9059f 100644 --- a/src/db/ctx.go +++ b/src/db/ctx.go @@ -5,7 +5,6 @@ import ( "fmt" "path" "strings" - "sync" "testing" "time" @@ -70,13 +69,13 @@ func Inject(ctx context.Context, conn string) (context.Context, error) { return ctx, err } - return context.WithValue(context.WithValue(ctx, ctxKey+"_lock", &sync.Mutex{}), ctxKey, db), ctx.Err() + return context.WithValue(context.WithValue(ctx, ctxKey+"_lock", newSemaphore()), ctxKey, db), ctx.Err() } -func extract(ctx context.Context) (*sql.DB, *sync.Mutex, error) { +func extract(ctx context.Context) (*sql.DB, semaphore, error) { db := ctx.Value(ctxKey) if db == nil { return nil, nil, fmt.Errorf("db not injected") } - return db.(*sql.DB), ctx.Value(ctxKey + "_lock").(*sync.Mutex), nil + return db.(*sql.DB), ctx.Value(ctxKey + "_lock").(semaphore), nil } diff --git a/src/db/db.go b/src/db/db.go index 7762ce5..3a10fc8 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -129,11 +129,11 @@ func Exec(ctx context.Context, q string, args ...any) error { } func with(ctx context.Context, foo func(*sql.DB) error) error { - db, lock, err := extract(ctx) + db, sem, err := extract(ctx) if err != nil { return err } - lock.Lock() - defer lock.Unlock() - return foo(db) + return sem.With(ctx, func() error { + return foo(db) + }) } diff --git a/src/db/sem.go b/src/db/sem.go new file mode 100644 index 0000000..a23cba9 --- /dev/null +++ b/src/db/sem.go @@ -0,0 +1,21 @@ +package db + +import "context" + +type semaphore chan struct{} + +func newSemaphore() semaphore { + return make(semaphore, 1) +} + +func (semaphore semaphore) With(ctx context.Context, cb func() error) error { + select { + case semaphore <- struct{}{}: + case <-ctx.Done(): + return ctx.Err() + } + defer func() { + <-semaphore + }() + return cb() +}