change conf to argsset and flag for oauth

This commit is contained in:
bel
2019-10-22 04:52:04 +00:00
parent e20ba5361d
commit bc11dd7f82
10 changed files with 398 additions and 250 deletions

140
config/config.go Executable file → Normal file
View File

@@ -1,115 +1,77 @@
package config
import (
"local/rproxy3/storage/packable"
"log"
"strconv"
"fmt"
"strings"
"time"
)
func GetPort() string {
v := packable.NewString()
conf.Get(nsConf, flagPort, v)
return ":" + strings.TrimPrefix(v.String(), ":")
type Proxy struct {
To string
BOAuthZ bool
}
func GetRoutes() map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRoutes, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
if len(v) == 0 {
return m
}
from := v[:strings.Index(v, ":")]
to := v[strings.Index(v, ":")+1:]
m[from] = to
func parseProxy(s string) (string, Proxy) {
p := Proxy{}
key := ""
l := strings.Split(s, ",")
if len(l) > 0 {
key = l[0]
}
return m
if len(l) > 1 {
p.To = l[1]
}
if len(l) > 2 {
p.BOAuthZ = l[2] == "true"
}
return key, p
}
func GetTCP() (string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagTCP, v)
tcpAddr := v.String()
return tcpAddr, notEmpty(tcpAddr)
}
func GetSSL() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagCert, v)
certPath := v.String()
conf.Get(nsConf, flagKey, v)
keyPath := v.String()
return certPath, keyPath, notEmpty(certPath, keyPath)
func GetBOAuthZ() (string, bool) {
boauthz := conf.Get("oauth").GetString()
return boauthz, boauthz != ""
}
func GetAuth() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagUser, v)
user := v.String()
conf.Get(nsConf, flagPass, v)
pass := v.String()
return user, pass, notEmpty(user, pass)
user := conf.Get("user").GetString()
pass := conf.Get("pass").GetString()
return user, pass, user != "" && pass != ""
}
func notEmpty(s ...string) bool {
for i := range s {
if s[i] == "" || s[i] == "/dev/null" {
return false
}
}
return true
func GetPort() string {
port := conf.Get("p").GetInt()
return ":" + fmt.Sprint(port)
}
func GetRate() (int, int) {
r := packable.NewString()
conf.Get(nsConf, flagRate, r)
b := packable.NewString()
conf.Get(nsConf, flagBurst, b)
rate, err := strconv.Atoi(r.String())
if err != nil {
log.Printf("illegal rate: %v", err)
rate = 5
}
burst, _ := strconv.Atoi(b.String())
if err != nil {
log.Printf("illegal burst: %v", err)
burst = 5
}
rate := conf.Get("r").GetInt()
burst := conf.Get("b").GetInt()
return rate, burst
}
func GetTimeout() int {
t := packable.NewString()
conf.Get(nsConf, flagTimeout, t)
timeout, err := strconv.Atoi(t.String())
if err != nil || timeout == 5 {
return 5
func GetRoutes() map[string]Proxy {
list := conf.Get("proxy").GetString()
definitions := strings.Split(list, ",,")
routes := make(map[string]Proxy)
for _, definition := range definitions {
k, v := parseProxy(definition)
routes[k] = v
}
return routes
}
func GetSSL() (string, string, bool) {
crt := conf.Get("crt").GetString()
key := conf.Get("key").GetString()
return crt, key, crt != "" && key != ""
}
func GetTCP() (string, bool) {
tcp := conf.Get("tcp").GetString()
return tcp, tcp != ""
}
func GetTimeout() time.Duration {
timeout := conf.Get("timeout").GetDuration()
return timeout
}
func GetRewrites(hostMatch string) map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRewrites, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
vs := strings.Split(v, ":")
if len(v) < 3 {
continue
}
host := vs[0]
if host != hostMatch {
continue
}
from := vs[1]
to := strings.Join(vs[2:], ":")
m[from] = to
}
return m
}

172
config/new.go Executable file → Normal file
View File

@@ -1,161 +1,49 @@
package config
import (
"flag"
"io/ioutil"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
"fmt"
"local/args"
"log"
"os"
"strings"
yaml "gopkg.in/yaml.v2"
"time"
)
const nsConf = "configuration"
const flagPort = "p"
const flagRoutes = "r"
const flagConf = "c"
const flagCert = "crt"
const flagTCP = "tcp"
const flagKey = "key"
const flagUser = "user"
const flagPass = "pass"
const flagRate = "rate"
const flagBurst = "burst"
const flagTimeout = "timeout"
const flagRewrites = "rw"
var conf *args.ArgSet
var conf = storage.NewMap()
type toBind struct {
flag string
value *string
func init() {
if err := Refresh(); err != nil {
panic(err)
}
}
type fileConf struct {
Port string `yaml:"p"`
Routes []string `yaml:"r"`
CertPath string `yaml:"crt"`
TCPPath string `yaml:"tcp"`
KeyPath string `yaml:"key"`
Username string `yaml:"user"`
Password string `yaml:"pass"`
Rate string `yaml:"rate"`
Burst string `yaml:"burst"`
Timeout string `yaml:"timeout"`
Rewrites []string `yaml:"rw"`
}
func Init() error {
func Refresh() error {
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
log.SetFlags(log.Ltime | log.Lshortfile)
if err := fromFile(); err != nil {
return err
}
if err := fromFlags(); err != nil {
as, err := parseArgs()
if err != nil && !strings.Contains(fmt.Sprint(os.Args), "-test") {
return err
}
conf = as
return nil
}
func fromFile() error {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
defer func() {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
}()
flag.String(flagConf, "/dev/null", "yaml config file path")
flag.Parse()
confFlag := flag.Lookup(flagConf)
if confFlag == nil || confFlag.Value.String() == "" {
return nil
}
confBytes, err := ioutil.ReadFile(confFlag.Value.String())
if err != nil {
return err
}
var c fileConf
if err := yaml.Unmarshal(confBytes, &c); err != nil {
return err
}
if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil {
return err
}
if err := conf.Set(nsConf, flagCert, packable.NewString(c.CertPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTCP, packable.NewString(c.TCPPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagKey, packable.NewString(c.KeyPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagUser, packable.NewString(c.Username)); err != nil {
return err
}
if err := conf.Set(nsConf, flagPass, packable.NewString(c.Password)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRate, packable.NewString(c.Rate)); err != nil {
return err
}
if err := conf.Set(nsConf, flagBurst, packable.NewString(c.Burst)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
return err
}
return nil
}
func fromFlags() error {
binds := make([]toBind, 0)
binds = append(binds, addFlag(flagPort, "51555", "port to bind to"))
binds = append(binds, addFlag(flagConf, "", "configuration file path"))
binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port"))
binds = append(binds, addFlag(flagCert, "", "path to .crt"))
binds = append(binds, addFlag(flagTCP, "", "tcp addr"))
binds = append(binds, addFlag(flagKey, "", "path to .key"))
binds = append(binds, addFlag(flagUser, "", "basic auth username"))
binds = append(binds, addFlag(flagPass, "", "basic auth password"))
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement:oauth to rewrite in response bodies"))
flag.Parse()
for _, bind := range binds {
confFlag := flag.Lookup(bind.flag)
if confFlag == nil || confFlag.Value.String() == "" {
continue
}
if err := conf.Set(nsConf, bind.flag, packable.NewString(*bind.value)); err != nil {
return err
}
}
return nil
}
func addFlag(key, def, help string) toBind {
def = getFlagOrDefault(key, def)
v := flag.String(key, def, help)
return toBind{
flag: key,
value: v,
}
}
func getFlagOrDefault(key, def string) string {
v := packable.NewString()
if err := conf.Get(nsConf, key, v); err != nil {
return def
}
return v.String()
func parseArgs() (*args.ArgSet, error) {
as := args.NewArgSet()
as.Append(args.STRING, "user", "username for basic auth", "")
as.Append(args.STRING, "pass", "password for basic auth", "")
as.Append(args.INT, "p", "port for service", 51555)
as.Append(args.INT, "r", "rate per second for requests", 100)
as.Append(args.INT, "b", "burst requests", 100)
as.Append(args.STRING, "crt", "path to crt for ssl", "")
as.Append(args.STRING, "key", "path to key for ssl", "")
as.Append(args.STRING, "tcp", "address for tcp only tunnel", "")
as.Append(args.DURATION, "timeout", "timeout for tunnel", time.Minute)
as.Append(args.STRING, "proxy", "double-comma separated from,scheme://to.tld:port,oauth,,", "")
as.Append(args.STRING, "oauth", "url for boauthz", "")
err := as.Parse()
return as, err
}

View File

@@ -1,46 +0,0 @@
package config
import (
"flag"
"os"
"testing"
)
func TestInit(t *testing.T) {
was := os.Args[:]
os.Args = []string{"program"}
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
defer func() {
os.Args = was[:]
}()
if err := Init(); err != nil {
t.Errorf("failed to init: %v", err)
}
}
func TestFromFile(t *testing.T) {
was := os.Args[:]
os.Args = []string{"program"}
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
defer func() {
os.Args = was[:]
}()
if err := fromFile(); err != nil {
t.Errorf("failed from file: %v", err)
}
}
func TestFromFlags(t *testing.T) {
was := os.Args[:]
os.Args = []string{"program"}
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
defer func() {
os.Args = was[:]
}()
if err := fromFlags(); err != nil {
t.Errorf("failed from flags: %v", err)
}
}