diff --git a/main.go b/main.go index 63d7b8d..a89b618 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "os" "os/signal" "show-rss/src/cmd" "syscall" @@ -17,7 +18,7 @@ func Main(ctx context.Context) error { ctx, can := signal.NotifyContext(ctx, syscall.SIGINT) defer can() - if err := cmd.Main(ctx); err != nil && ctx.Err() == nil { + if err := cmd.Main(ctx, os.Args[1:]); err != nil && ctx.Err() == nil { return err } diff --git a/main_test.go b/main_test.go index 6ce3a3d..ebd202c 100644 --- a/main_test.go +++ b/main_test.go @@ -2,6 +2,8 @@ package main_test import ( "context" + "os" + "path" main "show-rss" "show-rss/src/db" "testing" @@ -12,6 +14,8 @@ func TestMain(t *testing.T) { ctx, can := context.WithTimeout(context.Background(), 2*time.Second) defer can() + os.Args = []string{os.Args[0], "-db", path.Join(t.TempDir(), "db.db")} + if err := main.Main(db.Test(t, ctx)); err != nil && ctx.Err() == nil { t.Fatal(err) } diff --git a/src/cmd/config.go b/src/cmd/config.go index 391f70b..ac4d71b 100644 --- a/src/cmd/config.go +++ b/src/cmd/config.go @@ -2,12 +2,33 @@ package cmd import ( "context" + "flag" + "os" "show-rss/src/cleanup" "show-rss/src/db" ) -func Config(ctx context.Context) (context.Context, func(), error) { - ctx, err := db.Inject(ctx, "/tmp/f.db") +type Flags struct { + DB string +} + +func NewFlags(args []string) (Flags, error) { + var result Flags + + fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + fs.StringVar(&result.DB, "db", "/tmp/f.db", "path to sqlite.db") + err := fs.Parse(args) + + return result, err +} + +func Config(ctx context.Context, args []string) (context.Context, func(), error) { + flags, err := NewFlags(args) + if err != nil { + return ctx, nil, err + } + + ctx, err = db.Inject(ctx, flags.DB) if err != nil { return ctx, nil, err } diff --git a/src/cmd/main.go b/src/cmd/main.go index 950cdd7..0ef70cd 100644 --- a/src/cmd/main.go +++ b/src/cmd/main.go @@ -10,11 +10,11 @@ import ( "time" ) -func Main(ctx context.Context) error { +func Main(ctx context.Context, args []string) error { ctx, can := context.WithCancel(ctx) defer can() - ctx, can, err := Config(ctx) + ctx, can, err := Config(ctx, args) if err != nil { return fmt.Errorf("failed to inject: %w", err) }