package server import ( "context" "crypto/tls" "encoding/base64" "errors" "local/rproxy3/config" "local/rproxy3/storage" "local/rproxy3/storage/packable" "log" "net/http" "net/url" "strings" "time" "golang.org/x/time/rate" ) const nsRouting = "routing" type listenerScheme int const ( schemeHTTP listenerScheme = iota schemeHTTPS listenerScheme = iota ) func (ls listenerScheme) String() string { switch ls { case schemeHTTP: return "http" case schemeHTTPS: return "https" } return "" } type Server struct { db storage.DB addr string username string password string limiter *rate.Limiter } func (s *Server) Route(src, dst string) error { log.Printf("Adding route %q -> %q...\n", src, dst) u, err := url.Parse(dst) if err != nil { return err } return s.db.Set(nsRouting, src, packable.NewURL(u)) } func (s *Server) Run() error { scheme := schemeHTTP if _, _, ok := config.GetSSL(); ok { scheme = schemeHTTPS } log.Printf("Listening for %v on %v...\n", scheme, s.addr) switch scheme { case schemeHTTP: log.Printf("Serve http") return http.ListenAndServe(s.addr, s) case schemeHTTPS: log.Printf("Serve https") c, k, _ := config.GetSSL() httpsServer := &http.Server{ Addr: s.addr, Handler: s, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_CBC_SHA, }, }, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0), } return httpsServer.ListenAndServeTLS(c, k) } return errors.New("did not load server") } func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { rusr, rpwd, ok := config.GetAuth() if ok { //usr, pwd := getProxyAuth(r) usr, pwd, ok := r.BasicAuth() if !ok || rusr != usr || rpwd != pwd { w.WriteHeader(http.StatusUnauthorized) log.Printf("denying proxy basic auth") return } } foo(w, r) } } func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, can := context.WithTimeout(r.Context(), time.Second*time.Duration(config.GetTimeout())) defer can() if err := s.limiter.Wait(ctx); err != nil { w.WriteHeader(http.StatusTooManyRequests) return } s.doAuth(foo)(w, r) } } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.Pre(s.Proxy)(w, r) } func getProxyAuth(r *http.Request) (string, string) { proxyAuthHeader := r.Header.Get("Proxy-Authorization") proxyAuthB64 := strings.TrimPrefix(proxyAuthHeader, "Basic ") proxyAuthBytes, _ := base64.StdEncoding.DecodeString(proxyAuthB64) proxyAuth := string(proxyAuthBytes) if !strings.Contains(proxyAuth, ":") { return "", "" } proxyAuthSplit := strings.Split(proxyAuth, ":") return proxyAuthSplit[0], proxyAuthSplit[1] }