Support body rewrite

master v0.2
Bel LaPointe 2019-04-10 10:47:30 -06:00
parent f72ecc5e53
commit f58f6f7cf3
4 changed files with 71 additions and 11 deletions

View File

@ -87,16 +87,21 @@ func GetTimeout() int {
return timeout return timeout
} }
func GetRewrites() map[string]string { func GetRewrites(hostMatch string) map[string]string {
v := packable.NewString() v := packable.NewString()
conf.Get(nsConf, flagRewrites, v) conf.Get(nsConf, flagRewrites, v)
m := make(map[string]string) m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") { for _, v := range strings.Split(v.String(), ",") {
if len(v) == 0 { vs := strings.Split(v, ":")
return m if len(v) < 3 {
continue
} }
from := v[:strings.Index(v, ":")] host := vs[0]
to := v[strings.Index(v, ":")+1:] if host != hostMatch {
continue
}
from := vs[1]
to := strings.Join(vs[2:], ":")
m[from] = to m[from] = to
} }
return m return m

View File

@ -47,6 +47,7 @@ type fileConf struct {
func Init() error { func Init() error {
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile) log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
log.SetFlags(log.Ltime | log.Lshortfile)
if err := fromFile(); err != nil { if err := fromFile(); err != nil {
return err return err
} }
@ -120,7 +121,7 @@ func fromFlags() error {
binds = append(binds, addFlag(flagRate, "100", "rate limit per second")) binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst")) binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter")) binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
binds = append(binds, addFlag(flagRewrites, "", "comma-separated regex:v to rewrite in response bodies")) binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement to rewrite in response bodies"))
flag.Parse() flag.Parse()
for _, bind := range binds { for _, bind := range binds {

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"bytes"
"io" "io"
"local/rproxy3/config" "local/rproxy3/config"
"local/rproxy3/storage/packable" "local/rproxy3/storage/packable"
@ -23,7 +24,7 @@ type rewrite struct {
} }
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
newURL, err := s.lookup(r.Host) newURL, err := s.lookup(mapKey(r.Host))
var transport http.RoundTripper var transport http.RoundTripper
transport = &redirPurge{ transport = &redirPurge{
proxyHost: r.Host, proxyHost: r.Host,
@ -31,7 +32,7 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
baseTransport: http.DefaultTransport, baseTransport: http.DefaultTransport,
} }
transport = &rewrite{ transport = &rewrite{
rewrites: config.GetRewrites(), rewrites: config.GetRewrites(mapKey(r.Host)),
baseTransport: transport, baseTransport: transport,
} }
if err != nil { if err != nil {
@ -46,13 +47,17 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) lookup(host string) (*url.URL, error) { func (s *Server) lookup(host string) (*url.URL, error) {
host = strings.Split(host, ".")[0]
host = strings.Split(host, ":")[0]
v := packable.NewURL() v := packable.NewURL()
err := s.db.Get(nsRouting, host, v) err := s.db.Get(nsRouting, host, v)
return v.URL(), err return v.URL(), err
} }
func mapKey(host string) string {
host = strings.Split(host, ".")[0]
host = strings.Split(host, ":")[0]
return host
}
func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) { func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
resp, err := rp.baseTransport.RoundTrip(r) resp, err := rp.baseTransport.RoundTrip(r)
if err != nil { if err != nil {
@ -69,13 +74,22 @@ func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) {
if err != nil { if err != nil {
return resp, err return resp, err
} }
if len(rw.rewrites) == 0 {
return resp, err
}
resp.Header.Del("Content-Length")
pr, pw := io.Pipe() pr, pw := io.Pipe()
body := resp.Body body := resp.Body
resp.Body = pr resp.Body = pr
go func() { go func() {
buff := make([]byte, 1024) buff := make([]byte, 1024)
for n, err := body.Read(buff); err == nil || n > 0; n, err = body.Read(buff) { n, err := body.Read(buff)
for err == nil || n > 0 {
chunk := buff[:n] chunk := buff[:n]
for k, v := range rw.rewrites {
chunk = bytes.Replace(chunk, []byte(k), []byte(v), -1)
}
n = len(chunk)
m := 0 m := 0
for m < n { for m < n {
l, err := pw.Write(chunk[m:]) l, err := pw.Write(chunk[m:])
@ -85,6 +99,7 @@ func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) {
} }
m += l m += l
} }
n, err = body.Read(buff)
} }
pw.CloseWithError(err) pw.CloseWithError(err)
}() }()

View File

@ -1,3 +1,42 @@
package server package server
import (
"io/ioutil"
"net/http"
"strings"
"testing"
)
type fakeTransport struct{}
func (ft fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return &http.Response{
Body: r.Body,
}, nil
}
// empty url -> OK //TODO // empty url -> OK //TODO
func TestRewrite(t *testing.T) {
transport := &rewrite{
rewrites: map[string]string{
"a": "b",
},
baseTransport: fakeTransport{},
}
r, err := http.NewRequest("GET", "asdf", strings.NewReader("mary had a little lamb"))
if err != nil {
t.Fatal(err)
}
resp, err := transport.RoundTrip(r)
if err != nil {
t.Fatal(err)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != "mbry hbd b little lbmb" {
t.Errorf("failed to replace: got %q, want \"b\"", b)
}
}