diff --git a/src/lib/db.go b/src/lib/db.go new file mode 100644 index 0000000..a4722f2 --- /dev/null +++ b/src/lib/db.go @@ -0,0 +1,15 @@ +package lib + +import ( + "context" + "database/sql" +) + +func InjectDB(ctx context.Context, db *sql.DB) context.Context { + return context.WithValue(ctx, "__db__", db) +} + +func ExtractDB(ctx context.Context) *sql.DB { + v, _ := ctx.Value("__db__").(*sql.DB) + return v +} diff --git a/src/lib/db_test.go b/src/lib/db_test.go new file mode 100644 index 0000000..aec2322 --- /dev/null +++ b/src/lib/db_test.go @@ -0,0 +1,20 @@ +package lib_test + +import ( + "context" + "database/sql" + "gitea/price-is-wrong/src/lib" + "testing" +) + +func TestInjectDB(t *testing.T) { + ctx := context.Background() + db := &sql.DB{} + + injected := lib.InjectDB(ctx, db) + extracted := lib.ExtractDB(injected) + + if db != extracted { + t.Fatal("couldnt extract injected db") + } +}