4 Commits
v0.1 ... v0.2.1

Author SHA1 Message Date
Bel LaPointe
06ba2dfdc1 Add path option for proxying 2019-04-10 11:15:17 -06:00
Bel LaPointe
f58f6f7cf3 Support body rewrite 2019-04-10 10:47:30 -06:00
Bel LaPointe
f72ecc5e53 passes tests with no rewrites 2019-04-10 10:08:47 -06:00
Bel LaPointe
3bd1527b98 Fix tests 2019-04-10 09:52:54 -06:00
5 changed files with 211 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ package config
import ( import (
"local/rproxy3/storage/packable" "local/rproxy3/storage/packable"
"log"
"strconv" "strconv"
"strings" "strings"
) )
@@ -60,8 +61,16 @@ func GetRate() (int, int) {
b := packable.NewString() b := packable.NewString()
conf.Get(nsConf, flagBurst, b) conf.Get(nsConf, flagBurst, b)
rate, _ := strconv.Atoi(r.String()) rate, err := strconv.Atoi(r.String())
if err != nil {
log.Printf("illegal rate: %v", err)
rate = 5
}
burst, _ := strconv.Atoi(b.String()) burst, _ := strconv.Atoi(b.String())
if err != nil {
log.Printf("illegal burst: %v", err)
burst = 5
}
return rate, burst return rate, burst
} }
@@ -70,7 +79,40 @@ func GetTimeout() int {
t := packable.NewString() t := packable.NewString()
conf.Get(nsConf, flagTimeout, t) conf.Get(nsConf, flagTimeout, t)
timeout, _ := strconv.Atoi(t.String()) timeout, err := strconv.Atoi(t.String())
if err != nil || timeout == 5 {
return 5
}
return timeout return timeout
} }
func GetRewrites(hostMatch string) map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRewrites, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
vs := strings.Split(v, ":")
if len(v) < 3 {
continue
}
host := vs[0]
if host != hostMatch {
continue
}
from := vs[1]
to := strings.Join(vs[2:], ":")
m[from] = to
}
return m
}
func GetProxyMode() string {
v := packable.NewString()
conf.Get(nsConf, flagMode, v)
s := v.String()
if s == "" {
return "domain"
}
return s
}

View File

