if rw.Lock then no need for nested rw.RLock

main
Bel LaPointe 2024-12-15 00:45:27 -07:00
parent 78511f25f3
commit 58904c8619
2 changed files with 17 additions and 6 deletions

View File

@ -14,16 +14,19 @@ type DB struct {
scheme string scheme string
conn string conn string
rw *sync.RWMutex rw *sync.RWMutex
locked *bool
} }
func NewDB(ctx context.Context, scheme, conn string) (DB, error) { func NewDB(ctx context.Context, scheme, conn string) (DB, error) {
ctx, can := context.WithTimeout(ctx, time.Second*10) ctx, can := context.WithTimeout(ctx, time.Second*10)
defer can() defer can()
locked := false
db := DB{ db := DB{
scheme: scheme, scheme: scheme,
conn: conn, conn: conn,
rw: &sync.RWMutex{}, rw: &sync.RWMutex{},
locked: &locked,
} }
sql, err := db.dial(ctx) sql, err := db.dial(ctx)
@ -38,12 +41,18 @@ func NewDB(ctx context.Context, scheme, conn string) (DB, error) {
func (db DB) WithLock(cb func() error) error { func (db DB) WithLock(cb func() error) error {
db.rw.Lock() db.rw.Lock()
defer db.rw.Unlock() defer db.rw.Unlock()
*db.locked = true
defer func() {
*db.locked = false
}()
return cb() return cb()
} }
func (db DB) Exec(ctx context.Context, q string, args ...any) error { func (db DB) Exec(ctx context.Context, q string, args ...any) error {
db.rw.RLock() if !*db.locked {
defer db.rw.RUnlock() db.rw.RLock()
defer db.rw.RUnlock()
}
return db.exec(ctx, q, args...) return db.exec(ctx, q, args...)
} }
@ -59,8 +68,10 @@ func (db DB) exec(ctx context.Context, q string, args ...any) error {
} }
func (db DB) Query(ctx context.Context, cb func(*sql.Rows) error, q string, args ...any) error { func (db DB) Query(ctx context.Context, cb func(*sql.Rows) error, q string, args ...any) error {
db.rw.RLock() if !*db.locked {
defer db.rw.RUnlock() db.rw.RLock()
defer db.rw.RUnlock()
}
return db.query(ctx, cb, q, args...) return db.query(ctx, cb, q, args...)
} }

View File

@ -51,7 +51,7 @@ func (games Games) GamesForUser(ctx context.Context, id string) ([]string, error
}, ` }, `
SELECT players.game_uuid SELECT players.game_uuid
FROM players FROM players
WHERE players.user_uuid=? WHERE players.user_uuid=?
`, id) `, id)
return result, err return result, err
} }
@ -225,7 +225,7 @@ func (games Games) createEvent(ctx context.Context, id string, v any) error {
return games.db.Exec(ctx, ` return games.db.Exec(ctx, `
INSERT INTO events ( INSERT INTO events (
game_uuid, game_uuid,
timestamp, timestamp,
payload payload
) VALUES (?, ?, ?) ) VALUES (?, ?, ?)