From 8b2b095c105d8e101806ee93e2681c14349f0991 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Sun, 3 Nov 2019 07:42:57 -0700 Subject: [PATCH] Optional yaml and json config files --- .gitignore | 0 arg.go | 23 +++++++++++----- arg_test.go | 0 argset.go | 55 ++++++++++++++++++++++++++++++++++--- argset_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++ type.go | 0 type_test.go | 0 7 files changed, 140 insertions(+), 11 deletions(-) mode change 100644 => 100755 .gitignore mode change 100644 => 100755 arg.go mode change 100644 => 100755 arg_test.go mode change 100644 => 100755 argset.go mode change 100644 => 100755 argset_test.go mode change 100644 => 100755 type.go mode change 100644 => 100755 type_test.go diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/arg.go b/arg.go old mode 100644 new mode 100755 index dde7eb4..c9f3f76 --- a/arg.go +++ b/arg.go @@ -102,14 +102,23 @@ func (a *Arg) Set(value interface{}) error { err := fmt.Errorf("incompatible set of type %T for arg $%v/-%v type %v", value, a.Env, a.Flag, a.ArgType) switch a.ArgType { case INT: - i, ok := value.(int) - if ok { - a.Value = i - err = nil - } else if s, ok := value.(string); !ok { - } else if i, err = strconv.Atoi(s); err == nil { - a.Value = i + switch value.(type) { + case int: + a.Value = value.(int) err = nil + case string: + s := value.(string) + var i int + if i, err = strconv.Atoi(s); err == nil { + a.Value = i + err = nil + } + case float64: + f := value.(float64) + if f-float64(int(f)) == 0 { + a.Value = int(f) + err = nil + } } case STRING: i, ok := value.(string) diff --git a/arg_test.go b/arg_test.go old mode 100644 new mode 100755 diff --git a/argset.go b/argset.go old mode 100644 new mode 100755 index 5c9d31f..dd54cbc --- a/argset.go +++ b/argset.go @@ -1,20 +1,28 @@ package args import ( + "encoding/json" "errors" "flag" + "fmt" + "io/ioutil" "os" + "strings" "time" + + yaml "gopkg.in/yaml.v2" ) type ArgSet struct { - parsed bool - args []*Arg + parsed bool + args []*Arg + configFiles []string } -func NewArgSet() *ArgSet { +func NewArgSet(configFiles ...string) *ArgSet { return &ArgSet{ - args: []*Arg{}, + args: []*Arg{}, + configFiles: configFiles, } } @@ -45,6 +53,9 @@ func (as *ArgSet) Parse() error { if err := as.setValueFromDefaults(); err != nil { return err } + if err := as.setValueFromFiles(); err != nil { + return err + } if err := as.setValueFromEnv(); err != nil { return err } @@ -122,3 +133,39 @@ func (as *ArgSet) setValueFromFlags() error { } return err } + +func (as *ArgSet) setValueFromFiles() error { + for _, file := range as.configFiles { + if err := as.setValueFromFile(file); err != nil { + return err + } + } + return nil +} + +func (as *ArgSet) setValueFromFile(path string) error { + b, err := ioutil.ReadFile(path) + if err != nil { + return err + } + var config map[string]interface{} + if strings.HasSuffix(path, ".json") { + err = json.Unmarshal(b, &config) + } else if strings.HasSuffix(path, ".yaml") { + err = yaml.Unmarshal(b, &config) + } else { + err = fmt.Errorf("unknown config file suffix: %s", path) + } + if err != nil { + return err + } + for i := range as.args { + key := as.args[i].Flag + if v, ok := config[key]; ok { + if err := as.args[i].Set(v); err != nil { + return err + } + } + } + return nil +} diff --git a/argset_test.go b/argset_test.go old mode 100644 new mode 100755 index 64af2fb..5c6f758 --- a/argset_test.go +++ b/argset_test.go @@ -1,7 +1,10 @@ package args import ( + "fmt" + "io/ioutil" "os" + "path" "testing" "time" ) @@ -68,6 +71,11 @@ func TestParseFloatTime(t *testing.T) { } func TestGet(t *testing.T) { + was := os.Args + defer func() { + os.Args = was + }() + os.Args = []string{"exec", "-testkey1", "5"} as := NewArgSet() as.Append(INT, "testkey1", "help", 1) as.Append(INT, "testkey2", "help", 2) @@ -79,3 +87,68 @@ func TestGet(t *testing.T) { t.Errorf("cannot get parsed at %q", "testkey1") } } + +func TestSetValueFromFileYAML(t *testing.T) { + was := os.Args + defer func() { + os.Args = was + }() + os.Args = []string{"exec"} + testSetValueFromFile(t, "yaml", ` + 1: 5 + 2: hello + 3: 5m + 4: true + 5: 2019-01-01 + 6: 5.5 + `) + testSetValueFromFile(t, "json", ` + { + "1": 5, + "2": "hello", + "3": "5m", + "4": true, + "5": "2019-01-01", + "6": 5.5 + } + `) +} + +func testSetValueFromFile(t *testing.T, suffix, content string) { + d, err := ioutil.TempDir(os.TempDir(), "prefix") + defer os.RemoveAll(d) + f, err := os.Create(path.Join(d, "conf."+suffix)) + if err != nil { + t.Fatal(err) + } + fmt.Fprintln(f, content) + f.Close() + as := NewArgSet(f.Name()) + as.Append(INT, "1", "help", 1) + as.Append(STRING, "2", "help", "world") + as.Append(DURATION, "3", "help", time.Hour) + as.Append(BOOL, "4", "help", false) + as.Append(TIME, "5", "help", time.Now()) + as.Append(FLOAT, "6", "help", 1.1) + if err := as.Parse(); err != nil { + t.Fatal(err) + } + if v := as.Get("1").GetInt(); v != 5 { + t.Fatal(v) + } + if v := as.Get("2").GetString(); v != "hello" { + t.Fatal(v) + } + if v := as.Get("3").GetDuration(); v != time.Minute*5 { + t.Fatal(v) + } + if v := as.Get("4").GetBool(); v != true { + t.Fatal(v) + } + if v := as.Get("5").GetTime(); v != time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC) { + t.Fatal(v) + } + if v := as.Get("6").GetFloat(); v != 5.5 { + t.Fatal(v) + } +} diff --git a/type.go b/type.go old mode 100644 new mode 100755 diff --git a/type_test.go b/type_test.go old mode 100644 new mode 100755