189 lines
4.2 KiB
Go
Executable File
189 lines
4.2 KiB
Go
Executable File
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"local/oauth2/oauth2client"
|
|
"local/rproxy3/config"
|
|
"local/rproxy3/storage"
|
|
"local/rproxy3/storage/packable"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
const nsRouting = "routing"
|
|
const nsBOAuthZ = "oauth"
|
|
|
|
type listenerScheme int
|
|
|
|
const (
|
|
schemeHTTP listenerScheme = iota
|
|
schemeHTTPS listenerScheme = iota
|
|
schemeTCP listenerScheme = iota
|
|
)
|
|
|
|
func (ls listenerScheme) String() string {
|
|
switch ls {
|
|
case schemeHTTP:
|
|
return "http"
|
|
case schemeHTTPS:
|
|
return "https"
|
|
case schemeTCP:
|
|
return "tcp"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
type Server struct {
|
|
db storage.DB
|
|
addr string
|
|
username string
|
|
password string
|
|
limiter *rate.Limiter
|
|
}
|
|
|
|
func (s *Server) Route(src string, dst config.Proxy) error {
|
|
log.Printf("Adding route %q -> %v...\n", src, dst)
|
|
u, err := url.Parse(dst.To)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.db.Set(nsBOAuthZ, src, packable.NewString(fmt.Sprint(dst.BOAuthZ)))
|
|
return s.db.Set(nsRouting, src, packable.NewURL(u))
|
|
}
|
|
|
|
func (s *Server) Run() error {
|
|
scheme := schemeHTTP
|
|
if _, _, ok := config.GetSSL(); ok {
|
|
scheme = schemeHTTPS
|
|
}
|
|
if _, ok := config.GetTCP(); ok {
|
|
scheme = schemeTCP
|
|
}
|
|
log.Printf("Listening for %v on %v...\n", scheme, s.addr)
|
|
switch scheme {
|
|
case schemeHTTP:
|
|
log.Printf("Serve http")
|
|
return http.ListenAndServe(s.addr, s)
|
|
case schemeHTTPS:
|
|
log.Printf("Serve https")
|
|
c, k, _ := config.GetSSL()
|
|
httpsServer := &http.Server{
|
|
Addr: s.addr,
|
|
Handler: s,
|
|
TLSConfig: &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256},
|
|
PreferServerCipherSuites: true,
|
|
CipherSuites: []uint16{
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
|
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
|
},
|
|
},
|
|
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0),
|
|
}
|
|
return httpsServer.ListenAndServeTLS(c, k)
|
|
case schemeTCP:
|
|
log.Printf("Serve tcp")
|
|
addr, _ := config.GetTCP()
|
|
return s.ServeTCP(addr)
|
|
}
|
|
return errors.New("did not load server")
|
|
}
|
|
|
|
func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
rusr, rpwd, ok := config.GetAuth()
|
|
if ok {
|
|
usr, pwd, ok := r.BasicAuth()
|
|
if !ok || rusr != usr || rpwd != pwd {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
log.Printf("denying proxy basic auth")
|
|
return
|
|
}
|
|
}
|
|
ok, err := s.lookupBOAuthZ(mapKey(r.Host))
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if boauthz, useoauth := config.GetBOAuthZ(); ok && useoauth {
|
|
err := oauth2client.Authenticate(boauthz, w, r)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
foo(w, r)
|
|
}
|
|
}
|
|
|
|
func (s *Server) ServeTCP(addr string) error {
|
|
listen, err := net.Listen("tcp", s.addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for {
|
|
c, err := listen.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
go func(c net.Conn) {
|
|
d, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
go pipe(c, d)
|
|
go pipe(d, c)
|
|
}(c)
|
|
}
|
|
}
|
|
|
|
func pipe(a, b net.Conn) {
|
|
log.Println("open pipe")
|
|
defer log.Println("close pipe")
|
|
defer a.Close()
|
|
defer b.Close()
|
|
io.Copy(a, b)
|
|
}
|
|
|
|
func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, can := context.WithTimeout(r.Context(), time.Second*time.Duration(config.GetTimeout()))
|
|
defer can()
|
|
if err := s.limiter.Wait(ctx); err != nil {
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
s.doAuth(foo)(w, r)
|
|
}
|
|
}
|
|
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
s.Pre(s.Proxy)(w, r)
|
|
}
|
|
|
|
func getProxyAuth(r *http.Request) (string, string) {
|
|
proxyAuthHeader := r.Header.Get("Proxy-Authorization")
|
|
proxyAuthB64 := strings.TrimPrefix(proxyAuthHeader, "Basic ")
|
|
proxyAuthBytes, _ := base64.StdEncoding.DecodeString(proxyAuthB64)
|
|
proxyAuth := string(proxyAuthBytes)
|
|
if !strings.Contains(proxyAuth, ":") {
|
|
return "", ""
|
|
}
|
|
proxyAuthSplit := strings.Split(proxyAuth, ":")
|
|
return proxyAuthSplit[0], proxyAuthSplit[1]
|
|
}
|