Compare commits

..

No commits in common. "master" and "v0.8.3" have entirely different histories.

21 changed files with 517 additions and 383 deletions

115
.config/config.go Executable file
View File

@ -0,0 +1,115 @@
package config
import (
"local/rproxy3/storage/packable"
"log"
"strconv"
"strings"
)
func GetPort() string {
v := packable.NewString()
conf.Get(nsConf, flagPort, v)
return ":" + strings.TrimPrefix(v.String(), ":")
}
func GetRoutes() map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRoutes, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
if len(v) == 0 {
return m
}
from := v[:strings.Index(v, ":")]
to := v[strings.Index(v, ":")+1:]
m[from] = to
}
return m
}
func GetTCP() (string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagTCP, v)
tcpAddr := v.String()
return tcpAddr, notEmpty(tcpAddr)
}
func GetSSL() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagCert, v)
certPath := v.String()
conf.Get(nsConf, flagKey, v)
keyPath := v.String()
return certPath, keyPath, notEmpty(certPath, keyPath)
}
func GetAuth() (string, string, bool) {
v := packable.NewString()
conf.Get(nsConf, flagUser, v)
user := v.String()
conf.Get(nsConf, flagPass, v)
pass := v.String()
return user, pass, notEmpty(user, pass)
}
func notEmpty(s ...string) bool {
for i := range s {
if s[i] == "" || s[i] == "/dev/null" {
return false
}
}
return true
}
func GetRate() (int, int) {
r := packable.NewString()
conf.Get(nsConf, flagRate, r)
b := packable.NewString()
conf.Get(nsConf, flagBurst, b)
rate, err := strconv.Atoi(r.String())
if err != nil {
log.Printf("illegal rate: %v", err)
rate = 5
}
burst, _ := strconv.Atoi(b.String())
if err != nil {
log.Printf("illegal burst: %v", err)
burst = 5
}
return rate, burst
}
func GetTimeout() int {
t := packable.NewString()
conf.Get(nsConf, flagTimeout, t)
timeout, err := strconv.Atoi(t.String())
if err != nil || timeout == 5 {
return 5
}
return timeout
}
func GetRewrites(hostMatch string) map[string]string {
v := packable.NewString()
conf.Get(nsConf, flagRewrites, v)
m := make(map[string]string)
for _, v := range strings.Split(v.String(), ",") {
vs := strings.Split(v, ":")
if len(v) < 3 {
continue
}
host := vs[0]
if host != hostMatch {
continue
}
from := vs[1]
to := strings.Join(vs[2:], ":")
m[from] = to
}
return m
}

161
.config/new.go Executable file
View File

