diff --git a/config/config.go b/config/config.go index 62c811f..d50034f 100644 --- a/config/config.go +++ b/config/config.go @@ -87,16 +87,21 @@ func GetTimeout() int { return timeout } -func GetRewrites() map[string]string { +func GetRewrites(hostMatch string) map[string]string { v := packable.NewString() conf.Get(nsConf, flagRewrites, v) m := make(map[string]string) for _, v := range strings.Split(v.String(), ",") { - if len(v) == 0 { - return m + vs := strings.Split(v, ":") + if len(v) < 3 { + continue } - from := v[:strings.Index(v, ":")] - to := v[strings.Index(v, ":")+1:] + host := vs[0] + if host != hostMatch { + continue + } + from := vs[1] + to := strings.Join(vs[2:], ":") m[from] = to } return m diff --git a/config/new.go b/config/new.go index 6750b7f..4bdebdf 100644 --- a/config/new.go +++ b/config/new.go @@ -47,6 +47,7 @@ type fileConf struct { func Init() error { log.SetFlags(log.Ldate | log.Ltime | log.Llongfile) + log.SetFlags(log.Ltime | log.Lshortfile) if err := fromFile(); err != nil { return err } @@ -120,7 +121,7 @@ func fromFlags() error { binds = append(binds, addFlag(flagRate, "100", "rate limit per second")) binds = append(binds, addFlag(flagBurst, "100", "rate limit burst")) binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter")) - binds = append(binds, addFlag(flagRewrites, "", "comma-separated regex:v to rewrite in response bodies")) + binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement to rewrite in response bodies")) flag.Parse() for _, bind := range binds { diff --git a/server/proxy.go b/server/proxy.go index 7dcf34a..6ee58cc 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "io" "local/rproxy3/config" "local/rproxy3/storage/packable" @@ -23,7 +24,7 @@ type rewrite struct { } func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { - newURL, err := s.lookup(r.Host) + newURL, err := s.lookup(mapKey(r.Host)) var transport http.RoundTripper transport = &redirPurge{ proxyHost: r.Host, @@ -31,7 +32,7 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { baseTransport: http.DefaultTransport, } transport = &rewrite{ - rewrites: config.GetRewrites(), + rewrites: config.GetRewrites(mapKey(r.Host)), baseTransport: transport, } if err != nil { @@ -46,13 +47,17 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { } func (s *Server) lookup(host string) (*url.URL, error) { - host = strings.Split(host, ".")[0] - host = strings.Split(host, ":")[0] v := packable.NewURL() err := s.db.Get(nsRouting, host, v) return v.URL(), err } +func mapKey(host string) string { + host = strings.Split(host, ".")[0] + host = strings.Split(host, ":")[0] + return host +} + func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) { resp, err := rp.baseTransport.RoundTrip(r) if err != nil { @@ -69,13 +74,22 @@ func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) { if err != nil { return resp, err } + if len(rw.rewrites) == 0 { + return resp, err + } + resp.Header.Del("Content-Length") pr, pw := io.Pipe() body := resp.Body resp.Body = pr go func() { buff := make([]byte, 1024) - for n, err := body.Read(buff); err == nil || n > 0; n, err = body.Read(buff) { + n, err := body.Read(buff) + for err == nil || n > 0 { chunk := buff[:n] + for k, v := range rw.rewrites { + chunk = bytes.Replace(chunk, []byte(k), []byte(v), -1) + } + n = len(chunk) m := 0 for m < n { l, err := pw.Write(chunk[m:]) @@ -85,6 +99,7 @@ func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) { } m += l } + n, err = body.Read(buff) } pw.CloseWithError(err) }() diff --git a/server/proxy_test.go b/server/proxy_test.go index 51ff208..ca1c1f1 100644 --- a/server/proxy_test.go +++ b/server/proxy_test.go @@ -1,3 +1,42 @@ package server +import ( + "io/ioutil" + "net/http" + "strings" + "testing" +) + +type fakeTransport struct{} + +func (ft fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) { + return &http.Response{ + Body: r.Body, + }, nil +} + // empty url -> OK //TODO +func TestRewrite(t *testing.T) { + transport := &rewrite{ + rewrites: map[string]string{ + "a": "b", + }, + baseTransport: fakeTransport{}, + } + + r, err := http.NewRequest("GET", "asdf", strings.NewReader("mary had a little lamb")) + if err != nil { + t.Fatal(err) + } + resp, err := transport.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(b) != "mbry hbd b little lbmb" { + t.Errorf("failed to replace: got %q, want \"b\"", b) + } +}