change conf to argsset and flag for oauth

master v0.4
bel 2019-10-22 04:52:04 +00:00
parent e20ba5361d
commit bc11dd7f82
10 changed files with 398 additions and 250 deletions

115
.config/config.go Executable file
View File

@ -0,0 +1,115 @@
package config
import (
"local/rproxy3/storage/packable"
"log"
"strconv"
"strings"
)
func GetPort() string {
v := packable.NewString()
conf.Get(nsConf, flagPort, v)
return ":" + strings.TrimPrefix(v.String(), ":")
}
func GetRoutes() map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRoutes, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
if len(v) == 0 {
return m
}
from := v[:strings.Index(v, ":")]
to := v[strings.Index(v, ":")+1:]
m[from] = to
}
return m
}
func GetTCP() (string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagTCP, v)
tcpAddr := v.String()
return tcpAddr, notEmpty(tcpAddr)
}
func GetSSL() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagCert, v)
certPath := v.String()
conf.Get(nsConf, flagKey, v)
keyPath := v.String()
return certPath, keyPath, notEmpty(certPath, keyPath)
}
func GetAuth() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagUser, v)
user := v.String()
conf.Get(nsConf, flagPass, v)
pass := v.String()
return user, pass, notEmpty(user, pass)
}
func notEmpty(s ...string) bool {
for i := range s {
if s[i] == "" || s[i] == "/dev/null" {
return false
}
}
return true
}
func GetRate() (int, int) {
r := packable.NewString()
conf.Get(nsConf, flagRate, r)
b := packable.NewString()
conf.Get(nsConf, flagBurst, b)
rate, err := strconv.Atoi(r.String())
if err != nil {
log.Printf("illegal rate: %v", err)
rate = 5
}
burst, _ := strconv.Atoi(b.String())
if err != nil {
log.Printf("illegal burst: %v", err)
burst = 5
}
return rate, burst
}
func GetTimeout() int {
t := packable.NewString()
conf.Get(nsConf, flagTimeout, t)
timeout, err := strconv.Atoi(t.String())
if err != nil || timeout == 5 {
return 5
}
return timeout
}
func GetRewrites(hostMatch string) map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRewrites, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
vs := strings.Split(v, ":")
if len(v) < 3 {
continue
}
host := vs[0]
if host != hostMatch {
continue
}
from := vs[1]
to := strings.Join(vs[2:], ":")
m[from] = to
}
return m
}

161
.config/new.go Executable file
View File

