from mutex to semaphore chan

main
Bel LaPointe 2025-05-08 11:35:41 -06:00
parent 3ac7ae63b6
commit 6259a4f179
3 changed files with 28 additions and 8 deletions

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"path" "path"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -70,13 +69,13 @@ func Inject(ctx context.Context, conn string) (context.Context, error) {
return ctx, err return ctx, err
} }
return context.WithValue(context.WithValue(ctx, ctxKey+"_lock", &sync.Mutex{}), ctxKey, db), ctx.Err() return context.WithValue(context.WithValue(ctx, ctxKey+"_lock", newSemaphore()), ctxKey, db), ctx.Err()
} }
func extract(ctx context.Context) (*sql.DB, *sync.Mutex, error) { func extract(ctx context.Context) (*sql.DB, semaphore, error) {
db := ctx.Value(ctxKey) db := ctx.Value(ctxKey)
if db == nil { if db == nil {
return nil, nil, fmt.Errorf("db not injected") return nil, nil, fmt.Errorf("db not injected")
} }
return db.(*sql.DB), ctx.Value(ctxKey + "_lock").(*sync.Mutex), nil return db.(*sql.DB), ctx.Value(ctxKey + "_lock").(semaphore), nil
} }

View File

@ -129,11 +129,11 @@ func Exec(ctx context.Context, q string, args ...any) error {
} }
func with(ctx context.Context, foo func(*sql.DB) error) error { func with(ctx context.Context, foo func(*sql.DB) error) error {
db, lock, err := extract(ctx) db, sem, err := extract(ctx)
if err != nil { if err != nil {
return err return err
} }
lock.Lock() return sem.With(ctx, func() error {
defer lock.Unlock() return foo(db)
return foo(db) })
} }

21
src/db/sem.go Normal file
View File

@ -0,0 +1,21 @@
package db
import "context"
type semaphore chan struct{}
func newSemaphore() semaphore {
return make(semaphore, 1)
}
func (semaphore semaphore) With(ctx context.Context, cb func() error) error {
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}
defer func() {
<-semaphore
}()
return cb()
}