main
Bel LaPointe 2025-02-12 10:37:02 -07:00
parent a2c44498eb
commit 688b7d9c01
8 changed files with 90 additions and 53 deletions

View File

@ -1,15 +0,0 @@
package lib
import (
"context"
"database/sql"
)
func InjectDB(ctx context.Context, db *sql.DB) context.Context {
return context.WithValue(ctx, "__db__", db)
}
func ExtractDB(ctx context.Context) *sql.DB {
v, _ := ctx.Value("__db__").(*sql.DB)
return v
}

19
src/lib/db/db.go Normal file
View File

@ -0,0 +1,19 @@
package db
import (
"context"
"database/sql"
)
func Inject(ctx context.Context, db *sql.DB) context.Context {
return context.WithValue(ctx, "__db__", db)
}
func From(ctx context.Context) *sql.DB {
return Extract(ctx)
}
func Extract(ctx context.Context) *sql.DB {
v, _ := ctx.Value("__db__").(*sql.DB)
return v
}

22
src/lib/db/db_test.go Normal file
View File

@ -0,0 +1,22 @@
package db_test
import (
"context"
"database/sql"
"gitea/price-is-wrong/src/lib/db"
"testing"
)
func TestInjectDB(t *testing.T) {
ctx := context.Background()
d := &sql.DB{}
injected := db.Inject(ctx, d)
extracted := db.Extract(injected)
if d != extracted {
t.Fatal("couldnt extract injected db")
} else if extracted != db.From(injected) {
t.Fatal("couldnt from extracted db")
}
}

View File

@ -1,20 +0,0 @@
package lib_test
import (
"context"
"database/sql"
"gitea/price-is-wrong/src/lib"
"testing"
)
func TestInjectDB(t *testing.T) {
ctx := context.Background()
db := &sql.DB{}
injected := lib.InjectDB(ctx, db)
extracted := lib.ExtractDB(injected)
if db != extracted {
t.Fatal("couldnt extract injected db")
}
}

View File

@ -6,15 +6,17 @@ import (
"path"
"testing"
"gitea/price-is-wrong/src/lib/db"
_ "github.com/glebarez/sqlite"
)
func NewTestCtx(t *testing.T) context.Context {
d := t.TempDir()
db, err := sql.Open("sqlite", path.Join(d, "db.db"))
b, err := sql.Open("sqlite", path.Join(d, "db.db"))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { db.Close() })
return InjectDB(context.Background(), db)
t.Cleanup(func() { b.Close() })
return db.Inject(context.Background(), b)
}

View File

@ -3,6 +3,7 @@ package lib_test
import (
"context"
"gitea/price-is-wrong/src/lib"
"gitea/price-is-wrong/src/lib/db"
"sync"
"testing"
)
@ -14,14 +15,12 @@ func TestTestCtx(t *testing.T) {
t.Run("subtest", func(t *testing.T) {
defer wg.Done()
ctx = lib.NewTestCtx(t)
db := lib.ExtractDB(ctx)
if _, err := db.Exec(`SELECT 1`); err != nil {
if _, err := db.Extract(ctx).Exec(`SELECT 1`); err != nil {
t.Fatal(err)
}
})
wg.Wait()
db := lib.ExtractDB(ctx)
if _, err := db.Exec(`SELECT 1`); err == nil {
if _, err := db.Extract(ctx).Exec(`SELECT 1`); err == nil {
t.Fatal(err)
}
}

View File

@ -1,11 +0,0 @@
package lobby
import "context"
type Lobby interface{}
type lobby struct{}
func NewLobby(ctx context.Context) (Lobby, error) {
return lobby{}, nil
}

View File

@ -0,0 +1,41 @@
package lobby
import (
"context"
"fmt"
"gitea/price-is-wrong/src/lib"
"io"
)
type Lobby struct{}
func Open(ctx context.Context, id string) (Lobby, error) {
result, err := open(ctx, id)
if err != nil {
return Lobby{}, err
}
if result != nil {
} else if err := create(ctx, id); err != nil {
return Lobby{}, err
} else if result, err = open(ctx, id); err != nil {
return Lobby{}, err
} else if result != nil {
} else {
return Lobby{}, fmt.Errorf("unable to create new lobby %s", id)
}
return *result, err
}
func open(ctx context.Context, id string) (*Lobby, error) {
return nil, io.EOF
}
func create(ctx context.Context, id string) error {
return io.EOF
}
func init(ctx context.Context) error {
_, err := lib.ExtractDB(ctx).ExecContext(ctx, `
`)
return err
}