From 2099ae50c6ff4c8441757eccd06ac5f0e66ba776 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Sat, 23 Feb 2019 18:54:33 -0700 Subject: [PATCH] defaults dont work for conf --- conf.yaml | 2 ++ config/config.go | 24 ++++++++++++++++++++++++ config/new.go | 19 +++++++++++++++++++ server/new.go | 8 ++++++-- server/server.go | 16 +++++++++++++++- 5 files changed, 66 insertions(+), 3 deletions(-) diff --git a/conf.yaml b/conf.yaml index e28b074..bdea289 100644 --- a/conf.yaml +++ b/conf.yaml @@ -6,3 +6,5 @@ crt: ./testdata/rproxy3server.crt key: ./testdata/rproxy3server.key user: bel pass: bel +rate: 1 +burst: 2 diff --git a/config/config.go b/config/config.go index 7b2b73c..20f6043 100644 --- a/config/config.go +++ b/config/config.go @@ -2,6 +2,8 @@ package config import ( "local/rproxy3/storage/packable" + "log" + "strconv" "strings" ) @@ -52,3 +54,25 @@ func notEmpty(s ...string) bool { } return true } + +func GetRate() (int, int) { + r := packable.NewString() + conf.Get(nsConf, flagRate, r) + b := packable.NewString() + conf.Get(nsConf, flagBurst, b) + + rate, _ := strconv.Atoi(r.String()) + burst, _ := strconv.Atoi(b.String()) + + return rate, burst +} + +func GetTimeout() int { + t := packable.NewString() + conf.Get(nsConf, flagTimeout, t) + + timeout, _ := strconv.Atoi(t.String()) + log.Printf("TIMEOUT t:%q, i:%v", t.String(), timeout) + + return timeout +} diff --git a/config/new.go b/config/new.go index 21e1469..65879fc 100644 --- a/config/new.go +++ b/config/new.go @@ -20,6 +20,9 @@ const flagCert = "crt" const flagKey = "key" const flagUser = "user" const flagPass = "pass" +const flagRate = "rate" +const flagBurst = "burst" +const flagTimeout = "timeout" var conf = storage.NewMap() @@ -35,6 +38,9 @@ type fileConf struct { KeyPath string `yaml:"key"` Username string `yaml:"user"` Password string `yaml:"pass"` + Rate string `yaml:"rate"` + Burst string `yaml:"burst"` + Timeout string `yaml:"timeout"` } func Init() error { @@ -85,6 +91,15 @@ func fromFile() error { if err := conf.Set(nsConf, flagPass, packable.NewString(c.Password)); err != nil { return err } + if err := conf.Set(nsConf, flagRate, packable.NewString(c.Rate)); err != nil { + return err + } + if err := conf.Set(nsConf, flagBurst, packable.NewString(c.Burst)); err != nil { + return err + } + if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil { + return err + } return nil } @@ -97,10 +112,14 @@ func fromFlags() error { binds = append(binds, addFlag(flagKey, "", "path to .key")) binds = append(binds, addFlag(flagUser, "", "basic auth username")) binds = append(binds, addFlag(flagPass, "", "basic auth password")) + binds = append(binds, addFlag(flagRate, "100", "rate limit per second")) + binds = append(binds, addFlag(flagBurst, "100", "rate limit burst")) + binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter")) flag.Parse() for _, bind := range binds { confFlag := flag.Lookup(bind.flag) + log.Printf("flag:%v", confFlag) if confFlag == nil || confFlag.Value.String() == "" { continue } diff --git a/server/new.go b/server/new.go index cbc94a5..96f2eeb 100644 --- a/server/new.go +++ b/server/new.go @@ -3,12 +3,16 @@ package server import ( "local/rproxy3/config" "local/rproxy3/storage" + + "golang.org/x/time/rate" ) func New() *Server { port := config.GetPort() + r, b := config.GetRate() return &Server{ - db: storage.NewMap(), - addr: port, + db: storage.NewMap(), + addr: port, + limiter: rate.NewLimiter(rate.Limit(r), b), } } diff --git a/server/server.go b/server/server.go index fa3a079..58e8395 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/base64" "errors" "local/rproxy3/config" @@ -10,6 +11,9 @@ import ( "net/http" "net/url" "strings" + "time" + + "golang.org/x/time/rate" ) const nsRouting = "routing" @@ -36,6 +40,7 @@ type Server struct { addr string username string password string + limiter *rate.Limiter } func (s *Server) Route(src, dst string) error { @@ -80,7 +85,16 @@ func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc { } func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { - return s.doAuth(foo) + return func(w http.ResponseWriter, r *http.Request) { + ctx, can := context.WithTimeout(r.Context(), time.Second*time.Duration(config.GetTimeout())) + log.Printf("context iwth tiemout: %v", time.Second*time.Duration(config.GetTimeout())) + defer can() + if err := s.limiter.Wait(ctx); err != nil { + w.WriteHeader(http.StatusTooManyRequests) + return + } + s.doAuth(foo)(w, r) + } } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {