From 81af991c58bcd2c69bf2c3424253a8167f520c0f Mon Sep 17 00:00:00 2001 From: bel Date: Sat, 14 Dec 2024 21:32:39 -0700 Subject: [PATCH] locking --- cmd/server/db.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/cmd/server/db.go b/cmd/server/db.go index bd8f05a..c8f60fa 100644 --- a/cmd/server/db.go +++ b/cmd/server/db.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "io" + "sync" "time" _ "github.com/glebarez/sqlite" @@ -12,6 +13,7 @@ import ( type DB struct { scheme string conn string + rw *sync.RWMutex } func NewDB(ctx context.Context, scheme, conn string) (DB, error) { @@ -21,6 +23,7 @@ func NewDB(ctx context.Context, scheme, conn string) (DB, error) { db := DB{ scheme: scheme, conn: conn, + rw: &sync.RWMutex{}, } sql, err := db.dial(ctx) @@ -53,7 +56,19 @@ func NewDB(ctx context.Context, scheme, conn string) (DB, error) { return db, err } +func (db DB) WithLock(cb func() error) error { + db.rw.Lock() + defer db.rw.Unlock() + return cb() +} + func (db DB) Exec(ctx context.Context, q string, args ...any) error { + db.rw.RLock() + defer db.rw.RUnlock() + return db.exec(ctx, q, args...) +} + +func (db DB) exec(ctx context.Context, q string, args ...any) error { c, err := db.dial(ctx) if err != nil { return err @@ -65,6 +80,12 @@ func (db DB) Exec(ctx context.Context, q string, args ...any) error { } func (db DB) Query(ctx context.Context, cb func(*sql.Rows) error, q string, args ...any) error { + db.rw.RLock() + defer db.rw.RUnlock() + return db.query(ctx, cb, q, args...) +} + +func (db DB) query(ctx context.Context, cb func(*sql.Rows) error, q string, args ...any) error { c, err := db.dial(ctx) if err != nil { return err