from mutex to semaphore chan
parent
3ac7ae63b6
commit
6259a4f179
|
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -70,13 +69,13 @@ func Inject(ctx context.Context, conn string) (context.Context, error) {
|
|||
return ctx, err
|
||||
}
|
||||
|
||||
return context.WithValue(context.WithValue(ctx, ctxKey+"_lock", &sync.Mutex{}), ctxKey, db), ctx.Err()
|
||||
return context.WithValue(context.WithValue(ctx, ctxKey+"_lock", newSemaphore()), ctxKey, db), ctx.Err()
|
||||
}
|
||||
|
||||
func extract(ctx context.Context) (*sql.DB, *sync.Mutex, error) {
|
||||
func extract(ctx context.Context) (*sql.DB, semaphore, error) {
|
||||
db := ctx.Value(ctxKey)
|
||||
if db == nil {
|
||||
return nil, nil, fmt.Errorf("db not injected")
|
||||
}
|
||||
return db.(*sql.DB), ctx.Value(ctxKey + "_lock").(*sync.Mutex), nil
|
||||
return db.(*sql.DB), ctx.Value(ctxKey + "_lock").(semaphore), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -129,11 +129,11 @@ func Exec(ctx context.Context, q string, args ...any) error {
|
|||
}
|
||||
|
||||
func with(ctx context.Context, foo func(*sql.DB) error) error {
|
||||
db, lock, err := extract(ctx)
|
||||
db, sem, err := extract(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
return foo(db)
|
||||
return sem.With(ctx, func() error {
|
||||
return foo(db)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
package db
|
||||
|
||||
import "context"
|
||||
|
||||
type semaphore chan struct{}
|
||||
|
||||
func newSemaphore() semaphore {
|
||||
return make(semaphore, 1)
|
||||
}
|
||||
|
||||
func (semaphore semaphore) With(ctx context.Context, cb func() error) error {
|
||||
select {
|
||||
case semaphore <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
defer func() {
|
||||
<-semaphore
|
||||
}()
|
||||
return cb()
|
||||
}
|
||||
Loading…
Reference in New Issue