package server import ( "context" "crypto/tls" "encoding/base64" "errors" "fmt" "io" "local/oauth2/oauth2client" "local/rproxy3/config" "local/rproxy3/storage" "local/rproxy3/storage/packable" "log" "net" "net/http" "net/url" "strings" "time" "golang.org/x/time/rate" ) const nsRouting = "routing" const nsBOAuthZ = "oauth" type listenerScheme int const ( schemeHTTP listenerScheme = iota schemeHTTPS listenerScheme = iota schemeTCP listenerScheme = iota ) func (ls listenerScheme) String() string { switch ls { case schemeHTTP: return "http" case schemeHTTPS: return "https" case schemeTCP: return "tcp" } return "" } type Server struct { db storage.DB addr string altaddr string username string password string limiter *rate.Limiter } func (s *Server) Route(src string, dst config.Proxy) error { hasOAuth := strings.HasPrefix(src, "+") src = strings.TrimPrefix(src, "+") log.Printf("Adding route %q -> %v...\n", src, dst) u, err := url.Parse(dst.To) if err != nil { return err } s.db.Set(nsBOAuthZ, src, packable.NewString(fmt.Sprint(hasOAuth))) return s.db.Set(nsRouting, src, packable.NewURL(u)) } func (s *Server) Run() error { go s.alt() scheme := getScheme() log.Printf("Listening for %v on %v...\n", scheme, s.addr) switch scheme { case schemeHTTP: return http.ListenAndServe(s.addr, s) case schemeHTTPS: c, k, _ := config.GetSSL() httpsServer := &http.Server{ Addr: s.addr, Handler: s, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_CBC_SHA, }, }, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0), } return httpsServer.ListenAndServeTLS(c, k) case schemeTCP: addr, _ := config.GetTCP() return s.ServeTCP(addr) } return errors.New("did not load server") } func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { rusr, rpwd, ok := config.GetAuth() if ok { usr, pwd, ok := r.BasicAuth() if !ok || rusr != usr || rpwd != pwd { w.WriteHeader(http.StatusUnauthorized) log.Printf("denying proxy basic auth") return } } key := mapKey(r.Host) ok, err := s.lookupBOAuthZ(key) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } if url, exists := config.GetBOAuthZ(); ok && exists { err := oauth2client.Authenticate(url, key, w, r) if err != nil { return } } foo(w, r) } } func (s *Server) ServeTCP(addr string) error { listen, err := net.Listen("tcp", s.addr) if err != nil { return err } for { c, err := listen.Accept() if err != nil { return err } go func(c net.Conn) { d, err := net.Dial("tcp", addr) if err != nil { log.Println(err) return } go pipe(c, d) go pipe(d, c) }(c) } } func pipe(a, b net.Conn) { log.Println("open pipe") defer log.Println("close pipe") defer a.Close() defer b.Close() io.Copy(a, b) } func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, can := context.WithTimeout(r.Context(), time.Duration(config.GetTimeout())) defer can() if err := s.limiter.Wait(ctx); err != nil { w.WriteHeader(http.StatusTooManyRequests) return } s.doAuth(foo)(w, r) } } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.Pre(s.Proxy)(w, r) } func getProxyAuth(r *http.Request) (string, string) { proxyAuthHeader := r.Header.Get("Proxy-Authorization") proxyAuthB64 := strings.TrimPrefix(proxyAuthHeader, "Basic ") proxyAuthBytes, _ := base64.StdEncoding.DecodeString(proxyAuthB64) proxyAuth := string(proxyAuthBytes) if !strings.Contains(proxyAuth, ":") { return "", "" } proxyAuthSplit := strings.Split(proxyAuth, ":") return proxyAuthSplit[0], proxyAuthSplit[1] } func (s *Server) alt() { switch getScheme() { case schemeHTTP: case schemeHTTPS: default: return } foo := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = getScheme().String() if hostname := r.URL.Hostname(); hostname != "" { r.URL.Host = r.URL.Hostname() + s.addr } else if hostname := r.URL.Host; hostname != "" { r.URL.Host = r.URL.Host + s.addr } else { u := url.URL{Host: r.Host} r.URL.Host = u.Hostname() + s.addr } http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) }) log.Println("redirecting from", s.altaddr) if err := http.ListenAndServe(s.altaddr, foo); err != nil { panic(err) } } func getScheme() listenerScheme { scheme := schemeHTTP if _, _, ok := config.GetSSL(); ok { scheme = schemeHTTPS } if _, ok := config.GetTCP(); ok { scheme = schemeTCP } return scheme }