@ -0,0 +1,161 @@
package config
import (
"flag"
"io/ioutil"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
"log"
"os"
"strings"
yaml "gopkg.in/yaml.v2"
)
const nsConf = "configuration"
const flagPort = "p"
const flagRoutes = "r"
const flagConf = "c"
const flagCert = "crt"
const flagTCP = "tcp"
const flagKey = "key"
const flagUser = "user"
const flagPass = "pass"
const flagRate = "rate"
const flagBurst = "burst"
const flagTimeout = "timeout"
const flagRewrites = "rw"
var conf = storage.NewMap()
type toBind struct {
flag string
value *string
}
type fileConf struct {
Port string `yaml:"p"`
Routes []string `yaml:"r"`
CertPath string `yaml:"crt"`
TCPPath string `yaml:"tcp"`
KeyPath string `yaml:"key"`
Username string `yaml:"user"`
Password string `yaml:"pass"`
Rate string `yaml:"rate"`
Burst string `yaml:"burst"`
Timeout string `yaml:"timeout"`
Rewrites []string `yaml:"rw"`
}
func Init() error {
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
log.SetFlags(log.Ltime | log.Lshortfile)
if err := fromFile(); err != nil {
return err
}
if err := fromFlags(); err != nil {
return err
}
return nil
}
func fromFile() error {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
defer func() {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
}()
flag.String(flagConf, "/dev/null", "yaml config file path")
flag.Parse()
confFlag := flag.Lookup(flagConf)
if confFlag == nil || confFlag.Value.String() == "" {
return nil
}
confBytes, err := ioutil.ReadFile(confFlag.Value.String())
if err != nil {
return err
}
var c fileConf
if err := yaml.Unmarshal(confBytes, &c); err != nil {
return err
}
if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil {
return err
}
if err := conf.Set(nsConf, flagCert, packable.NewString(c.CertPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTCP, packable.NewString(c.TCPPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagKey, packable.NewString(c.KeyPath)); err != nil {
return err
}
if err := conf.Set(nsConf, flagUser, packable.NewString(c.Username)); err != nil {
return err
}
if err := conf.Set(nsConf, flagPass, packable.NewString(c.Password)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRate, packable.NewString(c.Rate)); err != nil {
return err
}
if err := conf.Set(nsConf, flagBurst, packable.NewString(c.Burst)); err != nil {
return err
}
if err := conf.Set(nsConf, flagTimeout, packable.NewString(c.Timeout)); err != nil {
return err
}
if err := conf.Set(nsConf, flagRewrites, packable.NewString(strings.Join(c.Rewrites, ","))); err != nil {
return err
}
return nil
}
func fromFlags() error {
binds := make([]toBind, 0)
binds = append(binds, addFlag(flagPort, "51555", "port to bind to"))
binds = append(binds, addFlag(flagConf, "", "configuration file path"))
binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port"))
binds = append(binds, addFlag(flagCert, "", "path to .crt"))
binds = append(binds, addFlag(flagTCP, "", "tcp addr"))
binds = append(binds, addFlag(flagKey, "", "path to .key"))
binds = append(binds, addFlag(flagUser, "", "basic auth username"))
binds = append(binds, addFlag(flagPass, "", "basic auth password"))
binds = append(binds, addFlag(flagRate, "100", "rate limit per second"))
binds = append(binds, addFlag(flagBurst, "100", "rate limit burst"))
binds = append(binds, addFlag(flagTimeout, "30", "seconds to wait for limiter"))
binds = append(binds, addFlag(flagRewrites, "", "comma-separated from:replace:replacement:oauth to rewrite in response bodies"))
flag.Parse()
for _, bind := range binds {
confFlag := flag.Lookup(bind.flag)
if confFlag == nil || confFlag.Value.String() == "" {
continue
}
if err := conf.Set(nsConf, bind.flag, packable.NewString(*bind.value)); err != nil {
return err
}
}
return nil
}
func addFlag(key, def, help string) toBind {
def = getFlagOrDefault(key, def)
v := flag.String(key, def, help)
return toBind{
flag: key,
value: v,
}
}
func getFlagOrDefault(key, def string) string {
v := packable.NewString()
if err := conf.Get(nsConf, key, v); err != nil {
return def
}
return v.String()
}

46
.config/new_test.go Executable file
View File

@ -0,0 +1,46 @@
package config
import (
"flag"
"os"
"testing"
)
func TestInit(t *testing.T) {
was := os.Args[:]
os.Args = []string{"program"}
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
defer func() {
os.Args = was[:]
}()
if err := Init(); err != nil {
t.Errorf("failed to init: %v", err)
}
}
func TestFromFile(t *testing.T) {
was := os.Args[:]
os.Args = []string{"program"}
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
defer func() {
os.Args = was[:]
}()
if err := fromFile(); err != nil {
t.Errorf("failed from file: %v", err)
}
}
func TestFromFlags(t *testing.T) {
was := os.Args[:]
os.Args = []string{"program"}
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
defer func() {
os.Args = was[:]
}()
if err := fromFlags(); err != nil {
t.Errorf("failed from flags: %v", err)
}
}

View File

@ -1,5 +0,0 @@
#! /usr/bin/env bash
export CGO_ENABLED=1
export CC=x86_64-linux-musl-gcc
exec go build -ldflags="-linkmode external -extldflags '-static'" -o exec-rproxy3

BIN
config/.config.go.un~ Normal file

Binary file not shown.

BIN
config/.new.go.un~ Normal file

Binary file not shown.

View File

@ -4,20 +4,15 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"regexp"
"strings" "strings"
"time" "time"
"gopkg.in/yaml.v2"
) )
type Proxy struct { type Proxy struct {
Auth string
From string
To string To string
} }
func parseOneProxyCSV(s string) (string, Proxy) { func parseProxy(s string) (string, Proxy) {
p := Proxy{} p := Proxy{}
key := "" key := ""
l := strings.Split(s, ",") l := strings.Split(s, ",")
@ -30,16 +25,22 @@ func parseOneProxyCSV(s string) (string, Proxy) {
return key, p return key, p
} }
func GetAuthelia() (string, bool) {
authelia := conf.Get("authelia").GetString()
return authelia, authelia != ""
}
func GetBOAuthZ() (string, bool) {
boauthz := conf.Get("oauth").GetString()
return boauthz, boauthz != ""
}
func GetAuth() (string, string, bool) { func GetAuth() (string, string, bool) {
user := conf.Get("user").GetString() user := conf.Get("user").GetString()
pass := conf.Get("pass").GetString() pass := conf.Get("pass").GetString()
return user, pass, user != "" && pass != "" return user, pass, user != "" && pass != ""
} }
func GetTrim() string {
return conf.Get("trim").GetString()
}
func GetPort() string { func GetPort() string {
port := conf.Get("p").GetInt() port := conf.Get("p").GetInt()
return ":" + fmt.Sprint(port) return ":" + fmt.Sprint(port)
@ -58,31 +59,11 @@ func GetRate() (int, int) {
} }
func GetRoutes() map[string]Proxy { func GetRoutes() map[string]Proxy {
s := conf.Get("proxy2").GetString() list := conf.Get("proxy").GetString()
var dict map[string]string
if err := yaml.Unmarshal([]byte(s), &dict); err == nil && len(s) > 0 {
pattern := regexp.MustCompile(`(([^:]*):)?(([^:]*):)?([a-z0-9]*:.*)`)
result := map[string]Proxy{}
for k, v := range dict {
submatches := pattern.FindAllStringSubmatch(v, -1)
log.Printf("%+v", submatches)
result[k] = Proxy{
Auth: submatches[0][2],
From: submatches[0][4],
To: submatches[0][5],
}
}
return result
}
return getRoutesCSV()
}
func getRoutesCSV() map[string]Proxy {
list := conf.Get("proxy2").GetString()
definitions := strings.Split(list, ",,") definitions := strings.Split(list, ",,")
routes := make(map[string]Proxy) routes := make(map[string]Proxy)
for _, definition := range definitions { for _, definition := range definitions {
k, v := parseOneProxyCSV(definition) k, v := parseProxy(definition)
routes[k] = v routes[k] = v
} }
return routes return routes

View File

@ -2,13 +2,12 @@ package config
import ( import (
"fmt" "fmt"
"local/args"
"local/logb"
"log" "log"
"os" "os"
"strings" "strings"
"time" "time"
"gitea.bel.blue/local/args"
"gitea.bel.blue/local/logb"
) )
var conf *args.ArgSet var conf *args.ArgSet
@ -48,10 +47,11 @@ func parseArgs() (*args.ArgSet, error) {
as.Append(args.BOOL, "compress", "enable compression", true) as.Append(args.BOOL, "compress", "enable compression", true)
as.Append(args.STRING, "crt", "path to crt for ssl", "") as.Append(args.STRING, "crt", "path to crt for ssl", "")
as.Append(args.STRING, "key", "path to key for ssl", "") as.Append(args.STRING, "key", "path to key for ssl", "")
as.Append(args.STRING, "trim", "path prefix to trim, like '/abc' to change '/abc/def' to '/def'", "")
as.Append(args.STRING, "tcp", "address for tcp only tunnel", "") as.Append(args.STRING, "tcp", "address for tcp only tunnel", "")
as.Append(args.DURATION, "timeout", "timeout for tunnel", time.Minute) as.Append(args.DURATION, "timeout", "timeout for tunnel", time.Minute)
as.Append(args.STRING, "proxy2", "double-comma separated 'from,scheme://to.tld:port,,' OR a yaml dictionary of 'from: (password:)scheme://to.tld:port'", "") as.Append(args.STRING, "proxy", "double-comma separated (+ if auth)from,scheme://to.tld:port,,", "")
as.Append(args.STRING, "oauth", "url for boauthz", "")
as.Append(args.STRING, "authelia", "url for authelia", "")
as.Append(args.STRING, "cors", "json dict key:true for keys to set CORS permissive headers, like {\"from\":true}", "{}") as.Append(args.STRING, "cors", "json dict key:true for keys to set CORS permissive headers, like {\"from\":true}", "{}")
as.Append(args.STRING, "nopath", "json dict key:true for keys to remove all path info from forwarded request, like -cors", "{}") as.Append(args.STRING, "nopath", "json dict key:true for keys to remove all path info from forwarded request, like -cors", "{}")
as.Append(args.STRING, "level", "log level", "info") as.Append(args.STRING, "level", "log level", "info")

View File

@ -7,7 +7,5 @@ crt: ""
key: "" key: ""
tcp: "" tcp: ""
timeout: 1m timeout: 1m
proxy2: | proxy: a,http://localhost:41912,,+b,http://localhost:41912
a: http://localhost:41912
b: password:http://localhost:41912
oauth: http://localhost:23456 oauth: http://localhost:23456

17
go.mod
View File

@ -1,17 +0,0 @@
module gitea.bel.blue/local/rproxy3
go 1.18
require (
gitea.bel.blue/local/args v0.0.0-20251121001304-83c57f856714
gitea.bel.blue/local/logb v0.0.0-20251121001353-d45d53fbaae9
github.com/google/uuid v1.3.0
golang.org/x/time v0.1.0
)
require gopkg.in/yaml.v2 v2.4.0
require (
github.com/kr/pretty v0.1.0 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
)

18
go.sum
View File

@ -1,18 +0,0 @@
gitea.bel.blue/local/args v0.0.0-20251121001304-83c57f856714 h1:JHV86INH1QmPJoyIhdrDLJq7OKta+fJAwbK0pnxI4Hc=
gitea.bel.blue/local/args v0.0.0-20251121001304-83c57f856714/go.mod h1:GCzui3GPhOgKgGYNqtW55YkI3vIWCQEHPydGjFhaXV0=
gitea.bel.blue/local/logb v0.0.0-20251121001353-d45d53fbaae9 h1:lBkQPYgWZnPxt6CvsSwVh9EZtuvi2lIbGOHPqe/gn1Y=
gitea.bel.blue/local/logb v0.0.0-20251121001353-d45d53fbaae9/go.mod h1:+8sJb8UksdadKy43czL7/3TcfBwCkuYT6hFY+RaxP48=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA=
golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=

View File

@ -1,8 +1,8 @@
package main package main
import ( import (
"gitea.bel.blue/local/rproxy3/config" "local/rproxy3/config"
"gitea.bel.blue/local/rproxy3/server" "local/rproxy3/server"
) )
func main() { func main() {

View File

@ -34,7 +34,7 @@ func TestHTTPSMain(t *testing.T) {
"username", "username",
"-pass", "-pass",
"password", "password",
"-proxy2", "-proxy",
"hello," + addr, "hello," + addr,
"-crt", "-crt",
"./testdata/rproxy3server.crt", "./testdata/rproxy3server.crt",
@ -89,7 +89,7 @@ func TestHTTPMain(t *testing.T) {
"username", "username",
"-pass", "-pass",
"password", "password",
"-proxy2", "-proxy",
"hello," + addr, "hello," + addr,
} }
main() main()

View File

@ -1,8 +1,8 @@
package server package server
import ( import (
"gitea.bel.blue/local/rproxy3/config" "local/rproxy3/config"
"gitea.bel.blue/local/rproxy3/storage" "local/rproxy3/storage"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
@ -17,5 +17,7 @@ func New() *Server {
altaddr: altport, altaddr: altport,
limiter: rate.NewLimiter(rate.Limit(r), b), limiter: rate.NewLimiter(rate.Limit(r), b),
} }
_, server.auth.BOAuthZ = config.GetBOAuthZ()
_, server.auth.Authelia = config.GetAuthelia()
return server return server
} }

View File

@ -4,14 +4,12 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"io" "io"
"local/rproxy3/storage/packable"
"log" "log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"strings" "strings"
"gitea.bel.blue/local/rproxy3/config"
"gitea.bel.blue/local/rproxy3/storage/packable"
) )
type redirPurge struct { type redirPurge struct {
@ -27,7 +25,6 @@ type rewrite struct {
func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
newURL, err := s.lookup(mapKey(r.Host)) newURL, err := s.lookup(mapKey(r.Host))
r.URL.Path = strings.TrimPrefix(r.URL.Path, config.GetTrim())
var transport http.RoundTripper var transport http.RoundTripper
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
transport = &redirPurge{ transport = &redirPurge{
@ -52,16 +49,10 @@ func (s *Server) lookup(host string) (*url.URL, error) {
return v.URL(), err return v.URL(), err
} }
func (s *Server) lookupAuth(host string) (string, error) { func (s *Server) lookupAuth(host string) (bool, error) {
v := packable.NewString() v := packable.NewString()
err := s.db.Get(nsRouting, host+"//auth", v) err := s.db.Get(nsBOAuthZ, host, v)
return v.String(), err return v.String() == "true", err
}
func (s *Server) lookupFrom(host string) (string, error) {
v := packable.NewString()
err := s.db.Get(nsRouting, host+"//from", v)
return v.String(), err
} }
func mapKey(host string) string { func mapKey(host string) string {

View File

@ -1,7 +1,7 @@
package server package server
import ( import (
"gitea.bel.blue/local/rproxy3/config" "local/rproxy3/config"
) )
func (s *Server) Routes() error { func (s *Server) Routes() error {

View File

@ -4,36 +4,34 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"local/logb"
"local/oauth2/oauth2client"
"local/rproxy3/config"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"regexp" "path"
"strconv"
"strings" "strings"
"time" "time"
"gitea.bel.blue/local/rproxy3/config"
"gitea.bel.blue/local/rproxy3/storage"
"gitea.bel.blue/local/rproxy3/storage/packable"
"github.com/google/uuid"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
const nsRouting = "routing" const nsRouting = "routing"
const nsBOAuthZ = "oauth"
type listenerScheme int type listenerScheme int
const ( const (
schemeHTTP listenerScheme = iota schemeHTTP listenerScheme = iota
schemeHTTPS schemeHTTPS listenerScheme = iota
schemeTCP schemeTCP listenerScheme = iota
schemeTCPTLS
) )
func (ls listenerScheme) String() string { func (ls listenerScheme) String() string {
@ -44,8 +42,6 @@ func (ls listenerScheme) String() string {
return "https" return "https"
case schemeTCP: case schemeTCP:
return "tcp" return "tcp"
case schemeTCPTLS:
return "tcptls"
} }
return "" return ""
} }
@ -57,21 +53,21 @@ type Server struct {
username string username string
password string password string
limiter *rate.Limiter limiter *rate.Limiter
auth struct {
BOAuthZ bool
Authelia bool
}
} }
func (s *Server) Route(src string, dst config.Proxy) error { func (s *Server) Route(src string, dst config.Proxy) error {
hasOAuth := strings.HasPrefix(src, "+")
src = strings.TrimPrefix(src, "+") src = strings.TrimPrefix(src, "+")
log.Printf("Adding route %q -> %v...\n", src, dst) log.Printf("Adding route %q -> %v...\n", src, dst)
u, err := url.Parse(dst.To) u, err := url.Parse(dst.To)
if err != nil { if err != nil {
return err return err
} }
if err := s.db.Set(nsRouting, src+"//from", packable.NewString(dst.From)); err != nil { s.db.Set(nsBOAuthZ, src, packable.NewString(fmt.Sprint(hasOAuth)))
return err
}
if err := s.db.Set(nsRouting, src+"//auth", packable.NewString(dst.Auth)); err != nil {
return err
}
return s.db.Set(nsRouting, src, packable.NewURL(u)) return s.db.Set(nsRouting, src, packable.NewURL(u))
} }
@ -104,40 +100,135 @@ func (s *Server) Run() error {
case schemeTCP: case schemeTCP:
addr, _ := config.GetTCP() addr, _ := config.GetTCP()
return s.ServeTCP(addr) return s.ServeTCP(addr)
case schemeTCPTLS:
addr, _ := config.GetTCP()
cert, key, _ := config.GetSSL()
return s.ServeTCPTLS(addr, cert, key)
} }
return errors.New("did not load server") return errors.New("did not load server")
} }
func (s *Server) ServeTCPTLS(addr, c, k string) error { func (s *Server) doAuthelia(foo http.HandlerFunc) http.HandlerFunc {
certificate, err := tls.LoadX509KeyPair(c, k) return func(w http.ResponseWriter, r *http.Request) {
if err != nil { authelia, ok := config.GetAuthelia()
return err if !ok {
panic("howd i get here")
} }
certificates := []tls.Certificate{certificate} url, err := url.Parse(authelia)
listen, err := net.Listen("tcp", s.addr)
if err != nil { if err != nil {
return err panic(fmt.Sprintf("bad config for authelia url: %v", err))
} }
defer listen.Close() url.Path = "/api/verify"
config := &tls.Config{ logb.Verbosef("authelia @ %s", url.String())
Certificates: certificates, req, err := http.NewRequest(http.MethodGet, url.String(), nil)
MinVersion: tls.VersionTLS12, if err != nil {
CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, panic(err.Error())
PreferServerCipherSuites: true, }
CipherSuites: []uint16{ r2 := r.Clone(r.Context())
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, if r2.URL.Host == "" {
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, r2.URL.Host = r2.Host
tls.TLS_RSA_WITH_AES_256_GCM_SHA384, }
tls.TLS_RSA_WITH_AES_256_CBC_SHA, if r2.URL.Scheme == "" {
r2.URL.Scheme = "https"
}
for _, httpreq := range []*http.Request{r, req} {
for k, v := range map[string]string{
"X-Original-Url": r2.URL.String(),
"X-Forwarded-Proto": r2.URL.Scheme,
"X-Forwarded-Host": r2.URL.Host,
"X-Forwarded-Uri": r2.URL.String(),
} {
if _, ok := httpreq.Header[k]; !ok {
logb.Verbosef("authelia header setting %s:%s", k, v)
httpreq.Header.Set(k, v)
}
}
}
if cookie, err := r.Cookie("authelia_session"); err == nil {
logb.Verbosef("authelia session found in cookies; %+v", cookie)
req.AddCookie(cookie)
}
c := &http.Client{
Timeout: time.Minute,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}, },
} }
config.BuildNameToCertificate()
tlsListener := tls.NewListener(listen, config) autheliaKey := mapKey(req.Host)
return s.serveTCP(addr, tlsListener) logb.Verbosef("request to %s is authelia %s? %v", r.Host, autheliaKey, strings.HasPrefix(r.Host, autheliaKey))
if strings.HasPrefix(r.Host, autheliaKey) {
logb.Debugf("no authelia for %s because it has prefix %s", r.Host, autheliaKey)
foo(w, r)
return
}
resp, err := c.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
logb.Debugf(
"authelia: %+v, %+v \n\t-> \n\t(%d) %+v, %+v",
req,
req.Cookies(),
resp.StatusCode,
resp.Header,
resp.Cookies(),
)
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
for k := range resp.Header {
if strings.HasPrefix(k, "Remote-") {
cookie := &http.Cookie{
Name: k,
Value: resp.Header.Get(k),
Path: "/",
SameSite: http.SameSiteLaxMode,
Expires: time.Now().Add(24 * time.Hour * 30),
}
logb.Verbosef("setting authelia cookie in response: %+v", cookie)
http.SetCookie(w, cookie)
logb.Verbosef("setting authelia cookie in request: %+v", cookie)
r.AddCookie(cookie)
}
}
foo(w, r)
return
}
url.Path = ""
q := url.Query()
q.Set("rd", r2.URL.String())
url.RawQuery = q.Encode()
logb.Verbosef("authelia status %d, rd'ing %s", resp.StatusCode, url.String())
http.Redirect(w, r, url.String(), http.StatusFound)
}
}
func (s *Server) doBOAuthZ(foo http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
key := mapKey(r.Host)
rusr, rpwd, ok := config.GetAuth()
if ok {
usr, pwd, ok := r.BasicAuth()
if !ok || rusr != usr || rpwd != pwd {
w.WriteHeader(http.StatusUnauthorized)
log.Printf("denying proxy basic auth")
return
}
}
ok, err := s.lookupAuth(key)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
if url, exists := config.GetBOAuthZ(); ok && exists {
err := oauth2client.Authenticate(url, key, w, r)
if err != nil {
return
}
}
if config.GetNoPath(key) && path.Ext(r.URL.Path) == "" {
r.URL.Path = "/"
}
foo(w, r)
}
} }
func (s *Server) ServeTCP(addr string) error { func (s *Server) ServeTCP(addr string) error {
@ -145,11 +236,6 @@ func (s *Server) ServeTCP(addr string) error {
if err != nil { if err != nil {
return err return err
} }
defer listen.Close()
return s.serveTCP(addr, listen)
}
func (s *Server) serveTCP(addr string, listen net.Listener) error {
for { for {
c, err := listen.Accept() c, err := listen.Accept()
if err != nil { if err != nil {
@ -177,166 +263,45 @@ func pipe(a, b net.Conn) {
func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
r, flush := withMeta(w, r)
defer flush()
ctx, can := context.WithTimeout(r.Context(), time.Duration(config.GetTimeout())) ctx, can := context.WithTimeout(r.Context(), time.Duration(config.GetTimeout()))
defer can() defer can()
if err := s.limiter.Wait(ctx); err != nil { if err := s.limiter.Wait(ctx); err != nil {
pushMeta(r, "explain", "limiter exceeded")
w.WriteHeader(http.StatusTooManyRequests) w.WriteHeader(http.StatusTooManyRequests)
return return
} }
if did := s.doCORS(w, r); did {
if r.URL.Scheme == "https" {
w.Header().Set("X-Forwarded-Proto", "https")
}
w, did := doCORS(w, r)
if did {
pushMeta(r, "explain", "did cors")
return return
} }
if s.auth.BOAuthZ {
if mapKey(r.Host) == "_" { logb.Verbosef("doing boauthz for request to %s", r.URL.String())
s.List(w) s.doBOAuthZ(foo)(w, r)
return } else if s.auth.Authelia {
} logb.Verbosef("doing authelia for request to %s", r.URL.String())
s.doAuthelia(foo)(w, r)
if auth, err := s.lookupAuth(mapKey(r.Host)); err != nil {
log.Printf("failed to lookup auth for %s (%s): %v", r.Host, mapKey(r.Host), err)
w.Header().Set("WWW-Authenticate", "Basic")
http.Error(w, err.Error(), http.StatusUnauthorized)
} else if _, p, _ := r.BasicAuth(); auth != "" && auth != p {
log.Printf("failed to auth: expected %q but got %q", auth, p)
w.Header().Set("WWW-Authenticate", "Basic")
http.Error(w, "unexpected basic auth", http.StatusUnauthorized)
} else if from, err := s.lookupFrom(mapKey(r.Host)); err != nil {
log.Printf("failed to lookup from for %s (%s): %v", r.Host, mapKey(r.Host), err)
http.Error(w, err.Error(), http.StatusBadGateway)
} else if err := assertFrom(from, r.RemoteAddr); err != nil {
log.Printf("failed to from: expected %q but got %q: %v", from, r.RemoteAddr, err)
http.Error(w, "unexpected from", http.StatusUnauthorized)
} else { } else {
foo(w, r) foo(w, r)
} }
} }
} }
func assertFrom(from, remoteAddr string) error {
if from == "" {
return nil
}
pattern := regexp.MustCompile(`[0-9](:[0-9]+)$`).FindStringSubmatchIndex(remoteAddr)
if len(pattern) == 4 {
remoteAddr = remoteAddr[:pattern[2]]
}
remoteIP := net.ParseIP(remoteAddr)
if remoteIP == nil {
return fmt.Errorf("cannot parse remote %q", remoteAddr)
}
_, net, err := net.ParseCIDR(from)
if err != nil {
panic(err)
}
if net.Contains(remoteIP) {
return nil
}
return fmt.Errorf("expected like %q but got like %q", from, remoteAddr)
}
func withMeta(w http.ResponseWriter, r *http.Request) (*http.Request, func()) {
meta := map[string]string{
"ts": strconv.FormatInt(time.Now().Unix(), 10),
"method": r.Method,
"url": r.URL.String(),
"id": uuid.New().String(),
}
w.Header().Set("meta-id", meta["id"])
ctx := r.Context()
ctx = context.WithValue(ctx, "meta", meta)
r = r.WithContext(ctx)
return r, func() {
b, err := json.Marshal(meta)
if err != nil {
panic(err)
}
fmt.Printf("[access] %s\n", b)
}
}
func pushMeta(r *http.Request, k, v string) {
got := r.Context().Value("meta")
if got == nil {
return
}
meta, ok := got.(map[string]string)
if !ok || meta == nil {
return
}
meta[k] = v
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.Pre(s.Proxy)(w, r) s.Pre(s.Proxy)(w, r)
} }
func (s *Server) List(w http.ResponseWriter) { func (s *Server) doCORS(w http.ResponseWriter, r *http.Request) bool {
keys := s.db.Keys(nsRouting)
hostURL := map[string]string{}
hostFrom := map[string]string{}
for _, key := range keys {
u, _ := s.lookup(key)
if u != nil && strings.TrimSuffix(key, "//auth") == key {
hostURL[key] = u.String()
}
if u != nil && strings.TrimSuffix(key, "//from") == key {
hostFrom[key] = u.String()
}
}
json.NewEncoder(w).Encode(map[string]any{
"hostsToURLs": hostURL,
"hostsToFrom": hostFrom,
})
}
type corsResponseWriter struct {
r *http.Request
http.ResponseWriter
}
func (cb corsResponseWriter) WriteHeader(code int) {
cb.Header().Set("Access-Control-Allow-Origin", "*")
cb.Header().Set("Access-Control-Allow-Headers", "X-Auth-Token, content-type, Content-Type")
cb.ResponseWriter.WriteHeader(code)
pushMeta(cb.r, "cors", "wrote headers")
}
func doCORS(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, bool) {
key := mapKey(r.Host) key := mapKey(r.Host)
if !config.GetCORS(key) { if !config.GetCORS(key) {
return w, false return false
} }
pushMeta(r, "do-cors", "enabled for key") w.Header().Set("Access-Control-Allow-Origin", "*")
return _doCORS(w, r) w.Header().Set("Access-Control-Allow-Headers", "X-Auth-Token, content-type, Content-Type")
if r.Method != "OPTIONS" {
return false
} }
w.Header().Set("Content-Length", "0")
func _doCORS(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, bool) { w.Header().Set("Content-Type", "text/plain")
w2 := corsResponseWriter{r: r, ResponseWriter: w} w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE")
if r.Method != http.MethodOptions { return true
pushMeta(r, "-do-cors", "not options")
return w2, false
}
pushMeta(r, "-do-cors", "options")
w2.Header().Set("Content-Length", "0")
w2.Header().Set("Content-Type", "text/plain")
w2.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE")
w2.WriteHeader(http.StatusOK)
return w2, true
} }
func getProxyAuth(r *http.Request) (string, string) { func getProxyAuth(r *http.Request) (string, string) {
@ -378,15 +343,11 @@ func (s *Server) alt() {
func getScheme() listenerScheme { func getScheme() listenerScheme {
scheme := schemeHTTP scheme := schemeHTTP
_, _, ssl := config.GetSSL() if _, _, ok := config.GetSSL(); ok {
if ssl {
scheme = schemeHTTPS scheme = schemeHTTPS
} }
if _, ok := config.GetTCP(); ok { if _, ok := config.GetTCP(); ok {
scheme = schemeTCP scheme = schemeTCP
if ssl {
scheme = schemeTCPTLS
}
} }
return scheme return scheme
} }

View File

@ -3,19 +3,17 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"local/rproxy3/config"
"local/rproxy3/storage"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"gitea.bel.blue/local/rproxy3/config"
"gitea.bel.blue/local/rproxy3/storage"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
func TestServerStart(t *testing.T) { func TestServerStart(t *testing.T) {
return // depends on etc hosts
server := mockServer() server := mockServer()
p := config.Proxy{ p := config.Proxy{
@ -68,69 +66,3 @@ func TestServerRoute(t *testing.T) {
t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code) t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code)
} }
} }
func TestCORS(t *testing.T) {
t.Run(http.MethodOptions, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodOptions, "/", nil)
w2, did := _doCORS(w, r)
w2.WriteHeader(300)
if !did {
t.Error("didnt do on options")
}
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Error("didnt set origina")
}
if w.Header().Get("Access-Control-Allow-Methods") != "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE" {
t.Error("didnt set allow methods")
}
})
t.Run(http.MethodGet, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
w2, did := _doCORS(w, r)
w2.Header().Set("a", "b")
w2.Header().Set("Access-Control-Allow-Origin", "NO")
w2.WriteHeader(300)
if did {
t.Error("did cors on options")
}
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Error("didnt set origina")
} else if len(w.Header()["Access-Control-Allow-Origin"]) != 1 {
t.Error(w.Header())
}
if w.Header().Get("Access-Control-Allow-Methods") != "" {
t.Error("did set allow methods")
}
})
}
func TestAssertFrom(t *testing.T) {
cases := map[string]struct {
from string
remote string
err bool
}{
"empty": {},
"ipv6 localhost": {
from: "::1/128",
remote: "::1:12345",
},
"ipv4 localhost": {
from: "127.0.0.1/32",
remote: "127.0.0.1:12345",
},
}
for name, d := range cases {
c := d
t.Run(name, func(t *testing.T) {
err := assertFrom(c.from, c.remote)
got := err != nil
if got != c.err {
t.Errorf("expected err=%v but got %v", c.err, err)
}
})
}
}

View File

@ -2,8 +2,7 @@ package storage
import ( import (
"errors" "errors"
"local/rproxy3/storage/packable"
"gitea.bel.blue/local/rproxy3/storage/packable"
) )
var ErrNotFound = errors.New("not found") var ErrNotFound = errors.New("not found")
@ -11,6 +10,5 @@ var ErrNotFound = errors.New("not found")
type DB interface { type DB interface {
Get(string, string, packable.Packable) error Get(string, string, packable.Packable) error
Set(string, string, packable.Packable) error Set(string, string, packable.Packable) error
Keys(string) []string
Close() error Close() error
} }

View File

@ -1,10 +1,9 @@
package storage package storage
import ( import (
"local/rproxy3/storage/packable"
"os" "os"
"testing" "testing"
"gitea.bel.blue/local/rproxy3/storage/packable"
) )
func TestDB(t *testing.T) { func TestDB(t *testing.T) {

View File

@ -2,8 +2,7 @@ package storage
import ( import (
"fmt" "fmt"
"local/rproxy3/storage/packable"
"gitea.bel.blue/local/rproxy3/storage/packable"
) )
type Map map[string]map[string][]byte type Map map[string]map[string][]byte
@ -41,15 +40,6 @@ func (m Map) Close() error {
return nil return nil
} }
func (m Map) Keys(ns string) []string {
m2, _ := m[ns]
result := make([]string, 0, len(m2))
for k := range m2 {
result = append(result, k)
}
return result
}
func (m Map) Get(ns, key string, value packable.Packable) error { func (m Map) Get(ns, key string, value packable.Packable) error {
if _, ok := m[ns]; !ok { if _, ok := m[ns]; !ok {
m[ns] = make(map[string][]byte) m[ns] = make(map[string][]byte)