diff --git a/server.go b/server.go index 01f86ba..c1f8d46 100644 --- a/server.go +++ b/server.go @@ -15,20 +15,37 @@ import ( type Server struct { Transport http.RoundTripper + resolver *net.Resolver limiter *rate.Limiter Timeout time.Duration } func NewServer(c *Config) *Server { + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + d := net.Dialer{Timeout: time.Second * 10} + return d.DialContext(ctx, network, "1.1.1.1:53") + }, + } transport := &http.Transport{} transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: c.TLSInsecure, } - return &Server{ + s := &Server{ limiter: c.Limiter, Transport: transport, Timeout: c.Timeout, + resolver: resolver, } + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + addr, err := s.dig(ctx, addr) + if err != nil { + return nil, err + } + return (&net.Dialer{}).DialContext(ctx, network, addr) + } + return s } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -48,7 +65,13 @@ func (s *Server) Error(w http.ResponseWriter, err error) { } func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { - dest, err := net.DialTimeout("tcp", r.Host, 30*time.Second) + host, err := s.dig(r.Context(), r.Host) + if err != nil { + s.Error(w, err) + return + } + + dest, err := net.DialTimeout("tcp", host, 30*time.Second) if err != nil { s.Error(w, err) return @@ -81,7 +104,7 @@ func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { defer resp.Body.Close() for k, v := range resp.Header { - log.Println(k, v) + //log.Println(k, v) for _, s := range v { w.Header().Add(k, s) } @@ -101,3 +124,19 @@ func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) { r, ) } + +func (s *Server) dig(ctx context.Context, host string) (string, error) { + port := "" + if splithost, splitport, err := net.SplitHostPort(host); err == nil { + host = splithost + port = ":" + splitport + } + ip, err := s.resolver.LookupHost(ctx, host) + if err != nil { + return "", err + } + if len(ip) == 0 { + return "", errors.New("name does not resolve") + } + return ip[0] + port, nil +}