diff --git a/server.go b/server.go index 8553643..15cbf9b 100644 --- a/server.go +++ b/server.go @@ -34,7 +34,7 @@ func NewServer(c *Config) *Server { }, } if c.DNS == "" { - resolver = nil + resolver = &net.Resolver{} } transport := &http.Transport{} transport.TLSClientConfig = &tls.Config{ @@ -74,15 +74,15 @@ func (s *Server) Error(r *http.Request, w http.ResponseWriter, err error) { } func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { - host, err := s.dig(r.Context(), r.Host) - if err != nil { - s.Error(r, w, err) + hijacker, ok := w.(http.Hijacker) + if !ok { + s.Error(r, w, errors.New("cannot hijack")) return } - hijacker, ok := w.(http.Hijacker) - if host == r.Host || !ok { - s.connectHTTPReverseProxy(w, r) + host, err := s.dig(r.Context(), r.Host) + if err != nil { + s.Error(r, w, err) return } @@ -92,14 +92,14 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { return } - w.WriteHeader(http.StatusOK) - client, _, err := hijacker.Hijack() if err != nil { s.Error(r, w, err) return } + w.WriteHeader(http.StatusOK) + go s.xfer(r.Context(), dest, client) go s.xfer(r.Context(), client, dest) } @@ -148,9 +148,6 @@ func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) { } func (s *Server) dig(ctx context.Context, host string) (string, error) { - if s.resolver == nil { - return host, nil - } search := host search = strings.TrimPrefix(search, "https://") search = strings.TrimPrefix(search, "http://")