@ -0,0 +1,161 @@
package config
import (
"flag"
"io/ioutil"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
"log"
"os"
"strings"
yaml "gopkg.in/yaml.v2"
)
const nsConf = "configuration"
const flagPort = "p"
const flagRoutes = "r"
const flagConf = "c"
const flagCert = "crt"
const flagTCP = "tcp"
const flagKey = "key"
const flagUser = "user"
const flagPass = "pass"
const flagRate = "rate"
const flagBurst = "burst"
const flagTimeout = "timeout"
const flagRewrites = "rw"
var conf = storage.NewMap()
type toBind struct {
flag string
value *string
}
type fileConf struct {
Port string `yaml:"p"`
Routes []string `yaml:"r"`
CertPath string `yaml:"crt"`
TCPPath string `yaml:"tcp"`
KeyPath string `yaml:"key"`
Username string `yaml:"user"`
Password string `yaml:"pass"`
Rate string `yaml:"rate"`
Burst string `yaml:"burst"`
Timeout string `yaml:"timeout"`
Rewrites []string `yaml:"rw"`
}
func Init() error {
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
log.SetFlags(log.Ltime | log.Lshortfile)
if err := fromFile(); err != nil {
return err
}
if err := fromFlags(); err != nil {
return err
}
return nil
}
func fromFile() error {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
defer func() {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
}()
flag.String(flagConf, "/dev/null", "yaml config file path")
flag.Parse()
confFlag := flag.Lookup(flagConf)
if confFlag == nil || confFlag.Value.String() == "" {
return nil
}
confBytes, err := ioutil.ReadFile(confFlag.Value.String())
if err != nil {
return err
}
var c fileConf
if err := yaml.Unmarshal(confBytes, &c); err != nil {
return err
}
if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil {
return err
}
if err := conf.Set(nsConf, flagCert, packable.NewString(c.CertPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTCP, packable.NewString(c.TCPPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagKey, packable.NewString(c.KeyPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagUser, packable.NewString(c.Username)); err != nil {
return err
}
if err := conf.Set(nsConf, flagPass, packable.NewString(c.Password)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRate, packable.NewString(c.Rate)); err != nil {
return err
}
if err := conf.Set(nsConf, flagBurst, packable.NewString(c.Burst)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
return err
}
return nil
}
func fromFlags() error {
binds := make([]toBind, 0)
binds = append(binds, addFlag(flagPort, "51555", "port to bind to"))
binds = append(binds, addFlag(flagConf, "", "configuration file path"))
binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port"))
binds = append(binds, addFlag(flagCert, "", "path to .crt"))
binds = append(binds, addFlag(flagTCP, "", "tcp addr"))
binds = append(binds, addFlag(flagKey, "", "path to .key"))
binds = append(binds, addFlag(flagUser, "", "basic auth username"))
binds = append(binds, addFlag(flagPass, "", "basic auth password"))
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement:oauth to rewrite in response bodies"))
flag.Parse()
for _, bind := range binds {
confFlag := flag.Lookup(bind.flag)
if confFlag == nil || confFlag.Value.String() == "" {
continue
}
if err := conf.Set(nsConf, bind.flag, packable.NewString(*bind.value)); err != nil {
return err
}
}
return nil
}
func addFlag(key, def, help string) toBind {
def = getFlagOrDefault(key, def)
v := flag.String(key, def, help)
return toBind{
flag: key,
value: v,
}
}
func getFlagOrDefault(key, def string) string {
v := packable.NewString()
if err := conf.Get(nsConf, key, v); err != nil {
return def
}
return v.String()
}

138
config/config.go Executable file → Normal file
View File

@ -1,115 +1,77 @@
package config
import (
"local/rproxy3/storage/packable"
"log"
"strconv"
"fmt"
"strings"
"time"
)
func GetPort() string {
v := packable.NewString()
conf.Get(nsConf, flagPort, v)
return ":" + strings.TrimPrefix(v.String(), ":")
type Proxy struct {
To string
BOAuthZ bool
}
func GetRoutes() map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRoutes, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
if len(v) == 0 {
return m
func parseProxy(s string) (string, Proxy) {
p := Proxy{}
key := ""
l := strings.Split(s, ",")
if len(l) > 0 {
key = l[0]
}
from := v[:strings.Index(v, ":")]
to := v[strings.Index(v, ":")+1:]
m[from] = to
if len(l) > 1 {
p.To = l[1]
}
return m
if len(l) > 2 {
p.BOAuthZ = l[2] == "true"
}
return key, p
}
func GetTCP() (string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagTCP, v)
tcpAddr := v.String()
return tcpAddr, notEmpty(tcpAddr)
}
func GetSSL() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagCert, v)
certPath := v.String()
conf.Get(nsConf, flagKey, v)
keyPath := v.String()
return certPath, keyPath, notEmpty(certPath, keyPath)
func GetBOAuthZ() (string, bool) {
boauthz := conf.Get("oauth").GetString()
return boauthz, boauthz != ""
}
func GetAuth() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagUser, v)
user := v.String()
conf.Get(nsConf, flagPass, v)
pass := v.String()
return user, pass, notEmpty(user, pass)
user := conf.Get("user").GetString()
pass := conf.Get("pass").GetString()
return user, pass, user != "" && pass != ""
}
func notEmpty(s ...string) bool {
for i := range s {
if s[i] == "" || s[i] == "/dev/null" {
return false
}
}
return true
func GetPort() string {
port := conf.Get("p").GetInt()
return ":" + fmt.Sprint(port)
}
func GetRate() (int, int) {
r := packable.NewString()
conf.Get(nsConf, flagRate, r)
b := packable.NewString()
conf.Get(nsConf, flagBurst, b)
rate, err := strconv.Atoi(r.String())
if err != nil {
log.Printf("illegal rate: %v", err)
rate = 5
}
burst, _ := strconv.Atoi(b.String())
if err != nil {
log.Printf("illegal burst: %v", err)
burst = 5
}
rate := conf.Get("r").GetInt()
burst := conf.Get("b").GetInt()
return rate, burst
}
func GetTimeout() int {
t := packable.NewString()
conf.Get(nsConf, flagTimeout, t)
timeout, err := strconv.Atoi(t.String())
if err != nil || timeout == 5 {
return 5
func GetRoutes() map[string]Proxy {
list := conf.Get("proxy").GetString()
definitions := strings.Split(list, ",,")
routes := make(map[string]Proxy)
for _, definition := range definitions {
k, v := parseProxy(definition)
routes[k] = v
}
return routes
}
func GetSSL() (string, string, bool) {
crt := conf.Get("crt").GetString()
key := conf.Get("key").GetString()
return crt, key, crt != "" && key != ""
}
func GetTCP() (string, bool) {
tcp := conf.Get("tcp").GetString()
return tcp, tcp != ""
}
func GetTimeout() time.Duration {
timeout := conf.Get("timeout").GetDuration()
return timeout
}
func GetRewrites(hostMatch string) map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRewrites, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
vs := strings.Split(v, ":")
if len(v) < 3 {
continue
}
host := vs[0]
if host != hostMatch {
continue
}
from := vs[1]
to := strings.Join(vs[2:], ":")
m[from] = to
}
return m
}

172
config/new.go Executable file → Normal file
View File

@ -1,161 +1,49 @@
package config
import (
"flag"
"io/ioutil"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
"fmt"
"local/args"
"log"
"os"
"strings"
yaml "gopkg.in/yaml.v2"
"time"
)
const nsConf = "configuration"
const flagPort = "p"
const flagRoutes = "r"
const flagConf = "c"
const flagCert = "crt"
const flagTCP = "tcp"
const flagKey = "key"
const flagUser = "user"
const flagPass = "pass"
const flagRate = "rate"
const flagBurst = "burst"
const flagTimeout = "timeout"
const flagRewrites = "rw"
var conf *args.ArgSet
var conf = storage.NewMap()
type toBind struct {
flag string
value *string
func init() {
if err := Refresh(); err != nil {
panic(err)
}
}
type fileConf struct {
Port string `yaml:"p"`
Routes []string `yaml:"r"`
CertPath string `yaml:"crt"`
TCPPath string `yaml:"tcp"`
KeyPath string `yaml:"key"`
Username string `yaml:"user"`
Password string `yaml:"pass"`
Rate string `yaml:"rate"`
Burst string `yaml:"burst"`
Timeout string `yaml:"timeout"`
Rewrites []string `yaml:"rw"`
}
func Init() error {
func Refresh() error {
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
log.SetFlags(log.Ltime | log.Lshortfile)
if err := fromFile(); err != nil {
return err
}
if err := fromFlags(); err != nil {
as, err := parseArgs()
if err != nil && !strings.Contains(fmt.Sprint(os.Args), "-test") {
return err
}
conf = as
return nil
}
func fromFile() error {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
defer func() {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
}()
flag.String(flagConf, "/dev/null", "yaml config file path")
flag.Parse()
confFlag := flag.Lookup(flagConf)
if confFlag == nil || confFlag.Value.String() == "" {
return nil
}
confBytes, err := ioutil.ReadFile(confFlag.Value.String())
if err != nil {
return err
}
var c fileConf
if err := yaml.Unmarshal(confBytes, &c); err != nil {
return err
}
if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil {
return err
}
if err := conf.Set(nsConf, flagCert, packable.NewString(c.CertPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTCP, packable.NewString(c.TCPPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagKey, packable.NewString(c.KeyPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagUser, packable.NewString(c.Username)); err != nil {
return err
}
if err := conf.Set(nsConf, flagPass, packable.NewString(c.Password)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRate, packable.NewString(c.Rate)); err != nil {
return err
}
if err := conf.Set(nsConf, flagBurst, packable.NewString(c.Burst)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
return err
}
return nil
}
func fromFlags() error {
binds := make([]toBind, 0)
binds = append(binds, addFlag(flagPort, "51555", "port to bind to"))
binds = append(binds, addFlag(flagConf, "", "configuration file path"))
binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port"))
binds = append(binds, addFlag(flagCert, "", "path to .crt"))
binds = append(binds, addFlag(flagTCP, "", "tcp addr"))
binds = append(binds, addFlag(flagKey, "", "path to .key"))
binds = append(binds, addFlag(flagUser, "", "basic auth username"))
binds = append(binds, addFlag(flagPass, "", "basic auth password"))
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement:oauth to rewrite in response bodies"))
flag.Parse()
for _, bind := range binds {
confFlag := flag.Lookup(bind.flag)
if confFlag == nil || confFlag.Value.String() == "" {
continue
}
if err := conf.Set(nsConf, bind.flag, packable.NewString(*bind.value)); err != nil {
return err
}
}
return nil
}
func addFlag(key, def, help string) toBind {
def = getFlagOrDefault(key, def)
v := flag.String(key, def, help)
return toBind{
flag: key,
value: v,
}
}
func getFlagOrDefault(key, def string) string {
v := packable.NewString()
if err := conf.Get(nsConf, key, v); err != nil {
return def
}
return v.String()
func parseArgs() (*args.ArgSet, error) {
as := args.NewArgSet()
as.Append(args.STRING, "user", "username for basic auth", "")
as.Append(args.STRING, "pass", "password for basic auth", "")
as.Append(args.INT, "p", "port for service", 51555)
as.Append(args.INT, "r", "rate per second for requests", 100)
as.Append(args.INT, "b", "burst requests", 100)
as.Append(args.STRING, "crt", "path to crt for ssl", "")
as.Append(args.STRING, "key", "path to key for ssl", "")
as.Append(args.STRING, "tcp", "address for tcp only tunnel", "")
as.Append(args.DURATION, "timeout", "timeout for tunnel", time.Minute)
as.Append(args.STRING, "proxy", "double-comma separated from,scheme://to.tld:port,oauth,,", "")
as.Append(args.STRING, "oauth", "url for boauthz", "")
err := as.Parse()
return as, err
}

View File

@ -6,7 +6,7 @@ import (
)
func main() {
if err := config.Init(); err != nil {
if err := config.Refresh(); err != nil {
panic(err)
}

View File

@ -34,8 +34,8 @@ func TestHTTPSMain(t *testing.T) {
"username",
"-pass",
"password",
"-r",
"hello:" + addr,
"-proxy",
"hello," + addr,
"-crt",
"./testdata/rproxy3server.crt",
"-key",
@ -51,7 +51,7 @@ func TestHTTPSMain(t *testing.T) {
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
r, _ := http.NewRequest("GET", "https://hello.localhost"+port, nil)
r, _ := http.NewRequest("GET", "https://hello.localhost:"+port, nil)
if resp, err := client.Do(r); err != nil {
t.Fatalf("client failed: %v", err)
@ -89,8 +89,8 @@ func TestHTTPMain(t *testing.T) {
"username",
"-pass",
"password",
"-r",
"hello:" + addr,
"-proxy",
"hello," + addr,
}
main()
}()
@ -98,7 +98,7 @@ func TestHTTPMain(t *testing.T) {
time.Sleep(time.Millisecond * 100)
client := &http.Client{}
r, _ := http.NewRequest("GET", "http://hello.localhost"+port, nil)
r, _ := http.NewRequest("GET", "http://hello.localhost:"+port, nil)
if resp, err := client.Do(r); err != nil {
t.Fatalf("client failed: %v", err)
@ -127,5 +127,5 @@ func echoServer() (string, func()) {
func getPort() string {
s := httptest.NewServer(nil)
s.Close()
return s.URL[strings.LastIndex(s.URL, ":"):]
return s.URL[strings.LastIndex(s.URL, ":")+1:]
}

View File

@ -4,7 +4,6 @@ import (
"bytes"
"crypto/tls"
"io"
"local/rproxy3/config"
"local/rproxy3/storage/packable"
"log"
"net/http"
@ -33,10 +32,6 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
targetHost: newURL.Host,
baseTransport: http.DefaultTransport,
}
transport = &rewrite{
rewrites: config.GetRewrites(mapKey(r.Host)),
baseTransport: transport,
}
if err != nil {
http.NotFound(w, r)
log.Printf("unknown host lookup %q", r.Host)
@ -54,6 +49,12 @@ func (s *Server) lookup(host string) (*url.URL, error) {
return v.URL(), err
}
func (s *Server) lookupBOAuthZ(host string) (bool, error) {
v := packable.NewString()
err := s.db.Get(nsBOAuthZ, host, v)
return v.String() != "", err
}
func mapKey(host string) string {
host = strings.Split(host, ".")[0]
host = strings.Split(host, ":")[0]

View File

@ -5,7 +5,9 @@ import (
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"local/oauth2/oauth2client"
"local/rproxy3/config"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
@ -20,6 +22,7 @@ import (
)
const nsRouting = "routing"
const nsBOAuthZ = "oauth"
type listenerScheme int
@ -49,12 +52,13 @@ type Server struct {
limiter *rate.Limiter
}
func (s *Server) Route(src, dst string) error {
log.Printf("Adding route %q -> %q...\n", src, dst)
u, err := url.Parse(dst)
func (s *Server) Route(src string, dst config.Proxy) error {
log.Printf("Adding route %q -> %v...\n", src, dst)
u, err := url.Parse(dst.To)
if err != nil {
return err
}
s.db.Set(nsBOAuthZ, src, packable.NewString(fmt.Sprint(dst.BOAuthZ)))
return s.db.Set(nsRouting, src, packable.NewURL(u))
}
@ -103,7 +107,6 @@ func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
rusr, rpwd, ok := config.GetAuth()
if ok {
//usr, pwd := getProxyAuth(r)
usr, pwd, ok := r.BasicAuth()
if !ok || rusr != usr || rpwd != pwd {
w.WriteHeader(http.StatusUnauthorized)
@ -111,6 +114,17 @@ func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc {
return
}
}
ok, err := s.lookupBOAuthZ(mapKey(r.Host))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
if boauthz, useoauth := config.GetBOAuthZ(); ok && useoauth {
err := oauth2client.Authenticate(boauthz, w, r)
if err != nil {
return
}
}
foo(w, r)
}
}

View File

@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
"local/rproxy3/config"
"local/rproxy3/storage"
"net/http"
"net/http/httptest"
@ -15,7 +16,10 @@ import (
func TestServerStart(t *testing.T) {
server := mockServer()
if err := server.Route("world", "http://hello.localhost"+server.addr); err != nil {
p := config.Proxy{
To: "http://hello.localhost" + server.addr,
}
if err := server.Route("world", p); err != nil {
t.Fatalf("cannot add route: %v", err)
}
@ -48,7 +52,10 @@ func mockServer() *Server {
func TestServerRoute(t *testing.T) {
server := mockServer()
if err := server.Route("world", "http://hello.localhost"+server.addr); err != nil {
p := config.Proxy{
To: "http://hello.localhost" + server.addr,
}
if err := server.Route("world", p); err != nil {
t.Fatalf("cannot add route: %v", err)
}
w := httptest.NewRecorder()