From bcf1369f286c8be25330c83f17c7108ee49c35ff Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Sat, 23 Mar 2019 15:29:31 -0600 Subject: [PATCH] this seeks to work --- arg.go | 72 +++++++++++++++++++++++--- arg_test.go | 137 +++++++++++++++++++++++++++++++++++++++++++++++++ argset.go | 55 ++++++++++++++++---- argset_test.go | 60 ++++++++++++++++++++++ type.go | 20 ++++++++ type_test.go | 18 +++++++ 6 files changed, 346 insertions(+), 16 deletions(-) create mode 100644 arg_test.go create mode 100644 argset_test.go create mode 100644 type_test.go diff --git a/arg.go b/arg.go index b452f11..63364ec 100644 --- a/arg.go +++ b/arg.go @@ -1,6 +1,10 @@ package args -import "time" +import ( + "fmt" + "strconv" + "time" +) type Arg struct { Env string @@ -17,34 +21,90 @@ func NewArg(argType Type, key, help string, def interface{}) *Arg { Flag: key, Help: help, Default: def, + Value: def, ArgType: argType, } } -func (a *Arg) GetInt(key string) int { +func (a *Arg) GetInt() int { if a.ArgType != INT { return -1 } + if _, ok := a.Value.(int); !ok { + return -1 + } return a.Value.(int) } -func (a *Arg) GetString(key string) string { +func (a *Arg) GetString() string { if a.ArgType != STRING { return "" } + if _, ok := a.Value.(string); !ok { + return "" + } return a.Value.(string) } -func (a *Arg) GetBool(key string) bool { +func (a *Arg) GetBool() bool { if a.ArgType != BOOL { return false } + if _, ok := a.Value.(bool); !ok { + return false + } return a.Value.(bool) } -func (a *Arg) GetDuration(key string) time.Duration { +func (a *Arg) GetDuration() time.Duration { if a.ArgType != DURATION { - return time.Duration(0) + return time.Duration(-1) + } + if _, ok := a.Value.(time.Duration); !ok { + return time.Duration(-1) } return a.Value.(time.Duration) } + +func (a *Arg) Set(value interface{}) error { + err := fmt.Errorf("incompatible set of type %T for arg $%v/-%v type %v", value, a.Env, a.Flag, a.ArgType) + switch a.ArgType { + case INT: + i, ok := value.(int) + if ok { + a.Value = i + err = nil + } else if s, ok := value.(string); !ok { + } else if i, err = strconv.Atoi(s); err == nil { + a.Value = i + err = nil + } + case STRING: + i, ok := value.(string) + if ok { + a.Value = i + err = nil + } + case BOOL: + i, ok := value.(bool) + if ok { + a.Value = i + err = nil + } else if s, ok := value.(string); !ok { + } else if i, err = strconv.ParseBool(s); err == nil { + a.Value = i + err = nil + } + case DURATION: + i, ok := value.(time.Duration) + if ok { + a.Value = i + err = nil + } else if s, ok := value.(string); !ok { + } else if i, err = time.ParseDuration(s); err == nil { + a.Value = i + err = nil + } + } + return err +} diff --git a/arg_test.go b/arg_test.go new file mode 100644 index 0000000..2c43ac5 --- /dev/null +++ b/arg_test.go @@ -0,0 +1,137 @@ +package args + +import ( + "testing" + "time" +) + +func TestNewArg(t *testing.T) { + NewArg(INT, "key", "help", 5) +} + +func TestGets(t *testing.T) { + cases := []struct { + t Type + v interface{} + ok bool + }{ + { + t: STRING, + v: 5, + ok: false, + }, + { + t: DURATION, + v: 5, + ok: false, + }, + { + t: BOOL, + v: 5, + ok: false, + }, + { + t: INT, + v: "5", + ok: false, + }, + { + t: STRING, + v: "5", + ok: true, + }, + { + t: DURATION, + v: time.Duration(5), + ok: true, + }, + { + t: BOOL, + v: true, + ok: true, + }, + { + t: INT, + v: 5, + ok: true, + }, + } + + for _, c := range cases { + a := NewArg(c.t, "key", "help", c.v) + switch c.t { + case INT: + if (a.GetInt() == c.v) != c.ok { + t.Errorf("failed to get int: want %v, got %v", c.v, a.GetInt()) + } + case STRING: + if (a.GetString() == c.v) != c.ok { + t.Errorf("failed to get string: want %v, got %v", c.v, a.GetString()) + } + case BOOL: + if (a.GetBool() == c.v) != c.ok { + t.Errorf("failed to get bool: want %v, got %v", c.v, a.GetBool()) + } + case DURATION: + if (a.GetDuration() == c.v) != c.ok { + t.Errorf("failed to get duration: want %v, got %v", c.v, a.GetDuration()) + } + } + } +} + +func TestSets(t *testing.T) { + cases := []struct { + t Type + v interface{} + ok bool + }{ + { + t: STRING, + v: 5, + ok: false, + }, + { + t: DURATION, + v: 5, + ok: false, + }, + { + t: BOOL, + v: 5, + ok: false, + }, + { + t: INT, + v: "a", + ok: false, + }, + { + t: STRING, + v: "5", + ok: true, + }, + { + t: DURATION, + v: time.Duration(5), + ok: true, + }, + { + t: BOOL, + v: true, + ok: true, + }, + { + t: INT, + v: 5, + ok: true, + }, + } + + for i, c := range cases { + a := NewArg(c.t, "key", "help", nil) + if err := a.Set(c.v); (err == nil) != c.ok { + t.Errorf("[%d] failed to set %v: want ok=%v, got %v", i, c.t, c.ok, err) + } + } +} diff --git a/argset.go b/argset.go index 74d0c47..18febe5 100644 --- a/argset.go +++ b/argset.go @@ -2,7 +2,9 @@ package args import ( "errors" + "flag" "os" + "time" ) type ArgSet struct { @@ -25,8 +27,8 @@ func (as *ArgSet) Get(key string) *Arg { return nil } for i := range as.args { - if as.args[i].Key == key { - return as[i] + if as.args[i].Env == key || as.args[i].Flag == key { + return as.args[i] } } return nil @@ -46,7 +48,7 @@ func (as *ArgSet) Parse() error { if err := as.setValueFromEnv(); err != nil { return err } - if err := as.setValueFromFlag(); err != nil { + if err := as.setValueFromFlags(); err != nil { return err } @@ -54,25 +56,58 @@ func (as *ArgSet) Parse() error { } func (as *ArgSet) setValueFromDefaults() error { - // TODO if not casted type, return err for i := range as.args { - args[i].Value = args[i].Default + if err := as.args[i].Set(as.args[i].Default); err != nil { + return err + } } return nil } func (as *ArgSet) setValueFromEnv() error { - // TODO if not casted type, return err for i := range as.args { - key := args[i].Key + key := as.args[i].Env if v, ok := os.LookupEnv(key); ok { - args[i].Value = v + if err := as.args[i].Set(v); err != nil { + return err + } } } return nil } func (as *ArgSet) setValueFromFlags() error { - // TODO if not casted type, return err - return errors.New("not impl") + fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + for i := range as.args { + arg := as.args[i] + switch arg.ArgType { + case INT: + arg.Default = fs.Int(arg.Flag, arg.Value.(int), arg.Help) + case STRING: + arg.Default = fs.String(arg.Flag, arg.Value.(string), arg.Help) + case BOOL: + arg.Default = fs.Bool(arg.Flag, arg.Value.(bool), arg.Help) + case DURATION: + arg.Default = fs.Duration(arg.Flag, arg.Value.(time.Duration), arg.Help) + default: + return errors.New("unknown type, cannot set from flag") + } + } + if err := fs.Parse(os.Args[1:]); err != nil { + return err + } + for i := range as.args { + arg := as.args[i] + switch arg.ArgType { + case INT: + arg.Value = *arg.Default.(*int) + case STRING: + arg.Value = *arg.Default.(*string) + case BOOL: + arg.Value = *arg.Default.(*bool) + case DURATION: + arg.Value = *arg.Default.(*time.Duration) + } + } + return nil } diff --git a/argset_test.go b/argset_test.go new file mode 100644 index 0000000..d51ac22 --- /dev/null +++ b/argset_test.go @@ -0,0 +1,60 @@ +package args + +import ( + "os" + "testing" +) + +func TestNewArgSet(t *testing.T) { + NewArgSet() +} + +func TestAppend(t *testing.T) { + as := NewArgSet() + as.Append(INT, "key", "help", nil) + if len(as.args) != 1 || *as.args[0] != *NewArg(INT, "key", "help", nil) { + t.Errorf("argset.Append failed: got %v, want %v", *as.args[0], *NewArg(INT, "key", "help", nil)) + } +} + +func TestParse(t *testing.T) { + osArgs := append([]string{}, os.Args[:]...) + defer func() { + os.Args = osArgs + }() + os.Args = []string{"nothing", "-testkeyFlag", "5"} + os.Setenv("testkeyEnv", "9") + as := NewArgSet() + as.Append(INT, "testkeyFlag", "help", 1) + as.Append(INT, "testkeyEnv", "help", 2) + as.Append(INT, "testkeyDefault", "help", 3) + if err := as.Parse(); err != nil { + t.Fatalf("cannot parse legal: %v", err) + } else if arg := as.Get("testkeyFlag"); arg == nil || arg.Value != 5 { + t.Errorf("cannot set from flag: got %v", arg) + } else if arg := as.Get("testkeyEnv"); arg == nil || arg.Value != 9 { + t.Errorf("cannot set from env: got %v", arg) + } else if arg := as.Get("testkeyDefault"); arg == nil || arg.Value != 3 { + t.Errorf("cannot set from default: got %v", arg) + } + + os.Args = []string{"nothing", "-testkey5", "bad"} + as = NewArgSet() + as.Append(INT, "testkey4", "help", "4") + if err := as.Parse(); err == nil { + t.Fatal(err) + } +} + +func TestGet(t *testing.T) { + as := NewArgSet() + as.Append(INT, "testkey1", "help", 1) + as.Append(INT, "testkey2", "help", 2) + as.Append(INT, "testkey3", "help", 3) + if err := as.Parse(); err != nil { + t.Fatal(err) + } + if as.Get("testkey1") == nil { + t.Errorf("cannot get parsed at %q", "testkey1") + } +} diff --git a/type.go b/type.go index c9f41cf..acd1588 100644 --- a/type.go +++ b/type.go @@ -1,5 +1,10 @@ package args +import ( + "fmt" + "time" +) + type Type int const ( @@ -8,3 +13,18 @@ const ( BOOL = Type(iota) DURATION = Type(iota) ) + +func (t Type) String() string { + var i interface{} = nil + switch t { + case INT: + i = 1 + case STRING: + i = "" + case BOOL: + i = false + case DURATION: + i = time.Duration(0) + } + return fmt.Sprintf("%T", i) +} diff --git a/type_test.go b/type_test.go new file mode 100644 index 0000000..f52e047 --- /dev/null +++ b/type_test.go @@ -0,0 +1,18 @@ +package args + +import "testing" + +func TestType(t *testing.T) { + if INT.String() != "int" { + t.Errorf("wrong string for INT") + } + if STRING.String() != "string" { + t.Errorf("wrong string for STRING") + } + if BOOL.String() != "bool" { + t.Errorf("wrong string for BOOL") + } + if DURATION.String() != "time.Duration" { + t.Errorf("wrong string for DURATION") + } +}