Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06ba2dfdc1 | ||
|
|
f58f6f7cf3 | ||
|
|
f72ecc5e53 | ||
|
|
3bd1527b98 |
@@ -2,6 +2,7 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"local/rproxy3/storage/packable"
|
"local/rproxy3/storage/packable"
|
||||||
|
"log"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -60,8 +61,16 @@ func GetRate() (int, int) {
|
|||||||
b := packable.NewString()
|
b := packable.NewString()
|
||||||
conf.Get(nsConf, flagBurst, b)
|
conf.Get(nsConf, flagBurst, b)
|
||||||
|
|
||||||
rate, _ := strconv.Atoi(r.String())
|
rate, err := strconv.Atoi(r.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("illegal rate: %v", err)
|
||||||
|
rate = 5
|
||||||
|
}
|
||||||
burst, _ := strconv.Atoi(b.String())
|
burst, _ := strconv.Atoi(b.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("illegal burst: %v", err)
|
||||||
|
burst = 5
|
||||||
|
}
|
||||||
|
|
||||||
return rate, burst
|
return rate, burst
|
||||||
}
|
}
|
||||||
@@ -70,7 +79,40 @@ func GetTimeout() int {
|
|||||||
t := packable.NewString()
|
t := packable.NewString()
|
||||||
conf.Get(nsConf, flagTimeout, t)
|
conf.Get(nsConf, flagTimeout, t)
|
||||||
|
|
||||||
timeout, _ := strconv.Atoi(t.String())
|
timeout, err := strconv.Atoi(t.String())
|
||||||
|
if err != nil || timeout == 5 {
|
||||||
|
return 5
|
||||||
|
}
|
||||||
|
|
||||||
return timeout
|
return timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetRewrites(hostMatch string) map[string]string {
|
||||||
|
v := packable.NewString()
|
||||||
|
conf.Get(nsConf, flagRewrites, v)
|
||||||
|
m := make(map[string]string)
|
||||||
|
for _, v := range strings.Split(v.String(), ",") {
|
||||||
|
vs := strings.Split(v, ":")
|
||||||
|
if len(v) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
host := vs[0]
|
||||||
|
if host != hostMatch {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
from := vs[1]
|
||||||
|
to := strings.Join(vs[2:], ":")
|
||||||
|
m[from] = to
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetProxyMode() string {
|
||||||
|
v := packable.NewString()
|
||||||
|
conf.Get(nsConf, flagMode, v)
|
||||||
|
s := v.String()
|
||||||
|
if s == "" {
|
||||||
|
return "domain"
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
const nsConf = "configuration"
|
const nsConf = "configuration"
|
||||||
const flagPort = "p"
|
const flagPort = "p"
|
||||||
|
const flagMode = "mode"
|
||||||
const flagRoutes = "r"
|
const flagRoutes = "r"
|
||||||
const flagConf = "c"
|
const flagConf = "c"
|
||||||
const flagCert = "crt"
|
const flagCert = "crt"
|
||||||
@@ -23,6 +24,7 @@ const flagPass = "pass"
|
|||||||
const flagRate = "rate"
|
const flagRate = "rate"
|
||||||
const flagBurst = "burst"
|
const flagBurst = "burst"
|
||||||
const flagTimeout = "timeout"
|
const flagTimeout = "timeout"
|
||||||
|
const flagRewrites = "rw"
|
||||||
|
|
||||||
var conf = storage.NewMap()
|
var conf = storage.NewMap()
|
||||||
|
|
||||||
@@ -33,6 +35,7 @@ type toBind struct {
|
|||||||
|
|
||||||
type fileConf struct {
|
type fileConf struct {
|
||||||
Port string `yaml:"p"`
|
Port string `yaml:"p"`
|
||||||
|
Mode string `yaml:"mode"`
|
||||||
Routes []string `yaml:"r"`
|
Routes []string `yaml:"r"`
|
||||||
CertPath string `yaml:"crt"`
|
CertPath string `yaml:"crt"`
|
||||||
KeyPath string `yaml:"key"`
|
KeyPath string `yaml:"key"`
|
||||||
@@ -41,10 +44,12 @@ type fileConf struct {
|
|||||||
Rate string `yaml:"rate"`
|
Rate string `yaml:"rate"`
|
||||||
Burst string `yaml:"burst"`
|
Burst string `yaml:"burst"`
|
||||||
Timeout string `yaml:"timeout"`
|
Timeout string `yaml:"timeout"`
|
||||||
|
Rewrites []string `yaml:"rw"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func Init() error {
|
func Init() error {
|
||||||
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
|
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
|
||||||
|
log.SetFlags(log.Ltime | log.Lshortfile)
|
||||||
if err := fromFile(); err != nil {
|
if err := fromFile(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -76,6 +81,9 @@ func fromFile() error {
|
|||||||
if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil {
|
if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := conf.Set(nsConf, flagMode, packable.NewString(c.Mode)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil {
|
if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -100,12 +108,16 @@ func fromFile() error {
|
|||||||
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
|
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromFlags() error {
|
func fromFlags() error {
|
||||||
binds := make([]toBind, 0)
|
binds := make([]toBind, 0)
|
||||||
binds = append(binds, addFlag(flagPort, "51555", "port to bind to"))
|
binds = append(binds, addFlag(flagPort, "51555", "port to bind to"))
|
||||||
|
binds = append(binds, addFlag(flagMode, "domain", "[domain] or [path] to match"))
|
||||||
binds = append(binds, addFlag(flagConf, "", "configuration file path"))
|
binds = append(binds, addFlag(flagConf, "", "configuration file path"))
|
||||||
binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port"))
|
binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port"))
|
||||||
binds = append(binds, addFlag(flagCert, "", "path to .crt"))
|
binds = append(binds, addFlag(flagCert, "", "path to .crt"))
|
||||||
@@ -115,6 +127,7 @@ func fromFlags() error {
|
|||||||
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
|
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
|
||||||
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
|
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
|
||||||
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
|
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
|
||||||
|
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement to rewrite in response bodies"))
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
for _, bind := range binds {
|
for _, bind := range binds {
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"local/rproxy3/config"
|
||||||
"local/rproxy3/storage/packable"
|
"local/rproxy3/storage/packable"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -12,13 +15,25 @@ import (
|
|||||||
type redirPurge struct {
|
type redirPurge struct {
|
||||||
proxyHost string
|
proxyHost string
|
||||||
targetHost string
|
targetHost string
|
||||||
|
baseTransport http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
type rewrite struct {
|
||||||
|
rewrites map[string]string
|
||||||
|
baseTransport http.RoundTripper
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
newURL, err := s.lookup(r.Host)
|
newURL, err := s.lookup(mapKey(r, config.GetProxyMode()))
|
||||||
transport := &redirPurge{
|
var transport http.RoundTripper
|
||||||
|
transport = &redirPurge{
|
||||||
proxyHost: r.Host,
|
proxyHost: r.Host,
|
||||||
targetHost: newURL.Host,
|
targetHost: newURL.Host,
|
||||||
|
baseTransport: http.DefaultTransport,
|
||||||
|
}
|
||||||
|
transport = &rewrite{
|
||||||
|
rewrites: config.GetRewrites(mapKey(r, config.GetProxyMode())),
|
||||||
|
baseTransport: transport,
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
@@ -32,15 +47,29 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) lookup(host string) (*url.URL, error) {
|
func (s *Server) lookup(host string) (*url.URL, error) {
|
||||||
host = strings.Split(host, ".")[0]
|
|
||||||
host = strings.Split(host, ":")[0]
|
|
||||||
v := packable.NewURL()
|
v := packable.NewURL()
|
||||||
err := s.db.Get(nsRouting, host, v)
|
err := s.db.Get(nsRouting, host, v)
|
||||||
return v.URL(), err
|
return v.URL(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mapKey(r *http.Request, proxyMode string) string {
|
||||||
|
switch proxyMode {
|
||||||
|
case "domain":
|
||||||
|
host := strings.Split(r.Host, ".")[0]
|
||||||
|
host = strings.Split(host, ":")[0]
|
||||||
|
return host
|
||||||
|
case "path":
|
||||||
|
paths := strings.Split(r.URL.Path, "/")
|
||||||
|
if len(paths) < 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return paths[1]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
|
func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
resp, err := http.DefaultTransport.RoundTrip(r)
|
resp, err := rp.baseTransport.RoundTrip(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -49,3 +78,40 @@ func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
|
|||||||
}
|
}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
resp, err := rw.baseTransport.RoundTrip(r)
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
if len(rw.rewrites) == 0 {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
resp.Header.Del("Content-Length")
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
body := resp.Body
|
||||||
|
resp.Body = pr
|
||||||
|
go func() {
|
||||||
|
buff := make([]byte, 1024)
|
||||||
|
n, err := body.Read(buff)
|
||||||
|
for err == nil || n > 0 {
|
||||||
|
chunk := buff[:n]
|
||||||
|
for k, v := range rw.rewrites {
|
||||||
|
chunk = bytes.Replace(chunk, []byte(k), []byte(v), -1)
|
||||||
|
}
|
||||||
|
n = len(chunk)
|
||||||
|
m := 0
|
||||||
|
for m < n {
|
||||||
|
l, err := pw.Write(chunk[m:])
|
||||||
|
if err != nil {
|
||||||
|
pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m += l
|
||||||
|
}
|
||||||
|
n, err = body.Read(buff)
|
||||||
|
}
|
||||||
|
pw.CloseWithError(err)
|
||||||
|
}()
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,75 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeTransport struct{}
|
||||||
|
|
||||||
|
func (ft fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
Body: r.Body,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// empty url -> OK //TODO
|
// empty url -> OK //TODO
|
||||||
|
func TestRewrite(t *testing.T) {
|
||||||
|
transport := &rewrite{
|
||||||
|
rewrites: map[string]string{
|
||||||
|
"a": "b",
|
||||||
|
},
|
||||||
|
baseTransport: fakeTransport{},
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest("GET", "asdf", strings.NewReader("mary had a little lamb"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp, err := transport.RoundTrip(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if string(b) != "mbry hbd b little lbmb" {
|
||||||
|
t.Errorf("failed to replace: got %q, want \"b\"", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapKey(t *testing.T) {
|
||||||
|
r := &http.Request{
|
||||||
|
Host: "a.b.c:123",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/c/d/e",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := mapKey(r, "domain"); v != "a" {
|
||||||
|
t.Errorf("failed to get domain: got %v", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := mapKey(r, "path"); v != "c" {
|
||||||
|
t.Errorf("failed to get domain: got %v", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Host = "a:123"
|
||||||
|
if v := mapKey(r, "domain"); v != "a" {
|
||||||
|
t.Errorf("failed to get domain: got %v", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.URL.Path = ""
|
||||||
|
if v := mapKey(r, "path"); v != "" {
|
||||||
|
t.Errorf("failed to get domain: got %v", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.URL.Path = "/"
|
||||||
|
if v := mapKey(r, "path"); v != "" {
|
||||||
|
t.Errorf("failed to get domain: got %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"local/rproxy3/storage"
|
"local/rproxy3/storage"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServerStart(t *testing.T) {
|
func TestServerStart(t *testing.T) {
|
||||||
@@ -35,6 +38,7 @@ func mockServer() *Server {
|
|||||||
s := &Server{
|
s := &Server{
|
||||||
db: storage.NewMap(),
|
db: storage.NewMap(),
|
||||||
addr: ":" + port,
|
addr: ":" + port,
|
||||||
|
limiter: rate.NewLimiter(rate.Limit(50), 50),
|
||||||
}
|
}
|
||||||
if err := s.Routes(); err != nil {
|
if err := s.Routes(); err != nil {
|
||||||
panic(fmt.Sprintf("cannot initiate server routes; %v", err))
|
panic(fmt.Sprintf("cannot initiate server routes; %v", err))
|
||||||
@@ -49,6 +53,7 @@ func TestServerRoute(t *testing.T) {
|
|||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r, _ := http.NewRequest("GET", "http://world.localhost"+server.addr, nil)
|
r, _ := http.NewRequest("GET", "http://world.localhost"+server.addr, nil)
|
||||||
|
r = r.WithContext(context.Background())
|
||||||
server.ServeHTTP(w, r)
|
server.ServeHTTP(w, r)
|
||||||
if w.Code != 502 {
|
if w.Code != 502 {
|
||||||
t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code)
|
t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code)
|
||||||
|
|||||||
Reference in New Issue
Block a user