@@ -14,6 +14,7 @@ import (
const nsConf = "configuration" const nsConf = "configuration"
const flagPort = "p" const flagPort = "p"
const flagMode = "mode"
const flagRoutes = "r" const flagRoutes = "r"
const flagConf = "c" const flagConf = "c"
const flagCert = "crt" const flagCert = "crt"
@@ -23,6 +24,7 @@ const flagPass = "pass"
const flagRate = "rate" const flagRate = "rate"
const flagBurst = "burst" const flagBurst = "burst"
const flagTimeout = "timeout" const flagTimeout = "timeout"
const flagRewrites = "rw"
var conf = storage.NewMap() var conf = storage.NewMap()
@@ -33,6 +35,7 @@ type toBind struct {
type fileConf struct { type fileConf struct {
Port string `yaml:"p"` Port string `yaml:"p"`
Mode string `yaml:"mode"`
Routes []string `yaml:"r"` Routes []string `yaml:"r"`
CertPath string `yaml:"crt"` CertPath string `yaml:"crt"`
KeyPath string `yaml:"key"` KeyPath string `yaml:"key"`
@@ -41,10 +44,12 @@ type fileConf struct {
Rate string `yaml:"rate"` Rate string `yaml:"rate"`
Burst string `yaml:"burst"` Burst string `yaml:"burst"`
Timeout string `yaml:"timeout"` Timeout string `yaml:"timeout"`
Rewrites []string `yaml:"rw"`
} }
func Init() error { func Init() error {
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile) log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
log.SetFlags(log.Ltime | log.Lshortfile)
if err := fromFile(); err != nil { if err := fromFile(); err != nil {
return err return err
} }
@@ -76,6 +81,9 @@ func fromFile() error {
if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil { if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil {
return err return err
} }
if err := conf.Set(nsConf, flagMode, packable.NewString(c.Mode)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil { if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil {
return err return err
} }
@@ -100,12 +108,16 @@ func fromFile() error {
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil { if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
return err return err
} }
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
return err
}
return nil return nil
} }
func fromFlags() error { func fromFlags() error {
binds := make([]toBind, 0) binds := make([]toBind, 0)
binds = append(binds, addFlag(flagPort, "51555", "port to bind to")) binds = append(binds, addFlag(flagPort, "51555", "port to bind to"))
binds = append(binds, addFlag(flagMode, "domain", "[domain] or [path] to match"))
binds = append(binds, addFlag(flagConf, "", "configuration file path")) binds = append(binds, addFlag(flagConf, "", "configuration file path"))
binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port")) binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port"))
binds = append(binds, addFlag(flagCert, "", "path to .crt")) binds = append(binds, addFlag(flagCert, "", "path to .crt"))
@@ -115,6 +127,7 @@ func fromFlags() error {
binds = append(binds, addFlag(flagRate, "100", "rate limit per second")) binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst")) binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter")) binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement to rewrite in response bodies"))
flag.Parse() flag.Parse()
for _, bind := range binds { for _, bind := range binds {

View File

@@ -1,6 +1,9 @@
package server package server
import ( import (
"bytes"
"io"
"local/rproxy3/config"
"local/rproxy3/storage/packable" "local/rproxy3/storage/packable"
"log" "log"
"net/http" "net/http"
@@ -12,13 +15,25 @@ import (
type redirPurge struct { type redirPurge struct {
proxyHost string proxyHost string
targetHost string targetHost string
baseTransport http.RoundTripper
}
type rewrite struct {
rewrites map[string]string
baseTransport http.RoundTripper
} }
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
newURL, err := s.lookup(r.Host) newURL, err := s.lookup(mapKey(r, config.GetProxyMode()))
transport := &redirPurge{ var transport http.RoundTripper
transport = &redirPurge{
proxyHost: r.Host, proxyHost: r.Host,
targetHost: newURL.Host, targetHost: newURL.Host,
baseTransport: http.DefaultTransport,
}
transport = &rewrite{
rewrites: config.GetRewrites(mapKey(r, config.GetProxyMode())),
baseTransport: transport,
} }
if err != nil { if err != nil {
http.NotFound(w, r) http.NotFound(w, r)
@@ -32,15 +47,29 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) lookup(host string) (*url.URL, error) { func (s *Server) lookup(host string) (*url.URL, error) {
host = strings.Split(host, ".")[0]
host = strings.Split(host, ":")[0]
v := packable.NewURL() v := packable.NewURL()
err := s.db.Get(nsRouting, host, v) err := s.db.Get(nsRouting, host, v)
return v.URL(), err return v.URL(), err
} }
func mapKey(r *http.Request, proxyMode string) string {
switch proxyMode {
case "domain":
host := strings.Split(r.Host, ".")[0]
host = strings.Split(host, ":")[0]
return host
case "path":
paths := strings.Split(r.URL.Path, "/")
if len(paths) < 2 {
return ""
}
return paths[1]
}
return ""
}
func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) { func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
resp, err := http.DefaultTransport.RoundTrip(r) resp, err := rp.baseTransport.RoundTrip(r)
if err != nil { if err != nil {
return resp, err return resp, err
} }
@@ -49,3 +78,40 @@ func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
} }
return resp, err return resp, err
} }
func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) {
resp, err := rw.baseTransport.RoundTrip(r)
if err != nil {
return resp, err
}
if len(rw.rewrites) == 0 {
return resp, err
}
resp.Header.Del("Content-Length")
pr, pw := io.Pipe()
body := resp.Body
resp.Body = pr
go func() {
buff := make([]byte, 1024)
n, err := body.Read(buff)
for err == nil || n > 0 {
chunk := buff[:n]
for k, v := range rw.rewrites {
chunk = bytes.Replace(chunk, []byte(k), []byte(v), -1)
}
n = len(chunk)
m := 0
for m < n {
l, err := pw.Write(chunk[m:])
if err != nil {
pw.CloseWithError(err)
return
}
m += l
}
n, err = body.Read(buff)
}
pw.CloseWithError(err)
}()
return resp, err
}

View File

@@ -1,3 +1,75 @@
package server package server
import (
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"
)
type fakeTransport struct{}
func (ft fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return &http.Response{
Body: r.Body,
}, nil
}
// empty url -> OK //TODO // empty url -> OK //TODO
func TestRewrite(t *testing.T) {
transport := &rewrite{
rewrites: map[string]string{
"a": "b",
},
baseTransport: fakeTransport{},
}
r, err := http.NewRequest("GET", "asdf", strings.NewReader("mary had a little lamb"))
if err != nil {
t.Fatal(err)
}
resp, err := transport.RoundTrip(r)
if err != nil {
t.Fatal(err)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != "mbry hbd b little lbmb" {
t.Errorf("failed to replace: got %q, want \"b\"", b)
}
}
func TestMapKey(t *testing.T) {
r := &http.Request{
Host: "a.b.c:123",
URL: &url.URL{
Path: "/c/d/e",
},
}
if v := mapKey(r, "domain"); v != "a" {
t.Errorf("failed to get domain: got %v", v)
}
if v := mapKey(r, "path"); v != "c" {
t.Errorf("failed to get domain: got %v", v)
}
r.Host = "a:123"
if v := mapKey(r, "domain"); v != "a" {
t.Errorf("failed to get domain: got %v", v)
}
r.URL.Path = ""
if v := mapKey(r, "path"); v != "" {
t.Errorf("failed to get domain: got %v", v)
}
r.URL.Path = "/"
if v := mapKey(r, "path"); v != "" {
t.Errorf("failed to get domain: got %v", v)
}
}

View File

@@ -1,12 +1,15 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"local/rproxy3/storage" "local/rproxy3/storage"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"golang.org/x/time/rate"
) )
func TestServerStart(t *testing.T) { func TestServerStart(t *testing.T) {
@@ -35,6 +38,7 @@ func mockServer() *Server {
s := &Server{ s := &Server{
db: storage.NewMap(), db: storage.NewMap(),
addr: ":" + port, addr: ":" + port,
limiter: rate.NewLimiter(rate.Limit(50), 50),
} }
if err := s.Routes(); err != nil { if err := s.Routes(); err != nil {
panic(fmt.Sprintf("cannot initiate server routes; %v", err)) panic(fmt.Sprintf("cannot initiate server routes; %v", err))
@@ -49,6 +53,7 @@ func TestServerRoute(t *testing.T) {
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "http://world.localhost"+server.addr, nil) r, _ := http.NewRequest("GET", "http://world.localhost"+server.addr, nil)
r = r.WithContext(context.Background())
server.ServeHTTP(w, r) server.ServeHTTP(w, r)
if w.Code != 502 { if w.Code != 502 {
t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code) t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code)