From f72ecc5e5357dadf068a188085619addee01d167 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Wed, 10 Apr 2019 10:08:47 -0600 Subject: [PATCH] passes tests with no rewrites --- config/config.go | 26 +++++++++++++++++++++++- config/new.go | 6 ++++++ server/proxy.go | 53 ++++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/config/config.go b/config/config.go index 6c64328..62c811f 100644 --- a/config/config.go +++ b/config/config.go @@ -2,6 +2,7 @@ package config import ( "local/rproxy3/storage/packable" + "log" "strconv" "strings" ) @@ -60,8 +61,16 @@ func GetRate() (int, int) { b := packable.NewString() conf.Get(nsConf, flagBurst, b) - rate, _ := strconv.Atoi(r.String()) + rate, err := strconv.Atoi(r.String()) + if err != nil { + log.Printf("illegal rate: %v", err) + rate = 5 + } burst, _ := strconv.Atoi(b.String()) + if err != nil { + log.Printf("illegal burst: %v", err) + burst = 5 + } return rate, burst } @@ -77,3 +86,18 @@ func GetTimeout() int { return timeout } + +func GetRewrites() map[string]string { + v := packable.NewString() + conf.Get(nsConf, flagRewrites, v) + m := make(map[string]string) + for _, v := range strings.Split(v.String(), ",") { + if len(v) == 0 { + return m + } + from := v[:strings.Index(v, ":")] + to := v[strings.Index(v, ":")+1:] + m[from] = to + } + return m +} diff --git a/config/new.go b/config/new.go index 87d3ff6..6750b7f 100644 --- a/config/new.go +++ b/config/new.go @@ -23,6 +23,7 @@ const flagPass = "pass" const flagRate = "rate" const flagBurst = "burst" const flagTimeout = "timeout" +const flagRewrites = "rw" var conf = storage.NewMap() @@ -41,6 +42,7 @@ type fileConf struct { Rate string `yaml:"rate"` Burst string `yaml:"burst"` Timeout string `yaml:"timeout"` + Rewrites []string `yaml:"rw"` } func Init() error { @@ -100,6 +102,9 @@ func fromFile() error { if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil { return err } + if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil { + return err + } return nil } @@ -115,6 +120,7 @@ func fromFlags() error { binds = append(binds, addFlag(flagRate, "100", "rate limit per second")) binds = append(binds, addFlag(flagBurst, "100", "rate limit burst")) binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter")) + binds = append(binds, addFlag(flagRewrites, "", "comma-separated regex:v to rewrite in response bodies")) flag.Parse() for _, bind := range binds { diff --git a/server/proxy.go b/server/proxy.go index 3348ec3..7dcf34a 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -1,6 +1,8 @@ package server import ( + "io" + "local/rproxy3/config" "local/rproxy3/storage/packable" "log" "net/http" @@ -10,15 +12,27 @@ import ( ) type redirPurge struct { - proxyHost string - targetHost string + proxyHost string + targetHost string + baseTransport http.RoundTripper +} + +type rewrite struct { + rewrites map[string]string + baseTransport http.RoundTripper } func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { newURL, err := s.lookup(r.Host) - transport := &redirPurge{ - proxyHost: r.Host, - targetHost: newURL.Host, + var transport http.RoundTripper + transport = &redirPurge{ + proxyHost: r.Host, + targetHost: newURL.Host, + baseTransport: http.DefaultTransport, + } + transport = &rewrite{ + rewrites: config.GetRewrites(), + baseTransport: transport, } if err != nil { http.NotFound(w, r) @@ -40,7 +54,7 @@ func (s *Server) lookup(host string) (*url.URL, error) { } func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) { - resp, err := http.DefaultTransport.RoundTrip(r) + resp, err := rp.baseTransport.RoundTrip(r) if err != nil { return resp, err } @@ -49,3 +63,30 @@ func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) { } return resp, err } + +func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) { + resp, err := rw.baseTransport.RoundTrip(r) + if err != nil { + return resp, err + } + pr, pw := io.Pipe() + body := resp.Body + resp.Body = pr + go func() { + buff := make([]byte, 1024) + for n, err := body.Read(buff); err == nil || n > 0; n, err = body.Read(buff) { + chunk := buff[:n] + m := 0 + for m < n { + l, err := pw.Write(chunk[m:]) + if err != nil { + pw.CloseWithError(err) + return + } + m += l + } + } + pw.CloseWithError(err) + }() + return resp, err +}