diff --git a/config.go b/config.go index 034660e..088771b 100644 --- a/config.go +++ b/config.go @@ -14,6 +14,7 @@ type Config struct { Timeout time.Duration TLSInsecure bool Limiter *rate.Limiter + DNS string } func NewConfig() *Config { @@ -24,6 +25,8 @@ func NewConfig() *Config { as.Append(args.BOOL, "tls-insecure", "permit tls insecure", false) as.Append(args.DURATION, "t", "timeout", time.Minute) + as.Append(args.STRING, "dns", "dns ip:port", "1.1.1.1:53") + if err := as.Parse(); err != nil { panic(err) } @@ -39,5 +42,6 @@ func NewConfig() *Config { Timeout: as.GetDuration("t"), TLSInsecure: as.GetBool("tls-insecure"), Limiter: limiter, + DNS: as.GetString("dns"), } } diff --git a/server.go b/server.go index c1f8d46..4d82266 100644 --- a/server.go +++ b/server.go @@ -18,6 +18,7 @@ type Server struct { resolver *net.Resolver limiter *rate.Limiter Timeout time.Duration + dnsCache map[string]string } func NewServer(c *Config) *Server { @@ -25,7 +26,7 @@ func NewServer(c *Config) *Server { PreferGo: true, Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { d := net.Dialer{Timeout: time.Second * 10} - return d.DialContext(ctx, network, "1.1.1.1:53") + return d.DialContext(ctx, network, c.DNS) }, } transport := &http.Transport{} @@ -37,6 +38,7 @@ func NewServer(c *Config) *Server { Transport: transport, Timeout: c.Timeout, resolver: resolver, + dnsCache: map[string]string{}, } transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { addr, err := s.dig(ctx, addr) @@ -126,17 +128,24 @@ func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) { } func (s *Server) dig(ctx context.Context, host string) (string, error) { + if v, ok := s.dnsCache[host]; ok { + return v, nil + } + search := host port := "" if splithost, splitport, err := net.SplitHostPort(host); err == nil { - host = splithost + search = splithost port = ":" + splitport } - ip, err := s.resolver.LookupHost(ctx, host) + ip, err := s.resolver.LookupHost(ctx, search) if err != nil { return "", err } if len(ip) == 0 { return "", errors.New("name does not resolve") } - return ip[0] + port, nil + result := ip[0] + port + s.dnsCache[host] = result + //log.Printf("dug %s => %s => %s", host, search, result) + return result, nil }