package server import ( "context" "crypto/tls" "encoding/base64" "errors" "fmt" "io" "local/logb" "local/oauth2/oauth2client" "local/rproxy3/config" "local/rproxy3/storage" "local/rproxy3/storage/packable" "log" "net" "net/http" "net/url" "path" "strings" "time" "golang.org/x/time/rate" ) const nsRouting = "routing" const nsBOAuthZ = "oauth" type listenerScheme int const ( schemeHTTP listenerScheme = iota schemeHTTPS listenerScheme = iota schemeTCP listenerScheme = iota ) func (ls listenerScheme) String() string { switch ls { case schemeHTTP: return "http" case schemeHTTPS: return "https" case schemeTCP: return "tcp" } return "" } type Server struct { db storage.DB addr string altaddr string username string password string limiter *rate.Limiter auth struct { BOAuthZ bool Authelia bool } } func (s *Server) Route(src string, dst config.Proxy) error { hasOAuth := strings.HasPrefix(src, "+") src = strings.TrimPrefix(src, "+") log.Printf("Adding route %q -> %v...\n", src, dst) u, err := url.Parse(dst.To) if err != nil { return err } s.db.Set(nsBOAuthZ, src, packable.NewString(fmt.Sprint(hasOAuth))) return s.db.Set(nsRouting, src, packable.NewURL(u)) } func (s *Server) Run() error { go s.alt() scheme := getScheme() log.Printf("Listening for %v on %v...\n", scheme, s.addr) switch scheme { case schemeHTTP: return http.ListenAndServe(s.addr, s) case schemeHTTPS: c, k, _ := config.GetSSL() httpsServer := &http.Server{ Addr: s.addr, Handler: s, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_CBC_SHA, }, }, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0), } return httpsServer.ListenAndServeTLS(c, k) case schemeTCP: addr, _ := config.GetTCP() return s.ServeTCP(addr) } return errors.New("did not load server") } func (s *Server) doAuthelia(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authelia, ok := config.GetAuthelia() if !ok { panic("howd i get here") } url, err := url.Parse(authelia) if err != nil { panic(fmt.Sprintf("bad config for authelia url: %v", err)) } url.Path = "/api/verify" req, err := http.NewRequest(http.MethodGet, url.String(), nil) if err != nil { panic(err.Error()) } r2 := r.Clone(r.Context()) if r2.URL.Host == "" { r2.URL.Host = r2.Host } if r2.URL.Scheme == "" { r2.URL.Scheme = "https" } for _, httpreq := range []*http.Request{r, req} { for k, v := range map[string]string{ "X-Original-Url": r2.URL.String(), "X-Forwarded-Proto": r2.URL.Scheme, "X-Forwarded-Host": r2.URL.Host, "X-Forwarded-Uri": r2.URL.String(), } { if _, ok := httpreq.Header[k]; !ok { httpreq.Header.Set(k, v) } } } if cookie, err := r.Cookie("authelia_session"); err == nil { req.AddCookie(cookie) } c := &http.Client{ Timeout: time.Minute, Transport: &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, } autheliaKey := mapKey(req.Host) if strings.HasPrefix(r.Host, autheliaKey) { logb.Debugf("no authelia for %s because it has prefix %s", r.Host, autheliaKey) foo(w, r) return } resp, err := c.Do(req) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } logb.Debugf( "authelia: %+v, %+v \n\t-> \n\t(%d) %+v, %+v", req, req.Cookies(), resp.StatusCode, resp.Header, resp.Cookies(), ) defer resp.Body.Close() if resp.StatusCode == http.StatusOK { foo(w, r) return } url.Path = "" q := url.Query() q.Set("rd", r2.URL.String()) url.RawQuery = q.Encode() http.Redirect(w, r, url.String(), http.StatusFound) } } func (s *Server) doBOAuthZ(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { key := mapKey(r.Host) rusr, rpwd, ok := config.GetAuth() if ok { usr, pwd, ok := r.BasicAuth() if !ok || rusr != usr || rpwd != pwd { w.WriteHeader(http.StatusUnauthorized) log.Printf("denying proxy basic auth") return } } ok, err := s.lookupAuth(key) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } if url, exists := config.GetBOAuthZ(); ok && exists { err := oauth2client.Authenticate(url, key, w, r) if err != nil { return } } if config.GetNoPath(key) && path.Ext(r.URL.Path) == "" { r.URL.Path = "/" } foo(w, r) } } func (s *Server) ServeTCP(addr string) error { listen, err := net.Listen("tcp", s.addr) if err != nil { return err } for { c, err := listen.Accept() if err != nil { return err } go func(c net.Conn) { d, err := net.Dial("tcp", addr) if err != nil { log.Println(err) return } go pipe(c, d) go pipe(d, c) }(c) } } func pipe(a, b net.Conn) { log.Println("open pipe") defer log.Println("close pipe") defer a.Close() defer b.Close() io.Copy(a, b) } func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, can := context.WithTimeout(r.Context(), time.Duration(config.GetTimeout())) defer can() if err := s.limiter.Wait(ctx); err != nil { w.WriteHeader(http.StatusTooManyRequests) return } if did := s.doCORS(w, r); did { return } if s.auth.BOAuthZ { s.doBOAuthZ(foo)(w, r) } else if s.auth.Authelia { s.doAuthelia(foo)(w, r) } else { foo(w, r) } } } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.Pre(s.Proxy)(w, r) } func (s *Server) doCORS(w http.ResponseWriter, r *http.Request) bool { key := mapKey(r.Host) if !config.GetCORS(key) { return false } w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "X-Auth-Token, content-type, Content-Type") if r.Method != "OPTIONS" { return false } w.Header().Set("Content-Length", "0") w.Header().Set("Content-Type", "text/plain") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE") return true } func getProxyAuth(r *http.Request) (string, string) { proxyAuthHeader := r.Header.Get("Proxy-Authorization") proxyAuthB64 := strings.TrimPrefix(proxyAuthHeader, "Basic ") proxyAuthBytes, _ := base64.StdEncoding.DecodeString(proxyAuthB64) proxyAuth := string(proxyAuthBytes) if !strings.Contains(proxyAuth, ":") { return "", "" } proxyAuthSplit := strings.Split(proxyAuth, ":") return proxyAuthSplit[0], proxyAuthSplit[1] } func (s *Server) alt() { switch getScheme() { case schemeHTTP: case schemeHTTPS: default: return } foo := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = getScheme().String() if hostname := r.URL.Hostname(); hostname != "" { r.URL.Host = r.URL.Hostname() + s.addr } else if hostname := r.URL.Host; hostname != "" { r.URL.Host = r.URL.Host + s.addr } else { u := url.URL{Host: r.Host} r.URL.Host = u.Hostname() + s.addr } http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) }) log.Println("redirecting from", s.altaddr) if err := http.ListenAndServe(s.altaddr, foo); err != nil { panic(err) } } func getScheme() listenerScheme { scheme := schemeHTTP if _, _, ok := config.GetSSL(); ok { scheme = schemeHTTPS } if _, ok := config.GetTCP(); ok { scheme = schemeTCP } return scheme }