passes tests with no rewrites
parent
3bd1527b98
commit
f72ecc5e53
|
|
@ -2,6 +2,7 @@ package config
|
|||
|
||||
import (
|
||||
"local/rproxy3/storage/packable"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
|
@ -60,8 +61,16 @@ func GetRate() (int, int) {
|
|||
b := packable.NewString()
|
||||
conf.Get(nsConf, flagBurst, b)
|
||||
|
||||
rate, _ := strconv.Atoi(r.String())
|
||||
rate, err := strconv.Atoi(r.String())
|
||||
if err != nil {
|
||||
log.Printf("illegal rate: %v", err)
|
||||
rate = 5
|
||||
}
|
||||
burst, _ := strconv.Atoi(b.String())
|
||||
if err != nil {
|
||||
log.Printf("illegal burst: %v", err)
|
||||
burst = 5
|
||||
}
|
||||
|
||||
return rate, burst
|
||||
}
|
||||
|
|
@ -77,3 +86,18 @@ func GetTimeout() int {
|
|||
|
||||
return timeout
|
||||
}
|
||||
|
||||
func GetRewrites() map[string]string {
|
||||
v := packable.NewString()
|
||||
conf.Get(nsConf, flagRewrites, v)
|
||||
m := make(map[string]string)
|
||||
for _, v := range strings.Split(v.String(), ",") {
|
||||
if len(v) == 0 {
|
||||
return m
|
||||
}
|
||||
from := v[:strings.Index(v, ":")]
|
||||
to := v[strings.Index(v, ":")+1:]
|
||||
m[from] = to
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ const flagPass = "pass"
|
|||
const flagRate = "rate"
|
||||
const flagBurst = "burst"
|
||||
const flagTimeout = "timeout"
|
||||
const flagRewrites = "rw"
|
||||
|
||||
var conf = storage.NewMap()
|
||||
|
||||
|
|
@ -41,6 +42,7 @@ type fileConf struct {
|
|||
Rate string `yaml:"rate"`
|
||||
Burst string `yaml:"burst"`
|
||||
Timeout string `yaml:"timeout"`
|
||||
Rewrites []string `yaml:"rw"`
|
||||
}
|
||||
|
||||
func Init() error {
|
||||
|
|
@ -100,6 +102,9 @@ func fromFile() error {
|
|||
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -115,6 +120,7 @@ func fromFlags() error {
|
|||
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
|
||||
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
|
||||
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
|
||||
binds = append(binds, addFlag(flagRewrites, "", "comma-separated regex:v to rewrite in response bodies"))
|
||||
flag.Parse()
|
||||
|
||||
for _, bind := range binds {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"local/rproxy3/config"
|
||||
"local/rproxy3/storage/packable"
|
||||
"log"
|
||||
"net/http"
|
||||
|
|
@ -10,15 +12,27 @@ import (
|
|||
)
|
||||
|
||||
type redirPurge struct {
|
||||
proxyHost string
|
||||
targetHost string
|
||||
proxyHost string
|
||||
targetHost string
|
||||
baseTransport http.RoundTripper
|
||||
}
|
||||
|
||||
type rewrite struct {
|
||||
rewrites map[string]string
|
||||
baseTransport http.RoundTripper
|
||||
}
|
||||
|
||||
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
newURL, err := s.lookup(r.Host)
|
||||
transport := &redirPurge{
|
||||
proxyHost: r.Host,
|
||||
targetHost: newURL.Host,
|
||||
var transport http.RoundTripper
|
||||
transport = &redirPurge{
|
||||
proxyHost: r.Host,
|
||||
targetHost: newURL.Host,
|
||||
baseTransport: http.DefaultTransport,
|
||||
}
|
||||
transport = &rewrite{
|
||||
rewrites: config.GetRewrites(),
|
||||
baseTransport: transport,
|
||||
}
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
|
|
@ -40,7 +54,7 @@ func (s *Server) lookup(host string) (*url.URL, error) {
|
|||
}
|
||||
|
||||
func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
resp, err := http.DefaultTransport.RoundTrip(r)
|
||||
resp, err := rp.baseTransport.RoundTrip(r)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
|
@ -49,3 +63,30 @@ func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
|
|||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (rw *rewrite) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
resp, err := rw.baseTransport.RoundTrip(r)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
pr, pw := io.Pipe()
|
||||
body := resp.Body
|
||||
resp.Body = pr
|
||||
go func() {
|
||||
buff := make([]byte, 1024)
|
||||
for n, err := body.Read(buff); err == nil || n > 0; n, err = body.Read(buff) {
|
||||
chunk := buff[:n]
|
||||
m := 0
|
||||
for m < n {
|
||||
l, err := pw.Write(chunk[m:])
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
m += l
|
||||
}
|
||||
}
|
||||
pw.CloseWithError(err)
|
||||
}()
|
||||
return resp, err
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue