package config import ( "flag" "io/ioutil" "local/rproxy3/storage" "local/rproxy3/storage/packable" "log" "os" "strings" yaml "gopkg.in/yaml.v2" ) const nsConf = "configuration" const flagPort = "p" const flagRoutes = "r" const flagConf = "c" const flagCert = "crt" const flagTCP = "tcp" const flagKey = "key" const flagUser = "user" const flagPass = "pass" const flagRate = "rate" const flagBurst = "burst" const flagTimeout = "timeout" const flagRewrites = "rw" var conf = storage.NewMap() type toBind struct { flag string value *string } type fileConf struct { Port string `yaml:"p"` Routes []string `yaml:"r"` CertPath string `yaml:"crt"` TCPPath string `yaml:"tcp"` KeyPath string `yaml:"key"` Username string `yaml:"user"` Password string `yaml:"pass"` Rate string `yaml:"rate"` Burst string `yaml:"burst"` Timeout string `yaml:"timeout"` Rewrites []string `yaml:"rw"` } func Init() error { log.SetFlags(log.Ldate | log.Ltime | log.Llongfile) log.SetFlags(log.Ltime | log.Lshortfile) if err := fromFile(); err != nil { return err } if err := fromFlags(); err != nil { return err } return nil } func fromFile() error { flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError) defer func() { flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) }() flag.String(flagConf, "/dev/null", "yaml config file path") flag.Parse() confFlag := flag.Lookup(flagConf) if confFlag == nil || confFlag.Value.String() == "" { return nil } confBytes, err := ioutil.ReadFile(confFlag.Value.String()) if err != nil { return err } var c fileConf if err := yaml.Unmarshal(confBytes, &c); err != nil { return err } if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil { return err } if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil { return err } if err := conf.Set(nsConf, flagCert, packable.NewString(c.CertPath)); err != nil { return err } if err := conf.Set(nsConf, flagTCP, packable.NewString(c.TCPPath)); err != nil { return err } if err := conf.Set(nsConf, flagKey, packable.NewString(c.KeyPath)); err != nil { return err } if err := conf.Set(nsConf, flagUser, packable.NewString(c.Username)); err != nil { return err } if err := conf.Set(nsConf, flagPass, packable.NewString(c.Password)); err != nil { return err } if err := conf.Set(nsConf, flagRate, packable.NewString(c.Rate)); err != nil { return err } if err := conf.Set(nsConf, flagBurst, packable.NewString(c.Burst)); err != nil { return err } if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil { return err } if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil { return err } return nil } func fromFlags() error { binds := make([]toBind, 0) binds = append(binds, addFlag(flagPort, "51555", "port to bind to")) binds = append(binds, addFlag(flagConf, "", "configuration file path")) binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port")) binds = append(binds, addFlag(flagCert, "", "path to .crt")) binds = append(binds, addFlag(flagTCP, "", "tcp addr")) binds = append(binds, addFlag(flagKey, "", "path to .key")) binds = append(binds, addFlag(flagUser, "", "basic auth username")) binds = append(binds, addFlag(flagPass, "", "basic auth password")) binds = append(binds, addFlag(flagRate, "100", "rate limit per second")) binds = append(binds, addFlag(flagBurst, "100", "rate limit burst")) binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter")) binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement:oauth to rewrite in response bodies")) flag.Parse() for _, bind := range binds { confFlag := flag.Lookup(bind.flag) if confFlag == nil || confFlag.Value.String() == "" { continue } if err := conf.Set(nsConf, bind.flag, packable.NewString(*bind.value)); err != nil { return err } } return nil } func addFlag(key, def, help string) toBind { def = getFlagOrDefault(key, def) v := flag.String(key, def, help) return toBind{ flag: key, value: v, } } func getFlagOrDefault(key, def string) string { v := packable.NewString() if err := conf.Get(nsConf, key, v); err != nil { return def } return v.String() }