diff --git a/config.go b/config.go index 5e21e69..68b2c6e 100644 --- a/config.go +++ b/config.go @@ -4,18 +4,22 @@ import ( "fmt" "local/args" "time" + + "golang.org/x/time/rate" ) type Config struct { Listen string Timeout time.Duration TLSInsecure bool + Limiter *rate.Limiter } func NewConfig() *Config { as := args.NewArgSet() as.Append(args.INT, "p", "port to listen on", 61113) + as.Append(args.INT, "kbps", "port to listen on", 61113) as.Append(args.BOOL, "tls-insecure", "permit tls insecure", false) as.Append(args.DURATION, "t", "timeout", time.Minute) @@ -23,9 +27,15 @@ func NewConfig() *Config { panic(err) } + var limiter *rate.Limiter + if kbps := as.GetInt("kbps"); kbps > 0 { + limiter = rate.NewLimiter(rate.Limit(kbps), 100*1024) + } + return &Config{ Listen: fmt.Sprintf(":%v", as.GetInt("p")), Timeout: as.GetDuration("t"), TLSInsecure: as.GetBool("tls-insecure"), + Limiter: limiter, } } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..49bd2de --- /dev/null +++ b/go.mod @@ -0,0 +1,12 @@ +module gogs.inhome.blapointe.com/mfproxy + +go 1.18 + +require ( + golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 + local/args v0.0.0-00010101000000-000000000000 +) + +require gopkg.in/yaml.v2 v2.4.0 // indirect + +replace local/args => ../../local/args diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7df3092 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 h1:ftMN5LMiBFjbzleLqtoBZk7KdJwhuybIU+FckUHgoyQ= +golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/server.go b/server.go index ae334ed..785228a 100644 --- a/server.go +++ b/server.go @@ -8,10 +8,13 @@ import ( "net" "net/http" "time" + + "golang.org/x/time/rate" ) type Server struct { Transport http.RoundTripper + limiter *rate.Limiter Timeout time.Duration } @@ -21,6 +24,7 @@ func NewServer(c *Config) *Server { InsecureSkipVerify: c.TLSInsecure, } return &Server{ + limiter: c.Limiter, Transport: transport, Timeout: c.Timeout, } @@ -86,5 +90,12 @@ func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { w.Header().Add(k, s) } } - io.Copy(w, resp.Body) + io.Copy( + throttledWriter{ + ctx: r.Context(), + w: w, + limiter: s.limiter, + }, + resp.Body, + ) } diff --git a/throttle.go b/throttle.go new file mode 100644 index 0000000..c6fc37c --- /dev/null +++ b/throttle.go @@ -0,0 +1,26 @@ +package main + +import ( + "context" + "io" + + "golang.org/x/time/rate" +) + +type throttledWriter struct { + ctx context.Context + w io.Writer + limiter *rate.Limiter +} + +func (tw throttledWriter) Write(b []byte) (int, error) { + if tw.limiter != nil { + if block := tw.limiter.Burst(); len(b) > block { + b = b[:block] + } + if err := tw.limiter.WaitN(tw.ctx, len(b)); err != nil { + return 0, err + } + } + return tw.w.Write(b) +}