parent
f72ecc5e53
commit
f58f6f7cf3
|
|
@ -87,16 +87,21 @@ func GetTimeout() int {
|
||||||
return timeout
|
return timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRewrites() map[string]string {
|
func GetRewrites(hostMatch string) map[string]string {
|
||||||
v := packable.NewString()
|
v := packable.NewString()
|
||||||
conf.Get(nsConf, flagRewrites, v)
|
conf.Get(nsConf, flagRewrites, v)
|
||||||
m := make(map[string]string)
|
m := make(map[string]string)
|
||||||
for _, v := range strings.Split(v.String(), ",") {
|
for _, v := range strings.Split(v.String(), ",") {
|
||||||
if len(v) == 0 {
|
vs := strings.Split(v, ":")
|
||||||
return m
|
if len(v) < 3 {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
from := v[:strings.Index(v, ":")]
|
host := vs[0]
|
||||||
to := v[strings.Index(v, ":")+1:]
|
if host != hostMatch {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
from := vs[1]
|
||||||
|
to := strings.Join(vs[2:], ":")
|
||||||
m[from] = to
|
m[from] = to
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ type fileConf struct {
|
||||||
|
|
||||||
func Init() error {
|
func Init() error {
|
||||||
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
|
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
|
||||||
|
log.SetFlags(log.Ltime | log.Lshortfile)
|
||||||
if err := fromFile(); err != nil {
|
if err := fromFile(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -120,7 +121,7 @@ func fromFlags() error {
|
||||||
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
|
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
|
||||||
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
|
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
|
||||||
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
|
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
|
||||||
binds = append(binds, addFlag(flagRewrites, "", "comma-separated regex:v to rewrite in response bodies"))
|
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement to rewrite in response bodies"))
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
for _, bind := range binds {
|
for _, bind := range binds {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"local/rproxy3/config"
|
"local/rproxy3/config"
|
||||||
"local/rproxy3/storage/packable"
|
"local/rproxy3/storage/packable"
|
||||||
|
|
@ -23,7 +24,7 @@ type rewrite struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
newURL, err := s.lookup(r.Host)
|
newURL, err := s.lookup(mapKey(r.Host))
|
||||||
var transport http.RoundTripper
|
var transport http.RoundTripper
|
||||||
transport = &redirPurge{
|
transport = &redirPurge{
|
||||||
proxyHost: r.Host,
|
proxyHost: r.Host,
|
||||||
|
|
@ -31,7 +32,7 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
baseTransport: http.DefaultTransport,
|
baseTransport: http.DefaultTransport,
|
||||||
}
|
}
|
||||||
transport = &rewrite{
|
transport = &rewrite{
|
||||||
rewrites: config.GetRewrites(),
|
rewrites: config.GetRewrites(mapKey(r.Host)),
|
||||||
baseTransport: transport,
|
baseTransport: transport,
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -46,13 +47,17 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) lookup(host string) (*url.URL, error) {
|
func (s *Server) lookup(host string) (*url.URL, error) {
|
||||||
host = strings.Split(host, ".")[0]
|
|
||||||
host = strings.Split(host, ":")[0]
|
|
||||||
v := packable.NewURL()
|
v := packable.NewURL()
|
||||||
err := s.db.Get(nsRouting, host, v)
|
err := s.db.Get(nsRouting, host, v)
|
||||||
return v.URL(), err
|
return v.URL(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mapKey(host string) string {
|
||||||
|
host = strings.Split(host, ".")[0]
|
||||||
|
host = strings.Split(host, ":")[0]
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
|
func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
resp, err := rp.baseTransport.RoundTrip(r)
|
resp, err := rp.baseTransport.RoundTrip(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -69,13 +74,22 @@ func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
if len(rw.rewrites) == 0 {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
resp.Header.Del("Content-Length")
|
||||||
pr, pw := io.Pipe()
|
pr, pw := io.Pipe()
|
||||||
body := resp.Body
|
body := resp.Body
|
||||||
resp.Body = pr
|
resp.Body = pr
|
||||||
go func() {
|
go func() {
|
||||||
buff := make([]byte, 1024)
|
buff := make([]byte, 1024)
|
||||||
for n, err := body.Read(buff); err == nil || n > 0; n, err = body.Read(buff) {
|
n, err := body.Read(buff)
|
||||||
|
for err == nil || n > 0 {
|
||||||
chunk := buff[:n]
|
chunk := buff[:n]
|
||||||
|
for k, v := range rw.rewrites {
|
||||||
|
chunk = bytes.Replace(chunk, []byte(k), []byte(v), -1)
|
||||||
|
}
|
||||||
|
n = len(chunk)
|
||||||
m := 0
|
m := 0
|
||||||
for m < n {
|
for m < n {
|
||||||
l, err := pw.Write(chunk[m:])
|
l, err := pw.Write(chunk[m:])
|
||||||
|
|
@ -85,6 +99,7 @@ func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
m += l
|
m += l
|
||||||
}
|
}
|
||||||
|
n, err = body.Read(buff)
|
||||||
}
|
}
|
||||||
pw.CloseWithError(err)
|
pw.CloseWithError(err)
|
||||||
}()
|
}()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,42 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeTransport struct{}
|
||||||
|
|
||||||
|
func (ft fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
Body: r.Body,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// empty url -> OK //TODO
|
// empty url -> OK //TODO
|
||||||
|
func TestRewrite(t *testing.T) {
|
||||||
|
transport := &rewrite{
|
||||||
|
rewrites: map[string]string{
|
||||||
|
"a": "b",
|
||||||
|
},
|
||||||
|
baseTransport: fakeTransport{},
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest("GET", "asdf", strings.NewReader("mary had a little lamb"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp, err := transport.RoundTrip(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if string(b) != "mbry hbd b little lbmb" {
|
||||||
|
t.Errorf("failed to replace: got %q, want \"b\"", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue