Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
698edf7e45 | ||
|
|
48e0048216 | ||
|
|
f58f6f7cf3 | ||
|
|
f72ecc5e53 | ||
|
|
3bd1527b98 |
@@ -1,5 +1,5 @@
|
||||
p: 54243
|
||||
r:
|
||||
r:
|
||||
- echo:http://localhost:49982
|
||||
- echo2:http://192.168.0.86:38090
|
||||
#crt: ./testdata/rproxy3server.crt
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"local/rproxy3/storage/packable"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@@ -60,8 +61,16 @@ func GetRate() (int, int) {
|
||||
b := packable.NewString()
|
||||
conf.Get(nsConf, flagBurst, b)
|
||||
|
||||
rate, _ := strconv.Atoi(r.String())
|
||||
rate, err := strconv.Atoi(r.String())
|
||||
if err != nil {
|
||||
log.Printf("illegal rate: %v", err)
|
||||
rate = 5
|
||||
}
|
||||
burst, _ := strconv.Atoi(b.String())
|
||||
if err != nil {
|
||||
log.Printf("illegal burst: %v", err)
|
||||
burst = 5
|
||||
}
|
||||
|
||||
return rate, burst
|
||||
}
|
||||
@@ -70,7 +79,30 @@ func GetTimeout() int {
|
||||
t := packable.NewString()
|
||||
conf.Get(nsConf, flagTimeout, t)
|
||||
|
||||
timeout, _ := strconv.Atoi(t.String())
|
||||
timeout, err := strconv.Atoi(t.String())
|
||||
if err != nil || timeout == 5 {
|
||||
return 5
|
||||
}
|
||||
|
||||
return timeout
|
||||
}
|
||||
|
||||
func GetRewrites(hostMatch string) map[string]string {
|
||||
v := packable.NewString()
|
||||
conf.Get(nsConf, flagRewrites, v)
|
||||
m := make(map[string]string)
|
||||
for _, v := range strings.Split(v.String(), ",") {
|
||||
vs := strings.Split(v, ":")
|
||||
if len(v) < 3 {
|
||||
continue
|
||||
}
|
||||
host := vs[0]
|
||||
if host != hostMatch {
|
||||
continue
|
||||
}
|
||||
from := vs[1]
|
||||
to := strings.Join(vs[2:], ":")
|
||||
m[from] = to
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ const flagPass = "pass"
|
||||
const flagRate = "rate"
|
||||
const flagBurst = "burst"
|
||||
const flagTimeout = "timeout"
|
||||
const flagRewrites = "rw"
|
||||
|
||||
var conf = storage.NewMap()
|
||||
|
||||
@@ -41,10 +42,12 @@ type fileConf struct {
|
||||
Rate string `yaml:"rate"`
|
||||
Burst string `yaml:"burst"`
|
||||
Timeout string `yaml:"timeout"`
|
||||
Rewrites []string `yaml:"rw"`
|
||||
}
|
||||
|
||||
func Init() error {
|
||||
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
|
||||
log.SetFlags(log.Ltime | log.Lshortfile)
|
||||
if err := fromFile(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -100,6 +103,9 @@ func fromFile() error {
|
||||
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -115,6 +121,7 @@ func fromFlags() error {
|
||||
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
|
||||
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
|
||||
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
|
||||
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement to rewrite in response bodies"))
|
||||
flag.Parse()
|
||||
|
||||
for _, bind := range binds {
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"local/rproxy3/config"
|
||||
"local/rproxy3/storage/packable"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -10,15 +14,28 @@ import (
|
||||
)
|
||||
|
||||
type redirPurge struct {
|
||||
proxyHost string
|
||||
targetHost string
|
||||
proxyHost string
|
||||
targetHost string
|
||||
baseTransport http.RoundTripper
|
||||
}
|
||||
|
||||
type rewrite struct {
|
||||
rewrites map[string]string
|
||||
baseTransport http.RoundTripper
|
||||
}
|
||||
|
||||
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
newURL, err := s.lookup(r.Host)
|
||||
transport := &redirPurge{
|
||||
proxyHost: r.Host,
|
||||
targetHost: newURL.Host,
|
||||
newURL, err := s.lookup(mapKey(r.Host))
|
||||
var transport http.RoundTripper
|
||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
transport = &redirPurge{
|
||||
proxyHost: r.Host,
|
||||
targetHost: newURL.Host,
|
||||
baseTransport: http.DefaultTransport,
|
||||
}
|
||||
transport = &rewrite{
|
||||
rewrites: config.GetRewrites(mapKey(r.Host)),
|
||||
baseTransport: transport,
|
||||
}
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
@@ -32,15 +49,19 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) lookup(host string) (*url.URL, error) {
|
||||
host = strings.Split(host, ".")[0]
|
||||
host = strings.Split(host, ":")[0]
|
||||
v := packable.NewURL()
|
||||
err := s.db.Get(nsRouting, host, v)
|
||||
return v.URL(), err
|
||||
}
|
||||
|
||||
func mapKey(host string) string {
|
||||
host = strings.Split(host, ".")[0]
|
||||
host = strings.Split(host, ":")[0]
|
||||
return host
|
||||
}
|
||||
|
||||
func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
resp, err := http.DefaultTransport.RoundTrip(r)
|
||||
resp, err := rp.baseTransport.RoundTrip(r)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -49,3 +70,40 @@ func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
resp, err := rw.baseTransport.RoundTrip(r)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
if len(rw.rewrites) == 0 {
|
||||
return resp, err
|
||||
}
|
||||
resp.Header.Del("Content-Length")
|
||||
pr, pw := io.Pipe()
|
||||
body := resp.Body
|
||||
resp.Body = pr
|
||||
go func() {
|
||||
buff := make([]byte, 1024)
|
||||
n, err := body.Read(buff)
|
||||
for err == nil || n > 0 {
|
||||
chunk := buff[:n]
|
||||
for k, v := range rw.rewrites {
|
||||
chunk = bytes.Replace(chunk, []byte(k), []byte(v), -1)
|
||||
}
|
||||
n = len(chunk)
|
||||
m := 0
|
||||
for m < n {
|
||||
l, err := pw.Write(chunk[m:])
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
m += l
|
||||
}
|
||||
n, err = body.Read(buff)
|
||||
}
|
||||
pw.CloseWithError(err)
|
||||
}()
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -1,3 +1,42 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeTransport struct{}
|
||||
|
||||
func (ft fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: r.Body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// empty url -> OK //TODO
|
||||
func TestRewrite(t *testing.T) {
|
||||
transport := &rewrite{
|
||||
rewrites: map[string]string{
|
||||
"a": "b",
|
||||
},
|
||||
baseTransport: fakeTransport{},
|
||||
}
|
||||
|
||||
r, err := http.NewRequest("GET", "asdf", strings.NewReader("mary had a little lamb"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := transport.RoundTrip(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(b) != "mbry hbd b little lbmb" {
|
||||
t.Errorf("failed to replace: got %q, want \"b\"", b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"local/rproxy3/storage"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func TestServerStart(t *testing.T) {
|
||||
@@ -33,8 +36,9 @@ func mockServer() *Server {
|
||||
port := strings.Split(portServer.URL, ":")[2]
|
||||
portServer.Close()
|
||||
s := &Server{
|
||||
db: storage.NewMap(),
|
||||
addr: ":" + port,
|
||||
db: storage.NewMap(),
|
||||
addr: ":" + port,
|
||||
limiter: rate.NewLimiter(rate.Limit(50), 50),
|
||||
}
|
||||
if err := s.Routes(); err != nil {
|
||||
panic(fmt.Sprintf("cannot initiate server routes; %v", err))
|
||||
@@ -49,6 +53,7 @@ func TestServerRoute(t *testing.T) {
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
r, _ := http.NewRequest("GET", "http://world.localhost"+server.addr, nil)
|
||||
r = r.WithContext(context.Background())
|
||||
server.ServeHTTP(w, r)
|
||||
if w.Code != 502 {
|
||||
t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code)
|
||||
|
||||
25
vendor/vendor.json
vendored
25
vendor/vendor.json
vendored
@@ -1,25 +0,0 @@
|
||||
{
|
||||
"comment": "",
|
||||
"ignore": "test",
|
||||
"package": [
|
||||
{
|
||||
"checksumSHA1": "GtamqiJoL7PGHsN454AoffBFMa8=",
|
||||
"path": "golang.org/x/net/context",
|
||||
"revision": "65e2d4e15006aab9813ff8769e768bbf4bb667a0",
|
||||
"revisionTime": "2019-02-01T23:59:58Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "HoCvrd3hEhsFeBOdEw7cbcfyk50=",
|
||||
"path": "golang.org/x/time/rate",
|
||||
"revision": "fbb02b2291d28baffd63558aa44b4b56f178d650",
|
||||
"revisionTime": "2018-04-12T16:56:04Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "QqDq2x8XOU7IoOR98Cx1eiV5QY8=",
|
||||
"path": "gopkg.in/yaml.v2",
|
||||
"revision": "51d6538a90f86fe93ac480b35f37b2be17fef232",
|
||||
"revisionTime": "2018-11-15T11:05:04Z"
|
||||
}
|
||||
],
|
||||
"rootPath": "local/rproxy3"
|
||||
}
|
||||
Reference in New Issue
Block a user