no nil resolver

master
Bel LaPointe 2022-12-26 14:16:40 -05:00
parent 0b42c5fac6
commit 512dc9fcd5
1 changed files with 9 additions and 12 deletions

View File

@ -34,7 +34,7 @@ func NewServer(c *Config) *Server {
}, },
} }
if c.DNS == "" { if c.DNS == "" {
resolver = nil resolver = &net.Resolver{}
} }
transport := &http.Transport{} transport := &http.Transport{}
transport.TLSClientConfig = &tls.Config{ transport.TLSClientConfig = &tls.Config{
@ -74,15 +74,15 @@ func (s *Server) Error(r *http.Request, w http.ResponseWriter, err error) {
} }
func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
host, err := s.dig(r.Context(), r.Host) hijacker, ok := w.(http.Hijacker)
if err != nil { if !ok {
s.Error(r, w, err) s.Error(r, w, errors.New("cannot hijack"))
return return
} }
hijacker, ok := w.(http.Hijacker) host, err := s.dig(r.Context(), r.Host)
if host == r.Host || !ok { if err != nil {
s.connectHTTPReverseProxy(w, r) s.Error(r, w, err)
return return
} }
@ -92,14 +92,14 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
return return
} }
w.WriteHeader(http.StatusOK)
client, _, err := hijacker.Hijack() client, _, err := hijacker.Hijack()
if err != nil { if err != nil {
s.Error(r, w, err) s.Error(r, w, err)
return return
} }
w.WriteHeader(http.StatusOK)
go s.xfer(r.Context(), dest, client) go s.xfer(r.Context(), dest, client)
go s.xfer(r.Context(), client, dest) go s.xfer(r.Context(), client, dest)
} }
@ -148,9 +148,6 @@ func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) {
} }
func (s *Server) dig(ctx context.Context, host string) (string, error) { func (s *Server) dig(ctx context.Context, host string) (string, error) {
if s.resolver == nil {
return host, nil
}
search := host search := host
search = strings.TrimPrefix(search, "https://") search = strings.TrimPrefix(search, "https://")
search = strings.TrimPrefix(search, "http://") search = strings.TrimPrefix(search, "http://")