diff --git a/main_test.go b/main_test.go index 1111c75..6ce3a3d 100644 --- a/main_test.go +++ b/main_test.go @@ -3,6 +3,7 @@ package main_test import ( "context" main "show-rss" + "show-rss/src/db" "testing" "time" ) @@ -11,7 +12,7 @@ func TestMain(t *testing.T) { ctx, can := context.WithTimeout(context.Background(), 2*time.Second) defer can() - if err := main.Main(ctx); err != nil && ctx.Err() == nil { + if err := main.Main(db.Test(t, ctx)); err != nil && ctx.Err() == nil { t.Fatal(err) } } diff --git a/src/cmd/config.go b/src/cmd/config.go index 1e99165..391f70b 100644 --- a/src/cmd/config.go +++ b/src/cmd/config.go @@ -2,14 +2,17 @@ package cmd import ( "context" + "show-rss/src/cleanup" "show-rss/src/db" ) -func Config(ctx context.Context) (context.Context, error) { +func Config(ctx context.Context) (context.Context, func(), error) { ctx, err := db.Inject(ctx, "/tmp/f.db") if err != nil { - return ctx, err + return ctx, nil, err } - return ctx, nil + return ctx, func() { + cleanup.Extract(ctx)() + }, nil } diff --git a/src/cmd/cron/main.go b/src/cmd/cron/main.go index f6d81ca..2ce8bc2 100644 --- a/src/cmd/cron/main.go +++ b/src/cmd/cron/main.go @@ -2,8 +2,9 @@ package cron import ( "context" - "io" + "fmt" "show-rss/src/db" + "strings" "time" ) @@ -11,7 +12,7 @@ func Main(ctx context.Context) error { c := time.NewTicker(time.Minute) defer c.Stop() for { - if err := one(ctx); err != nil { + if err := One(ctx); err != nil { return err } @@ -23,19 +24,49 @@ func Main(ctx context.Context) error { return ctx.Err() } -func one(ctx context.Context) error { +func One(ctx context.Context) error { if err := initDB(ctx); err != nil { - return err + return fmt.Errorf("failed init db: %w", err) } - return io.EOF + return nil } func initDB(ctx context.Context) error { - return db.Exec(ctx, ` - CREATE TABLE IF NOT EXISTS feeds ( - id SERIAL - ); - ALTER TABLE feeds ADD COLUMN IF NOT EXISTS b TEXT; - ALTER TABLE feeds ADD COLUMN IF NOT EXISTS b TEXT; - `) + if err := db.Exec(ctx, `CREATE TABLE IF NOT EXISTS database_version (v NUMBER, t TIMESTAMP)`); err != nil { + return fmt.Errorf("failed to create database_version table: %w", err) + } + + type DatabaseVersion struct { + V int `json:"v"` + T time.Time `json:"t"` + } + vs, err := db.Query[DatabaseVersion](ctx, `SELECT v, t FROM database_version ORDER BY v DESC LIMIT 1`) + if err != nil { + return err + } + var v DatabaseVersion + if len(vs) > 0 { + v = vs[0] + } + + mods := []string{ + `CREATE TABLE feeds ( + id SERIAL PRIMARY KEY NOT NULL, + created_at TIMESTAMP, + updated_at TIMESTAMP, + deleted_at TIMESTAMP + )`, + } + mods = append([]string{""}, mods...) + for i := v.V + 1; i < len(mods); i++ { + q := mods[i] + q = strings.TrimSpace(q) + q = strings.TrimSuffix(q, ";") + q = fmt.Sprintf("BEGIN; %s; INSERT INTO database_version (v, t) VALUES (?, ?); COMMIT;", q) + if err := db.Exec(ctx, q, i, time.Now()); err != nil { + return fmt.Errorf("[%d] failed mod %s: %w", i, mods[i], err) + } + } + + return nil } diff --git a/src/cmd/cron/main_test.go b/src/cmd/cron/main_test.go new file mode 100644 index 0000000..90d43a1 --- /dev/null +++ b/src/cmd/cron/main_test.go @@ -0,0 +1,36 @@ +package cron_test + +import ( + "context" + "show-rss/src/cmd/cron" + "show-rss/src/db" + "strconv" + "testing" + "time" +) + +func TestOne(t *testing.T) { + ctx, can := context.WithTimeout(context.Background(), 5*time.Second) + defer can() + + t.Run("same ctx", func(t *testing.T) { + ctx := db.Test(t, ctx) + for i := 0; i < 2; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + if err := cron.One(ctx); err != nil && ctx.Err() == nil { + t.Fatalf("failed %d: %v", i, err) + } + }) + } + }) + + t.Run("new ctx", func(t *testing.T) { + for i := 0; i < 2; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + if err := cron.One(db.Test(t, ctx)); err != nil && ctx.Err() == nil { + t.Fatalf("failed %d: %v", i, err) + } + }) + } + }) +} diff --git a/src/cmd/main.go b/src/cmd/main.go index 3467860..950cdd7 100644 --- a/src/cmd/main.go +++ b/src/cmd/main.go @@ -14,10 +14,11 @@ func Main(ctx context.Context) error { ctx, can := context.WithCancel(ctx) defer can() - ctx, err := Config(ctx) + ctx, can, err := Config(ctx) if err != nil { return fmt.Errorf("failed to inject: %w", err) } + defer can() foos := map[string]func(context.Context) error{ "server": server.Main, diff --git a/src/db/ctx.go b/src/db/ctx.go index 7a6ede9..0909bc1 100644 --- a/src/db/ctx.go +++ b/src/db/ctx.go @@ -3,6 +3,9 @@ package db import ( "context" "fmt" + "path" + "strings" + "testing" "time" "database/sql" @@ -14,7 +17,27 @@ import ( const ctxKey = "__db" +func Test(t *testing.T, ctx context.Context) context.Context { + p := path.Join(t.TempDir(), strings.ReplaceAll(t.Name()+".db", "/", "_")) + ctx, err := Inject(ctx, p) + if err != nil { + t.Fatalf("failed to inject db %s: %v", p, err) + } + t.Cleanup(func() { + db, err := extract(ctx) + if err != nil { + return + } + db.Close() + }) + return ctx +} + func Inject(ctx context.Context, conn string) (context.Context, error) { + if _, err := extract(ctx); err == nil { + return ctx, nil + } + connctx, can := context.WithTimeout(ctx, 15*time.Second) defer can()