package server import ( "context" "crypto/tls" "encoding/base64" "encoding/json" "errors" "fmt" "io" "log" "net" "net/http" "net/url" "regexp" "strconv" "strings" "time" "gitea.bel.blue/local/rproxy3/config" "gitea.bel.blue/local/rproxy3/storage" "gitea.bel.blue/local/rproxy3/storage/packable" "github.com/google/uuid" "golang.org/x/time/rate" ) const nsRouting = "routing" type listenerScheme int const ( schemeHTTP listenerScheme = iota schemeHTTPS schemeTCP schemeTCPTLS ) func (ls listenerScheme) String() string { switch ls { case schemeHTTP: return "http" case schemeHTTPS: return "https" case schemeTCP: return "tcp" case schemeTCPTLS: return "tcptls" } return "" } type Server struct { db storage.DB addr string altaddr string username string password string limiter *rate.Limiter } func (s *Server) Route(src string, dst config.Proxy) error { src = strings.TrimPrefix(src, "+") log.Printf("Adding route %q -> %v...\n", src, dst) u, err := url.Parse(dst.To) if err != nil { return err } if err := s.db.Set(nsRouting, src+"//from", packable.NewString(dst.From)); err != nil { return err } if err := s.db.Set(nsRouting, src+"//auth", packable.NewString(dst.Auth)); err != nil { return err } return s.db.Set(nsRouting, src, packable.NewURL(u)) } func (s *Server) Run() error { go s.alt() scheme := getScheme() log.Printf("Listening for %v on %v...\n", scheme, s.addr) switch scheme { case schemeHTTP: return http.ListenAndServe(s.addr, s) case schemeHTTPS: c, k, _ := config.GetSSL() httpsServer := &http.Server{ Addr: s.addr, Handler: s, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_CBC_SHA, }, }, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0), } return httpsServer.ListenAndServeTLS(c, k) case schemeTCP: addr, _ := config.GetTCP() return s.ServeTCP(addr) case schemeTCPTLS: addr, _ := config.GetTCP() cert, key, _ := config.GetSSL() return s.ServeTCPTLS(addr, cert, key) } return errors.New("did not load server") } func (s *Server) ServeTCPTLS(addr, c, k string) error { certificate, err := tls.LoadX509KeyPair(c, k) if err != nil { return err } certificates := []tls.Certificate{certificate} listen, err := net.Listen("tcp", s.addr) if err != nil { return err } defer listen.Close() tlsListener, err := tls.NewListener(listen, &tls.Config{ Certificates: certificates, MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_CBC_SHA, }, }) return s.serveTCP(addr, listen) } func (s *Server) ServeTCP(addr string) error { listen, err := net.Listen("tcp", s.addr) if err != nil { return err } defer listen.Close() return s.serveTCP(addr, listen) } func (s *Server) serveTCP(addr string, listen net.Listener) error { for { c, err := listen.Accept() if err != nil { return err } go func(c net.Conn) { d, err := net.Dial("tcp", addr) if err != nil { log.Println(err) return } go pipe(c, d) go pipe(d, c) }(c) } } func pipe(a, b net.Conn) { log.Println("open pipe") defer log.Println("close pipe") defer a.Close() defer b.Close() io.Copy(a, b) } func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { r, flush := withMeta(w, r) defer flush() ctx, can := context.WithTimeout(r.Context(), time.Duration(config.GetTimeout())) defer can() if err := s.limiter.Wait(ctx); err != nil { pushMeta(r, "explain", "limiter exceeded") w.WriteHeader(http.StatusTooManyRequests) return } if r.URL.Scheme == "https" { w.Header().Set("X-Forwarded-Proto", "https") } w, did := doCORS(w, r) if did { pushMeta(r, "explain", "did cors") return } if mapKey(r.Host) == "_" { s.List(w) return } if auth, err := s.lookupAuth(mapKey(r.Host)); err != nil { log.Printf("failed to lookup auth for %s (%s): %v", r.Host, mapKey(r.Host), err) w.Header().Set("WWW-Authenticate", "Basic") http.Error(w, err.Error(), http.StatusUnauthorized) } else if _, p, _ := r.BasicAuth(); auth != "" && auth != p { log.Printf("failed to auth: expected %q but got %q", auth, p) w.Header().Set("WWW-Authenticate", "Basic") http.Error(w, "unexpected basic auth", http.StatusUnauthorized) } else if from, err := s.lookupFrom(mapKey(r.Host)); err != nil { log.Printf("failed to lookup from for %s (%s): %v", r.Host, mapKey(r.Host), err) http.Error(w, err.Error(), http.StatusBadGateway) } else if err := assertFrom(from, r.RemoteAddr); err != nil { log.Printf("failed to from: expected %q but got %q: %v", from, r.RemoteAddr, err) http.Error(w, "unexpected from", http.StatusUnauthorized) } else { foo(w, r) } } } func assertFrom(from, remoteAddr string) error { if from == "" { return nil } pattern := regexp.MustCompile(`[0-9](:[0-9]+)$`).FindStringSubmatchIndex(remoteAddr) if len(pattern) == 4 { remoteAddr = remoteAddr[:pattern[2]] } remoteIP := net.ParseIP(remoteAddr) if remoteIP == nil { return fmt.Errorf("cannot parse remote %q", remoteAddr) } _, net, err := net.ParseCIDR(from) if err != nil { panic(err) } if net.Contains(remoteIP) { return nil } return fmt.Errorf("expected like %q but got like %q", from, remoteAddr) } func withMeta(w http.ResponseWriter, r *http.Request) (*http.Request, func()) { meta := map[string]string{ "ts": strconv.FormatInt(time.Now().Unix(), 10), "method": r.Method, "url": r.URL.String(), "id": uuid.New().String(), } w.Header().Set("meta-id", meta["id"]) ctx := r.Context() ctx = context.WithValue(ctx, "meta", meta) r = r.WithContext(ctx) return r, func() { b, err := json.Marshal(meta) if err != nil { panic(err) } fmt.Printf("[access] %s\n", b) } } func pushMeta(r *http.Request, k, v string) { got := r.Context().Value("meta") if got == nil { return } meta, ok := got.(map[string]string) if !ok || meta == nil { return } meta[k] = v } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.Pre(s.Proxy)(w, r) } func (s *Server) List(w http.ResponseWriter) { keys := s.db.Keys(nsRouting) hostURL := map[string]string{} hostFrom := map[string]string{} for _, key := range keys { u, _ := s.lookup(key) if u != nil && strings.TrimSuffix(key, "//auth") == key { hostURL[key] = u.String() } if u != nil && strings.TrimSuffix(key, "//from") == key { hostFrom[key] = u.String() } } json.NewEncoder(w).Encode(map[string]any{ "hostsToURLs": hostURL, "hostsToFrom": hostFrom, }) } type corsResponseWriter struct { r *http.Request http.ResponseWriter } func (cb corsResponseWriter) WriteHeader(code int) { cb.Header().Set("Access-Control-Allow-Origin", "*") cb.Header().Set("Access-Control-Allow-Headers", "X-Auth-Token, content-type, Content-Type") cb.ResponseWriter.WriteHeader(code) pushMeta(cb.r, "cors", "wrote headers") } func doCORS(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, bool) { key := mapKey(r.Host) if !config.GetCORS(key) { return w, false } pushMeta(r, "do-cors", "enabled for key") return _doCORS(w, r) } func _doCORS(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, bool) { w2 := corsResponseWriter{r: r, ResponseWriter: w} if r.Method != http.MethodOptions { pushMeta(r, "-do-cors", "not options") return w2, false } pushMeta(r, "-do-cors", "options") w2.Header().Set("Content-Length", "0") w2.Header().Set("Content-Type", "text/plain") w2.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE") w2.WriteHeader(http.StatusOK) return w2, true } func getProxyAuth(r *http.Request) (string, string) { proxyAuthHeader := r.Header.Get("Proxy-Authorization") proxyAuthB64 := strings.TrimPrefix(proxyAuthHeader, "Basic ") proxyAuthBytes, _ := base64.StdEncoding.DecodeString(proxyAuthB64) proxyAuth := string(proxyAuthBytes) if !strings.Contains(proxyAuth, ":") { return "", "" } proxyAuthSplit := strings.Split(proxyAuth, ":") return proxyAuthSplit[0], proxyAuthSplit[1] } func (s *Server) alt() { switch getScheme() { case schemeHTTP: case schemeHTTPS: default: return } foo := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = getScheme().String() if hostname := r.URL.Hostname(); hostname != "" { r.URL.Host = r.URL.Hostname() + s.addr } else if hostname := r.URL.Host; hostname != "" { r.URL.Host = r.URL.Host + s.addr } else { u := url.URL{Host: r.Host} r.URL.Host = u.Hostname() + s.addr } http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) }) log.Println("redirecting from", s.altaddr) if err := http.ListenAndServe(s.altaddr, foo); err != nil { panic(err) } } func getScheme() listenerScheme { scheme := schemeHTTP _, _, ssl := config.GetSSL() if ssl { scheme = schemeHTTPS } if _, ok := config.GetTCP(); ok { scheme = schemeTCP if ssl { scheme = schemeTCPTLS } } return scheme }