package args import ( "encoding/json" "errors" "flag" "fmt" "io/ioutil" "os" "strings" "time" yaml "gopkg.in/yaml.v2" ) type ArgSet struct { parsed bool args []*Arg configFiles []string } func NewArgSet(configFiles ...string) *ArgSet { return &ArgSet{ args: []*Arg{}, configFiles: configFiles, } } func (as *ArgSet) String() string { s := "{" for _, arg := range as.args { if len(s) > 2 { s += ", " } s += fmt.Sprintf("%s:%v", arg.Flag, arg.Value) } s += "}" return s } func (as *ArgSet) Append(argType Type, key, help string, def interface{}) { as.args = append(as.args, NewArg(argType, key, help, def)) } func (as *ArgSet) Get(key string) *Arg { if !as.parsed { return nil } for i := range as.args { if as.args[i].Env == key || as.args[i].Flag == key { return as.args[i] } } return nil } func (as *ArgSet) GetFloat(key string) float32 { a := as.Get(key) return a.GetFloat() } func (as *ArgSet) GetInt(key string) int { a := as.Get(key) return a.GetInt() } func (as *ArgSet) GetBool(key string) bool { a := as.Get(key) return a.GetBool() } func (as *ArgSet) GetDuration(key string) time.Duration { a := as.Get(key) return a.GetDuration() } func (as *ArgSet) GetString(key string) string { a := as.Get(key) return a.GetString() } func (as *ArgSet) Parse() error { if as.parsed { return nil } defer func() { as.parsed = true }() if err := as.setValueFromDefaults(); err != nil { return err } if err := as.setValueFromFiles(); err != nil { return err } if err := as.setValueFromEnv(); err != nil { return err } if err := as.setValueFromFlags(); err != nil { return err } return nil } func (as *ArgSet) setValueFromDefaults() error { for i := range as.args { if err := as.args[i].Set(as.args[i].Default); err != nil { return err } } return nil } func (as *ArgSet) setValueFromEnv() error { for i := range as.args { key := as.args[i].Env key = strings.ReplaceAll(key, "-", "_") if v, ok := os.LookupEnv(key); ok { if err := as.args[i].Set(v); err != nil { return err } } } return nil } func (as *ArgSet) setValueFromFlags() error { var err error fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) for i := range as.args { arg := as.args[i] switch arg.ArgType { case INT: arg.Default = fs.Int(arg.Flag, arg.Value.(int), arg.Help) case STRING: arg.Default = fs.String(arg.Flag, arg.Value.(string), arg.Help) case BOOL: arg.Default = fs.Bool(arg.Flag, arg.Value.(bool), arg.Help) case DURATION: arg.Default = fs.Duration(arg.Flag, arg.Value.(time.Duration), arg.Help) case TIME: arg.Default = fs.String(arg.Flag, arg.Value.(time.Time).Format("2006-01-02"), arg.Help) case FLOAT: arg.Default = fs.Float64(arg.Flag, float64(arg.Value.(float32)), arg.Help) default: return errors.New("unknown type, cannot set from flag") } } if err := fs.Parse(os.Args[1:]); err != nil { return err } for i := range as.args { arg := as.args[i] switch arg.ArgType { case INT: arg.Value = *arg.Default.(*int) case STRING: arg.Value = *arg.Default.(*string) case BOOL: arg.Value = *arg.Default.(*bool) case DURATION: arg.Value = *arg.Default.(*time.Duration) case TIME: arg.Value = *arg.Default.(*string) err = arg.Set(arg.Value) case FLOAT: arg.Value = *arg.Default.(*float64) err = arg.Set(arg.Value) } } return err } func (as *ArgSet) setValueFromFiles() error { for _, file := range as.configFiles { if err := as.setValueFromFile(file); err != nil { return err } } return nil } func (as *ArgSet) setValueFromFile(path string) error { b, err := ioutil.ReadFile(path) if err != nil { return err } var config map[string]interface{} if strings.HasSuffix(path, ".json") { err = json.Unmarshal(b, &config) } else if strings.HasSuffix(path, ".yaml") { err = yaml.Unmarshal(b, &config) } else { err = fmt.Errorf("unknown config file suffix: %s", path) } if err != nil { return err } for i := range as.args { key := as.args[i].Flag if v, ok := config[key]; ok { if err := as.args[i].Set(v); err != nil { return err } } } return nil }