1 Commits

Author SHA1 Message Date
Bel LaPointe
06ba2dfdc1 Add path option for proxying 2019-04-10 11:15:17 -06:00
33 changed files with 346 additions and 833 deletions

View File

@@ -1,115 +0,0 @@
package config
import (
"local/rproxy3/storage/packable"
"log"
"strconv"
"strings"
)
func GetPort() string {
v := packable.NewString()
conf.Get(nsConf, flagPort, v)
return ":" + strings.TrimPrefix(v.String(), ":")
}
func GetRoutes() map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRoutes, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
if len(v) == 0 {
return m
}
from := v[:strings.Index(v, ":")]
to := v[strings.Index(v, ":")+1:]
m[from] = to
}
return m
}
func GetTCP() (string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagTCP, v)
tcpAddr := v.String()
return tcpAddr, notEmpty(tcpAddr)
}
func GetSSL() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagCert, v)
certPath := v.String()
conf.Get(nsConf, flagKey, v)
keyPath := v.String()
return certPath, keyPath, notEmpty(certPath, keyPath)
}
func GetAuth() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagUser, v)
user := v.String()
conf.Get(nsConf, flagPass, v)
pass := v.String()
return user, pass, notEmpty(user, pass)
}
func notEmpty(s ...string) bool {
for i := range s {
if s[i] == "" || s[i] == "/dev/null" {
return false
}
}
return true
}
func GetRate() (int, int) {
r := packable.NewString()
conf.Get(nsConf, flagRate, r)
b := packable.NewString()
conf.Get(nsConf, flagBurst, b)
rate, err := strconv.Atoi(r.String())
if err != nil {
log.Printf("illegal rate: %v", err)
rate = 5
}
burst, _ := strconv.Atoi(b.String())
if err != nil {
log.Printf("illegal burst: %v", err)
burst = 5
}
return rate, burst
}
func GetTimeout() int {
t := packable.NewString()
conf.Get(nsConf, flagTimeout, t)
timeout, err := strconv.Atoi(t.String())
if err != nil || timeout == 5 {
return 5
}
return timeout
}
func GetRewrites(hostMatch string) map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRewrites, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
vs := strings.Split(v, ":")
if len(v) < 3 {
continue
}
host := vs[0]
if host != hostMatch {
continue
}
from := vs[1]
to := strings.Join(vs[2:], ":")
m[from] = to
}
return m
}

View File

