Compare commits
2 Commits
786ea3ef8f
...
6259a4f179
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6259a4f179 | ||
|
|
3ac7ae63b6 |
@@ -25,7 +25,7 @@ func Test(t *testing.T, ctx context.Context) context.Context {
|
|||||||
t.Fatalf("failed to inject db %s: %v", p, err)
|
t.Fatalf("failed to inject db %s: %v", p, err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
db, err := extract(ctx)
|
db, _, err := extract(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -35,7 +35,7 @@ func Test(t *testing.T, ctx context.Context) context.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Inject(ctx context.Context, conn string) (context.Context, error) {
|
func Inject(ctx context.Context, conn string) (context.Context, error) {
|
||||||
if _, err := extract(ctx); err == nil {
|
if _, _, err := extract(ctx); err == nil {
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,13 +69,13 @@ func Inject(ctx context.Context, conn string) (context.Context, error) {
|
|||||||
return ctx, err
|
return ctx, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return context.WithValue(ctx, ctxKey, db), ctx.Err()
|
return context.WithValue(context.WithValue(ctx, ctxKey+"_lock", newSemaphore()), ctxKey, db), ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func extract(ctx context.Context) (*sql.DB, error) {
|
func extract(ctx context.Context) (*sql.DB, semaphore, error) {
|
||||||
db := ctx.Value(ctxKey)
|
db := ctx.Value(ctxKey)
|
||||||
if db == nil {
|
if db == nil {
|
||||||
return nil, fmt.Errorf("db not injected")
|
return nil, nil, fmt.Errorf("db not injected")
|
||||||
}
|
}
|
||||||
return db.(*sql.DB), nil
|
return db.(*sql.DB), ctx.Value(ctxKey + "_lock").(semaphore), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -129,9 +129,11 @@ func Exec(ctx context.Context, q string, args ...any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func with(ctx context.Context, foo func(*sql.DB) error) error {
|
func with(ctx context.Context, foo func(*sql.DB) error) error {
|
||||||
db, err := extract(ctx)
|
db, sem, err := extract(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return foo(db)
|
return sem.With(ctx, func() error {
|
||||||
|
return foo(db)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
21
src/db/sem.go
Normal file
21
src/db/sem.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type semaphore chan struct{}
|
||||||
|
|
||||||
|
func newSemaphore() semaphore {
|
||||||
|
return make(semaphore, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (semaphore semaphore) With(ctx context.Context, cb func() error) error {
|
||||||
|
select {
|
||||||
|
case semaphore <- struct{}{}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
<-semaphore
|
||||||
|
}()
|
||||||
|
return cb()
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user