package server import ( "context" "encoding/base64" "errors" "local/rproxy3/config" "local/rproxy3/storage" "local/rproxy3/storage/packable" "log" "net/http" "net/url" "strings" "time" "golang.org/x/time/rate" ) const nsRouting = "routing" type listenerScheme int const ( schemeHTTP listenerScheme = iota schemeHTTPS listenerScheme = iota ) func (ls listenerScheme) String() string { switch ls { case schemeHTTP: return "http" case schemeHTTPS: return "https" } return "" } type Server struct { db storage.DB addr string username string password string limiter *rate.Limiter } func (s *Server) Route(src, dst string) error { log.Printf("Adding route %q -> %q...\n", src, dst) u, err := url.Parse(dst) if err != nil { return err } return s.db.Set(nsRouting, src, packable.NewURL(u)) } func (s *Server) Run() error { scheme := schemeHTTP if _, _, ok := config.GetSSL(); ok { scheme = schemeHTTPS } log.Printf("Listening for %v on %v...\n", scheme, s.addr) switch scheme { case schemeHTTP: return http.ListenAndServe(s.addr, s) case schemeHTTPS: c, k, _ := config.GetSSL() return http.ListenAndServeTLS(s.addr, c, k, s) } return errors.New("did not load server") } func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { rusr, rpwd, ok := config.GetAuth() if ok { //usr, pwd := getProxyAuth(r) usr, pwd, ok := r.BasicAuth() if !ok || rusr != usr || rpwd != pwd { w.WriteHeader(http.StatusUnauthorized) log.Printf("denying proxy basic auth") return } } foo(w, r) } } func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, can := context.WithTimeout(r.Context(), time.Second*time.Duration(config.GetTimeout())) defer can() if err := s.limiter.Wait(ctx); err != nil { w.WriteHeader(http.StatusTooManyRequests) return } s.doAuth(foo)(w, r) } } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.Pre(s.Proxy)(w, r) } func getProxyAuth(r *http.Request) (string, string) { proxyAuthHeader := r.Header.Get("Proxy-Authorization") proxyAuthB64 := strings.TrimPrefix(proxyAuthHeader, "Basic ") proxyAuthBytes, _ := base64.StdEncoding.DecodeString(proxyAuthB64) proxyAuth := string(proxyAuthBytes) if !strings.Contains(proxyAuth, ":") { return "", "" } proxyAuthSplit := strings.Split(proxyAuth, ":") return proxyAuthSplit[0], proxyAuthSplit[1] }