@@ -1,161 +0,0 @@
package config
import (
"flag"
"io/ioutil"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
"log"
"os"
"strings"
yaml "gopkg.in/yaml.v2"
)
const nsConf = "configuration"
const flagPort = "p"
const flagRoutes = "r"
const flagConf = "c"
const flagCert = "crt"
const flagTCP = "tcp"
const flagKey = "key"
const flagUser = "user"
const flagPass = "pass"
const flagRate = "rate"
const flagBurst = "burst"
const flagTimeout = "timeout"
const flagRewrites = "rw"
var conf = storage.NewMap()
type toBind struct {
flag string
value *string
}
type fileConf struct {
Port string `yaml:"p"`
Routes []string `yaml:"r"`
CertPath string `yaml:"crt"`
TCPPath string `yaml:"tcp"`
KeyPath string `yaml:"key"`
Username string `yaml:"user"`
Password string `yaml:"pass"`
Rate string `yaml:"rate"`
Burst string `yaml:"burst"`
Timeout string `yaml:"timeout"`
Rewrites []string `yaml:"rw"`
}
func Init() error {
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
log.SetFlags(log.Ltime | log.Lshortfile)
if err := fromFile(); err != nil {
return err
}
if err := fromFlags(); err != nil {
return err
}
return nil
}
func fromFile() error {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
defer func() {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
}()
flag.String(flagConf, "/dev/null", "yaml config file path")
flag.Parse()
confFlag := flag.Lookup(flagConf)
if confFlag == nil || confFlag.Value.String() == "" {
return nil
}
confBytes, err := ioutil.ReadFile(confFlag.Value.String())
if err != nil {
return err
}
var c fileConf
if err := yaml.Unmarshal(confBytes, &c); err != nil {
return err
}
if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil {
return err
}
if err := conf.Set(nsConf, flagCert, packable.NewString(c.CertPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTCP, packable.NewString(c.TCPPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagKey, packable.NewString(c.KeyPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagUser, packable.NewString(c.Username)); err != nil {
return err
}
if err := conf.Set(nsConf, flagPass, packable.NewString(c.Password)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRate, packable.NewString(c.Rate)); err != nil {
return err
}
if err := conf.Set(nsConf, flagBurst, packable.NewString(c.Burst)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
return err
}
return nil
}
func fromFlags() error {
binds := make([]toBind, 0)
binds = append(binds, addFlag(flagPort, "51555", "port to bind to"))
binds = append(binds, addFlag(flagConf, "", "configuration file path"))
binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port"))
binds = append(binds, addFlag(flagCert, "", "path to .crt"))
binds = append(binds, addFlag(flagTCP, "", "tcp addr"))
binds = append(binds, addFlag(flagKey, "", "path to .key"))
binds = append(binds, addFlag(flagUser, "", "basic auth username"))
binds = append(binds, addFlag(flagPass, "", "basic auth password"))
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement:oauth to rewrite in response bodies"))
flag.Parse()
for _, bind := range binds {
confFlag := flag.Lookup(bind.flag)
if confFlag == nil || confFlag.Value.String() == "" {
continue
}
if err := conf.Set(nsConf, bind.flag, packable.NewString(*bind.value)); err != nil {
return err
}
}
return nil
}
func addFlag(key, def, help string) toBind {
def = getFlagOrDefault(key, def)
v := flag.String(key, def, help)
return toBind{
flag: key,
value: v,
}
}
func getFlagOrDefault(key, def string) string {
v := packable.NewString()
if err := conf.Get(nsConf, key, v); err != nil {
return def
}
return v.String()
}

4
.gitignore vendored Executable file → Normal file
View File

@@ -1,10 +1,6 @@
lz4
rclone
rcloner
exec
exec-*
**/exec
**/exec-*
Go
cloudly
dockfile

View File

@@ -1,16 +0,0 @@
FROM golang:1.13-alpine as certs
RUN apk update && apk add --no-cache ca-certificates
FROM busybox:glibc
RUN mkdir -p /var/log
WORKDIR /main
COPY --from=certs /etc/ssl/certs /etc/ssl/certs
COPY . .
ENV GOPATH=""
ENV MNT="/mnt/"
ENTRYPOINT ["/main/exec-rproxy3"]
CMD []

11
conf.yaml Normal file
View File

@@ -0,0 +1,11 @@
p: 54243
r:
- echo:http://localhost:49982
- echo2:http://192.168.0.86:38090
#crt: ./testdata/rproxy3server.crt
#key: ./testdata/rproxy3server.key
#user: bel
#pass: bel
rate: 1
burst: 2
timeout: 10

180
config/config.go Executable file → Normal file
View File

@@ -1,114 +1,118 @@
package config
import (
"encoding/json"
"fmt"
"local/rproxy3/storage/packable"
"log"
"strconv"
"strings"
"time"
)
type Proxy struct {
To string
}
func parseProxy(s string) (string, Proxy) {
p := Proxy{}
key := ""
l := strings.Split(s, ",")
if len(l) > 0 {
key = l[0]
}
if len(l) > 1 {
p.To = l[1]
}
return key, p
}
func GetAuthelia() (string, bool) {
authelia := conf.Get("authelia").GetString()
return authelia, authelia != ""
}
func GetBOAuthZ() (string, bool) {
boauthz := conf.Get("oauth").GetString()
return boauthz, boauthz != ""
}
func GetAuth() (string, string, bool) {
user := conf.Get("user").GetString()
pass := conf.Get("pass").GetString()
return user, pass, user != "" && pass != ""
}
func GetTrim() string {
return conf.Get("trim").GetString()
}
func GetPort() string {
port := conf.Get("p").GetInt()
return ":" + fmt.Sprint(port)
v := packable.NewString()
conf.Get(nsConf, flagPort, v)
return ":" + strings.TrimPrefix(v.String(), ":")
}
func GetAltPort() string {
port := conf.Get("ap").GetInt()
return ":" + fmt.Sprint(port)
}
func GetRate() (int, int) {
rate := conf.Get("r").GetInt()
burst := conf.Get("b").GetInt()
log.Println("rate/burst:", rate, burst)
return rate, burst
}
func GetRoutes() map[string]Proxy {
list := conf.Get("proxy").GetString()
definitions := strings.Split(list, ",,")
routes := make(map[string]Proxy)
for _, definition := range definitions {
k, v := parseProxy(definition)
routes[k] = v
func GetRoutes() map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRoutes, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
if len(v) == 0 {
return m
}
from := v[:strings.Index(v, ":")]
to := v[strings.Index(v, ":")+1:]
m[from] = to
}
return routes
return m
}
func GetSSL() (string, string, bool) {
crt := conf.Get("crt").GetString()
key := conf.Get("key").GetString()
return crt, key, crt != "" && key != ""
v := packable.NewString()
conf.Get(nsConf, flagCert, v)
certPath := v.String()
conf.Get(nsConf, flagKey, v)
keyPath := v.String()
return certPath, keyPath, notEmpty(certPath, keyPath)
}
func GetTCP() (string, bool) {
tcp := conf.Get("tcp").GetString()
return tcp, tcp != ""
func GetAuth() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagUser, v)
user := v.String()
conf.Get(nsConf, flagPass, v)
pass := v.String()
return user, pass, notEmpty(user, pass)
}
func GetTimeout() time.Duration {
timeout := conf.Get("timeout").GetDuration()
func notEmpty(s ...string) bool {
for i := range s {
if s[i] == "" || s[i] == "/dev/null" {
return false
}
}
return true
}
func GetRate() (int, int) {
r := packable.NewString()
conf.Get(nsConf, flagRate, r)
b := packable.NewString()
conf.Get(nsConf, flagBurst, b)
rate, err := strconv.Atoi(r.String())
if err != nil {
log.Printf("illegal rate: %v", err)
rate = 5
}
burst, _ := strconv.Atoi(b.String())
if err != nil {
log.Printf("illegal burst: %v", err)
burst = 5
}
return rate, burst
}
func GetTimeout() int {
t := packable.NewString()
conf.Get(nsConf, flagTimeout, t)
timeout, err := strconv.Atoi(t.String())
if err != nil || timeout == 5 {
return 5
}
return timeout
}
func GetCORS(key string) bool {
cors := conf.GetString("cors")
var m map[string]bool
if err := json.Unmarshal([]byte(cors), &m); err != nil {
return false
func GetRewrites(hostMatch string) map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRewrites, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
vs := strings.Split(v, ":")
if len(v) < 3 {
continue
}
host := vs[0]
if host != hostMatch {
continue
}
from := vs[1]
to := strings.Join(vs[2:], ":")
m[from] = to
}
_, ok := m[key]
return ok
return m
}
func GetNoPath(key string) bool {
nopath := conf.GetString("nopath")
var m map[string]bool
if err := json.Unmarshal([]byte(nopath), &m); err != nil {
return false
func GetProxyMode() string {
v := packable.NewString()
conf.Get(nsConf, flagMode, v)
s := v.String()
if s == "" {
return "domain"
}
_, ok := m[key]
return ok
}
func GetCompression() bool {
return conf.GetBool("compression")
return s
}

183
config/new.go Executable file → Normal file
View File

@@ -1,62 +1,161 @@
package config
import (
"fmt"
"local/args"
"local/logb"
"flag"
"io/ioutil"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
"log"
"os"
"strings"
"time"
yaml "gopkg.in/yaml.v2"
)
var conf *args.ArgSet
const nsConf = "configuration"
const flagPort = "p"
const flagMode = "mode"
const flagRoutes = "r"
const flagConf = "c"
const flagCert = "crt"
const flagKey = "key"
const flagUser = "user"
const flagPass = "pass"
const flagRate = "rate"
const flagBurst = "burst"
const flagTimeout = "timeout"
const flagRewrites = "rw"
func init() {
if err := Refresh(); err != nil {
panic(err)
}
var conf = storage.NewMap()
type toBind struct {
flag string
value *string
}
func Refresh() error {
type fileConf struct {
Port string `yaml:"p"`
Mode string `yaml:"mode"`
Routes []string `yaml:"r"`
CertPath string `yaml:"crt"`
KeyPath string `yaml:"key"`
Username string `yaml:"user"`
Password string `yaml:"pass"`
Rate string `yaml:"rate"`
Burst string `yaml:"burst"`
Timeout string `yaml:"timeout"`
Rewrites []string `yaml:"rw"`
}
func Init() error {
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
log.SetFlags(log.Ltime | log.Lshortfile)
as, err := parseArgs()
if err != nil && !strings.Contains(fmt.Sprint(os.Args), "-test") {
if err := fromFile(); err != nil {
return err
}
if err := fromFlags(); err != nil {
return err
}
conf = as
logb.Set(logb.LevelFromString(as.GetString("level")))
return nil
}
func parseArgs() (*args.ArgSet, error) {
configFiles := []string{}
if v, ok := os.LookupEnv("CONFIG"); ok {
configFiles = strings.Split(v, ",")
func fromFile() error {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
defer func() {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
}()
flag.String(flagConf, "/dev/null", "yaml config file path")
flag.Parse()
confFlag := flag.Lookup(flagConf)
if confFlag == nil || confFlag.Value.String() == "" {
return nil
}
as := args.NewArgSet(configFiles...)
as.Append(args.STRING, "user", "username for basic auth", "")
as.Append(args.STRING, "pass", "password for basic auth", "")
as.Append(args.INT, "p", "port for service", 51555)
as.Append(args.INT, "ap", "alt port for always http service", 51556)
as.Append(args.INT, "r", "rate per second for requests", 100)
as.Append(args.INT, "b", "burst requests", 100)
as.Append(args.BOOL, "compress", "enable compression", true)
as.Append(args.STRING, "crt", "path to crt for ssl", "")
as.Append(args.STRING, "key", "path to key for ssl", "")
as.Append(args.STRING, "trim", "path prefix to trim, like '/abc' to change '/abc/def' to '/def'", "")
as.Append(args.STRING, "tcp", "address for tcp only tunnel", "")
as.Append(args.DURATION, "timeout", "timeout for tunnel", time.Minute)
as.Append(args.STRING, "proxy", "double-comma separated (+ if auth)from,scheme://to.tld:port,,", "")
as.Append(args.STRING, "oauth", "url for boauthz", "")
as.Append(args.STRING, "authelia", "url for authelia", "")
as.Append(args.STRING, "cors", "json dict key:true for keys to set CORS permissive headers, like {\"from\":true}", "{}")
as.Append(args.STRING, "nopath", "json dict key:true for keys to remove all path info from forwarded request, like -cors", "{}")
as.Append(args.STRING, "level", "log level", "info")
err := as.Parse()
return as, err
confBytes, err := ioutil.ReadFile(confFlag.Value.String())
if err != nil {
return err
}
var c fileConf
if err := yaml.Unmarshal(confBytes, &c); err != nil {
return err
}
if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil {
return err
}
if err := conf.Set(nsConf, flagMode, packable.NewString(c.Mode)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil {
return err
}
if err := conf.Set(nsConf, flagCert, packable.NewString(c.CertPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagKey, packable.NewString(c.KeyPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagUser, packable.NewString(c.Username)); err != nil {
return err
}
if err := conf.Set(nsConf, flagPass, packable.NewString(c.Password)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRate, packable.NewString(c.Rate)); err != nil {
return err
}
if err := conf.Set(nsConf, flagBurst, packable.NewString(c.Burst)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
return err
}
return nil
}
func fromFlags() error {
binds := make([]toBind, 0)
binds = append(binds, addFlag(flagPort, "51555", "port to bind to"))
binds = append(binds, addFlag(flagMode, "domain", "[domain] or [path] to match"))
binds = append(binds, addFlag(flagConf, "", "configuration file path"))
binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port"))
binds = append(binds, addFlag(flagCert, "", "path to .crt"))
binds = append(binds, addFlag(flagKey, "", "path to .key"))
binds = append(binds, addFlag(flagUser, "", "basic auth username"))
binds = append(binds, addFlag(flagPass, "", "basic auth password"))
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement to rewrite in response bodies"))
flag.Parse()
for _, bind := range binds {
confFlag := flag.Lookup(bind.flag)
if confFlag == nil || confFlag.Value.String() == "" {
continue
}
if err := conf.Set(nsConf, bind.flag, packable.NewString(*bind.value)); err != nil {
return err
}
}
return nil
}
func addFlag(key, def, help string) toBind {
def = getFlagOrDefault(key, def)
v := flag.String(key, def, help)
return toBind{
flag: key,
value: v,
}
}
func getFlagOrDefault(key, def string) string {
v := packable.NewString()
if err := conf.Get(nsConf, key, v); err != nil {
return def
}
return v.String()
}

0
.config/new_test.go → config/new_test.go Executable file → Normal file
View File

View File

@@ -1,11 +0,0 @@
user: ""
pass: ""
port: 51555
r: 100
b: 100
crt: ""
key: ""
tcp: ""
timeout: 1m
proxy: a,http://localhost:41912,,+b,http://localhost:41912
oauth: http://localhost:23456

2
main.go Executable file → Normal file
View File

@@ -6,7 +6,7 @@ import (
)
func main() {
if err := config.Refresh(); err != nil {
if err := config.Init(); err != nil {
panic(err)
}

14
main_test.go Executable file → Normal file
View File

@@ -34,8 +34,8 @@ func TestHTTPSMain(t *testing.T) {
"username",
"-pass",
"password",
"-proxy",
"hello," + addr,
"-r",
"hello:" + addr,
"-crt",
"./testdata/rproxy3server.crt",
"-key",
@@ -51,7 +51,7 @@ func TestHTTPSMain(t *testing.T) {
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
r, _ := http.NewRequest("GET", "https://hello.localhost:"+port, nil)
r, _ := http.NewRequest("GET", "https://hello.localhost"+port, nil)
if resp, err := client.Do(r); err != nil {
t.Fatalf("client failed: %v", err)
@@ -89,8 +89,8 @@ func TestHTTPMain(t *testing.T) {
"username",
"-pass",
"password",
"-proxy",
"hello," + addr,
"-r",
"hello:" + addr,
}
main()
}()
@@ -98,7 +98,7 @@ func TestHTTPMain(t *testing.T) {
time.Sleep(time.Millisecond * 100)
client := &http.Client{}
r, _ := http.NewRequest("GET", "http://hello.localhost:"+port, nil)
r, _ := http.NewRequest("GET", "http://hello.localhost"+port, nil)
if resp, err := client.Do(r); err != nil {
t.Fatalf("client failed: %v", err)
@@ -127,5 +127,5 @@ func echoServer() (string, func()) {
func getPort() string {
s := httptest.NewServer(nil)
s.Close()
return s.URL[strings.LastIndex(s.URL, ":")+1:]
return s.URL[strings.LastIndex(s.URL, ":"):]
}

7
server/new.go Executable file → Normal file
View File

@@ -9,15 +9,10 @@ import (
func New() *Server {
port := config.GetPort()
altport := config.GetAltPort()
r, b := config.GetRate()
server := &Server{
return &Server{
db: storage.NewMap(),
addr: port,
altaddr: altport,
limiter: rate.NewLimiter(rate.Limit(r), b),
}
_, server.auth.BOAuthZ = config.GetBOAuthZ()
_, server.auth.Authelia = config.GetAuthelia()
return server
}

0
server/new_test.go Executable file → Normal file
View File

37
server/proxy.go Executable file → Normal file
View File

@@ -2,7 +2,6 @@ package server
import (
"bytes"
"crypto/tls"
"io"
"local/rproxy3/config"
"local/rproxy3/storage/packable"
@@ -25,21 +24,23 @@ type rewrite struct {
}
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
newURL, err := s.lookup(mapKey(r.Host))
r.URL.Path = strings.TrimPrefix(r.URL.Path, config.GetTrim())
newURL, err := s.lookup(mapKey(r, config.GetProxyMode()))
var transport http.RoundTripper
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
transport = &redirPurge{
proxyHost: r.Host,
targetHost: newURL.Host,
baseTransport: http.DefaultTransport,
}
transport = &rewrite{
rewrites: config.GetRewrites(mapKey(r, config.GetProxyMode())),
baseTransport: transport,
}
if err != nil {
http.NotFound(w, r)
log.Printf("unknown host lookup %q", r.Host)
return
}
//r.Host = newURL.Host
r.Host = newURL.Host
proxy := httputil.NewSingleHostReverseProxy(newURL)
proxy.Transport = transport
proxy.ServeHTTP(w, r)
@@ -51,16 +52,20 @@ func (s *Server) lookup(host string) (*url.URL, error) {
return v.URL(), err
}
func (s *Server) lookupAuth(host string) (bool, error) {
v := packable.NewString()
err := s.db.Get(nsBOAuthZ, host, v)
return v.String() == "true", err
}
func mapKey(host string) string {
host = strings.Split(host, ".")[0]
host = strings.Split(host, ":")[0]
return host
func mapKey(r *http.Request, proxyMode string) string {
switch proxyMode {
case "domain":
host := strings.Split(r.Host, ".")[0]
host = strings.Split(host, ":")[0]
return host
case "path":
paths := strings.Split(r.URL.Path, "/")
if len(paths) < 2 {
return ""
}
return paths[1]
}
return ""
}
func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
@@ -71,8 +76,6 @@ func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) {
if loc := resp.Header.Get("Location"); loc != "" {
resp.Header.Set("Location", strings.Replace(loc, rp.targetHost, rp.proxyHost, 1))
}
// google floc https://paramdeo.com/blog/opting-your-website-out-of-googles-floc-network
resp.Header.Set("Permissions-Policy", "interest-cohort=()")
return resp, err
}

33
server/proxy_test.go Executable file → Normal file
View File

@@ -3,6 +3,7 @@ package server
import (
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"
)
@@ -40,3 +41,35 @@ func TestRewrite(t *testing.T) {
t.Errorf("failed to replace: got %q, want \"b\"", b)
}
}
func TestMapKey(t *testing.T) {
r := &http.Request{
Host: "a.b.c:123",
URL: &url.URL{
Path: "/c/d/e",
},
}
if v := mapKey(r, "domain"); v != "a" {
t.Errorf("failed to get domain: got %v", v)
}
if v := mapKey(r, "path"); v != "c" {
t.Errorf("failed to get domain: got %v", v)
}
r.Host = "a:123"
if v := mapKey(r, "domain"); v != "a" {
t.Errorf("failed to get domain: got %v", v)
}
r.URL.Path = ""
if v := mapKey(r, "path"); v != "" {
t.Errorf("failed to get domain: got %v", v)
}
r.URL.Path = "/"
if v := mapKey(r, "path"); v != "" {
t.Errorf("failed to get domain: got %v", v)
}
}

0
server/routes.go Executable file → Normal file
View File

0
server/routes_test.go Executable file → Normal file
View File

257
server/server.go Executable file → Normal file
View File

@@ -5,18 +5,12 @@ import (
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"local/logb"
"local/oauth2/oauth2client"
"local/rproxy3/config"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
"log"
"net"
"net/http"
"net/url"
"path"
"strings"
"time"
@@ -24,14 +18,12 @@ import (
)
const nsRouting = "routing"
const nsBOAuthZ = "oauth"
type listenerScheme int
const (
schemeHTTP listenerScheme = iota
schemeHTTPS listenerScheme = iota
schemeTCP listenerScheme = iota
)
func (ls listenerScheme) String() string {
@@ -40,8 +32,6 @@ func (ls listenerScheme) String() string {
return "http"
case schemeHTTPS:
return "https"
case schemeTCP:
return "tcp"
}
return ""
}
@@ -49,36 +39,32 @@ func (ls listenerScheme) String() string {
type Server struct {
db storage.DB
addr string
altaddr string
username string
password string
limiter *rate.Limiter
auth struct {
BOAuthZ bool
Authelia bool
}
}
func (s *Server) Route(src string, dst config.Proxy) error {
hasOAuth := strings.HasPrefix(src, "+")
src = strings.TrimPrefix(src, "+")
log.Printf("Adding route %q -> %v...\n", src, dst)
u, err := url.Parse(dst.To)
func (s *Server) Route(src, dst string) error {
log.Printf("Adding route %q -> %q...\n", src, dst)
u, err := url.Parse(dst)
if err != nil {
return err
}
s.db.Set(nsBOAuthZ, src, packable.NewString(fmt.Sprint(hasOAuth)))
return s.db.Set(nsRouting, src, packable.NewURL(u))
}
func (s *Server) Run() error {
go s.alt()
scheme := getScheme()
scheme := schemeHTTP
if _, _, ok := config.GetSSL(); ok {
scheme = schemeHTTPS
}
log.Printf("Listening for %v on %v...\n", scheme, s.addr)
switch scheme {
case schemeHTTP:
log.Printf("Serve http")
return http.ListenAndServe(s.addr, s)
case schemeHTTPS:
log.Printf("Serve https")
c, k, _ := config.GetSSL()
httpsServer := &http.Server{
Addr: s.addr,
@@ -97,115 +83,15 @@ func (s *Server) Run() error {
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0),
}
return httpsServer.ListenAndServeTLS(c, k)
case schemeTCP:
addr, _ := config.GetTCP()
return s.ServeTCP(addr)
}
return errors.New("did not load server")
}
func (s *Server) doAuthelia(foo http.HandlerFunc) http.HandlerFunc {
func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authelia, ok := config.GetAuthelia()
if !ok {
panic("howd i get here")
}
url, err := url.Parse(authelia)
if err != nil {
panic(fmt.Sprintf("bad config for authelia url: %v", err))
}
url.Path = "/api/verify"
logb.Verbosef("authelia @ %s", url.String())
req, err := http.NewRequest(http.MethodGet, url.String(), nil)
if err != nil {
panic(err.Error())
}
r2 := r.Clone(r.Context())
if r2.URL.Host == "" {
r2.URL.Host = r2.Host
}
if r2.URL.Scheme == "" {
r2.URL.Scheme = "https"
}
for _, httpreq := range []*http.Request{r, req} {
for k, v := range map[string]string{
"X-Original-Url": r2.URL.String(),
"X-Forwarded-Proto": r2.URL.Scheme,
"X-Forwarded-Host": r2.URL.Host,
"X-Forwarded-Uri": r2.URL.String(),
} {
if _, ok := httpreq.Header[k]; !ok {
logb.Verbosef("authelia header setting %s:%s", k, v)
httpreq.Header.Set(k, v)
}
}
}
if cookie, err := r.Cookie("authelia_session"); err == nil {
logb.Verbosef("authelia session found in cookies; %+v", cookie)
req.AddCookie(cookie)
}
c := &http.Client{
Timeout: time.Minute,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
autheliaKey := mapKey(req.Host)
logb.Verbosef("request to %s is authelia %s? %v", r.Host, autheliaKey, strings.HasPrefix(r.Host, autheliaKey))
if strings.HasPrefix(r.Host, autheliaKey) {
logb.Debugf("no authelia for %s because it has prefix %s", r.Host, autheliaKey)
foo(w, r)
return
}
resp, err := c.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
logb.Debugf(
"authelia: %+v, %+v \n\t-> \n\t(%d) %+v, %+v",
req,
req.Cookies(),
resp.StatusCode,
resp.Header,
resp.Cookies(),
)
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
for k := range resp.Header {
if strings.HasPrefix(k, "Remote-") {
cookie := &http.Cookie{
Name: k,
Value: resp.Header.Get(k),
Path: "/",
SameSite: http.SameSiteLaxMode,
Expires: time.Now().Add(24 * time.Hour * 30),
}
logb.Verbosef("setting authelia cookie in response: %+v", cookie)
http.SetCookie(w, cookie)
logb.Verbosef("setting authelia cookie in request: %+v", cookie)
r.AddCookie(cookie)
}
}
foo(w, r)
return
}
url.Path = ""
q := url.Query()
q.Set("rd", r2.URL.String())
url.RawQuery = q.Encode()
logb.Verbosef("authelia status %d, rd'ing %s", resp.StatusCode, url.String())
http.Redirect(w, r, url.String(), http.StatusFound)
}
}
func (s *Server) doBOAuthZ(foo http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
key := mapKey(r.Host)
rusr, rpwd, ok := config.GetAuth()
if ok {
//usr, pwd := getProxyAuth(r)
usr, pwd, ok := r.BasicAuth()
if !ok || rusr != usr || rpwd != pwd {
w.WriteHeader(http.StatusUnauthorized)
@@ -213,75 +99,19 @@ func (s *Server) doBOAuthZ(foo http.HandlerFunc) http.HandlerFunc {
return
}
}
ok, err := s.lookupAuth(key)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
if url, exists := config.GetBOAuthZ(); ok && exists {
err := oauth2client.Authenticate(url, key, w, r)
if err != nil {
return
}
}
if config.GetNoPath(key) && path.Ext(r.URL.Path) == "" {
r.URL.Path = "/"
}
foo(w, r)
}
}
func (s *Server) ServeTCP(addr string) error {
listen, err := net.Listen("tcp", s.addr)
if err != nil {
return err
}
for {
c, err := listen.Accept()
if err != nil {
return err
}
go func(c net.Conn) {
d, err := net.Dial("tcp", addr)
if err != nil {
log.Println(err)
return
}
go pipe(c, d)
go pipe(d, c)
}(c)
}
}
func pipe(a, b net.Conn) {
log.Println("open pipe")
defer log.Println("close pipe")
defer a.Close()
defer b.Close()
io.Copy(a, b)
}
func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, can := context.WithTimeout(r.Context(), time.Duration(config.GetTimeout()))
ctx, can := context.WithTimeout(r.Context(), time.Second*time.Duration(config.GetTimeout()))
defer can()
if err := s.limiter.Wait(ctx); err != nil {
w.WriteHeader(http.StatusTooManyRequests)
return
}
w, did := s.doCORS(w, r)
if did {
return
}
if s.auth.BOAuthZ {
logb.Verbosef("doing boauthz for request to %s", r.URL.String())
s.doBOAuthZ(foo)(w, r)
} else if s.auth.Authelia {
logb.Verbosef("doing authelia for request to %s", r.URL.String())
s.doAuthelia(foo)(w, r)
} else {
foo(w, r)
}
s.doAuth(foo)(w, r)
}
}
@@ -289,31 +119,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.Pre(s.Proxy)(w, r)
}
type corsResponseWriter struct {
http.ResponseWriter
}
func (cb corsResponseWriter) WriteHeader(code int) {
cb.Header().Set("Access-Control-Allow-Origin", "*")
cb.Header().Set("Access-Control-Allow-Headers", "X-Auth-Token, content-type, Content-Type")
cb.ResponseWriter.WriteHeader(code)
}
func (s *Server) doCORS(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, bool) {
key := mapKey(r.Host)
if !config.GetCORS(key) {
return w, false
}
w = corsResponseWriter{ResponseWriter: w}
if r.Method != "OPTIONS" {
return w, false
}
w.Header().Set("Content-Length", "0")
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE")
return w, true
}
func getProxyAuth(r *http.Request) (string, string) {
proxyAuthHeader := r.Header.Get("Proxy-Authorization")
proxyAuthB64 := strings.TrimPrefix(proxyAuthHeader, "Basic ")
@@ -325,39 +130,3 @@ func getProxyAuth(r *http.Request) (string, string) {
proxyAuthSplit := strings.Split(proxyAuth, ":")
return proxyAuthSplit[0], proxyAuthSplit[1]
}
func (s *Server) alt() {
switch getScheme() {
case schemeHTTP:
case schemeHTTPS:
default:
return
}
foo := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL.Scheme = getScheme().String()
if hostname := r.URL.Hostname(); hostname != "" {
r.URL.Host = r.URL.Hostname() + s.addr
} else if hostname := r.URL.Host; hostname != "" {
r.URL.Host = r.URL.Host + s.addr
} else {
u := url.URL{Host: r.Host}
r.URL.Host = u.Hostname() + s.addr
}
http.Redirect(w, r, r.URL.String(), http.StatusSeeOther)
})
log.Println("redirecting from", s.altaddr)
if err := http.ListenAndServe(s.altaddr, foo); err != nil {
panic(err)
}
}
func getScheme() listenerScheme {
scheme := schemeHTTP
if _, _, ok := config.GetSSL(); ok {
scheme = schemeHTTPS
}
if _, ok := config.GetTCP(); ok {
scheme = schemeTCP
}
return scheme
}

11
server/server_test.go Executable file → Normal file
View File

@@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
"local/rproxy3/config"
"local/rproxy3/storage"
"net/http"
"net/http/httptest"
@@ -16,10 +15,7 @@ import (
func TestServerStart(t *testing.T) {
server := mockServer()
p := config.Proxy{
To: "http://hello.localhost" + server.addr,
}
if err := server.Route("world", p); err != nil {
if err := server.Route("world", "http://hello.localhost"+server.addr); err != nil {
t.Fatalf("cannot add route: %v", err)
}
@@ -52,10 +48,7 @@ func mockServer() *Server {
func TestServerRoute(t *testing.T) {
server := mockServer()
p := config.Proxy{
To: "http://hello.localhost" + server.addr,
}
if err := server.Route("world", p); err != nil {
if err := server.Route("world", "http://hello.localhost"+server.addr); err != nil {
t.Fatalf("cannot add route: %v", err)
}
w := httptest.NewRecorder()

0
storage/db.go Executable file → Normal file
View File

0
storage/db_test.go Executable file → Normal file
View File

0
storage/map.go Executable file → Normal file
View File

0
storage/packable/packable.go Executable file → Normal file
View File

0
storage/packable/packable_test.go Executable file → Normal file
View File

0
testdata/Bserver.crt vendored Executable file → Normal file
View File

0
testdata/Bserver.key vendored Executable file → Normal file
View File

0
testdata/Bserver.pkcs12 vendored Executable file → Normal file
View File

36
testdata/index.html vendored
View File

@@ -1,36 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="X-UA-Compatible" content="ie=edge" />
<title>Go WebSocket Tutorial</title>
</head>
<body>
<h2>Hello World</h2>
<script>
let socket = new WebSocket("ws://a.bel.test:51555/ws");
document.getElementsByTagName("body")[0].innerHTML += "<br>connecting";
socket.onopen = () => {
document.getElementsByTagName("body")[0].innerHTML += "<br>connected";
socket.send("Hi From the Client!")
};
socket.onclose = event => {
document.getElementsByTagName("body")[0].innerHTML += "<br>disconnected";
socket.send("Client Closed!")
};
socket.onerror = error => {
document.getElementsByTagName("body")[0].innerHTML += "<br>error:" + error;
console.log("Socket Error: ", error);
};
socket.onmessage = function(msgevent) {
document.getElementsByTagName("body")[0].innerHTML += "<br>got:" + msgevent.data;
};
</script>
</body>
</html>

0
testdata/rproxy3server.crt vendored Executable file → Normal file
View File

0
testdata/rproxy3server.key vendored Executable file → Normal file
View File

0
testdata/rproxy3server.pkcs12 vendored Executable file → Normal file
View File

76
testdata/ws.go vendored
View File

@@ -1,76 +0,0 @@
package main
import (
"fmt"
"io/ioutil"
"log"
"net/http"
"time"
"github.com/gorilla/websocket"
)
func homePage(w http.ResponseWriter, r *http.Request) {
b, _ := ioutil.ReadFile("./index.html")
fmt.Fprintf(w, "%s", b)
}
func setupRoutes() {
http.HandleFunc("/", homePage)
http.HandleFunc("/ws", wsEndpoint)
}
func main() {
fmt.Println("Hello World")
setupRoutes()
log.Fatal(http.ListenAndServe(":8080", nil))
}
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
func reader(conn *websocket.Conn) {
for {
// read in a message
messageType, p, err := conn.ReadMessage()
if err != nil {
log.Println(err)
return
}
// print out that message for clarity
fmt.Println(string(p))
if err := conn.WriteMessage(messageType, p); err != nil {
log.Println(err)
return
}
}
}
func wsEndpoint(w http.ResponseWriter, r *http.Request) {
upgrader.CheckOrigin = func(r *http.Request) bool { return true }
// upgrade this connection to a WebSocket
// connection
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
}
log.Println("Client Connected")
// listen indefinitely for new messages coming
// through on our WebSocket connection
go reader(ws)
for {
log.Println("writing...")
err = ws.WriteMessage(1, []byte("Hi Client!"))
log.Println("written")
if err != nil {
log.Println(err)
return
}
time.Sleep(time.Second)
}
}

25
vendor/vendor.json vendored Normal file
View File

@@ -0,0 +1,25 @@
{
"comment": "",
"ignore": "test",
"package": [
{
"checksumSHA1": "GtamqiJoL7PGHsN454AoffBFMa8=",
"path": "golang.org/x/net/context",
"revision": "65e2d4e15006aab9813ff8769e768bbf4bb667a0",
"revisionTime": "2019-02-01T23:59:58Z"
},
{
"checksumSHA1": "HoCvrd3hEhsFeBOdEw7cbcfyk50=",
"path": "golang.org/x/time/rate",
"revision": "fbb02b2291d28baffd63558aa44b4b56f178d650",
"revisionTime": "2018-04-12T16:56:04Z"
},
{
"checksumSHA1": "QqDq2x8XOU7IoOR98Cx1eiV5QY8=",
"path": "gopkg.in/yaml.v2",
"revision": "51d6538a90f86fe93ac480b35f37b2be17fef232",
"revisionTime": "2018-11-15T11:05:04Z"
}
],
"rootPath": "local/rproxy3"
}