From c4a1c9ce98d009bbefb6bfb319d50548837c3f5b Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Mon, 26 Dec 2022 14:34:51 -0500 Subject: [PATCH] time limit dns caching, and tcp6 if tcp6 specifically --- server.go | 51 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/server.go b/server.go index 7c980d2..0af83cf 100644 --- a/server.go +++ b/server.go @@ -20,7 +20,21 @@ type Server struct { resolver *net.Resolver limiter *rate.Limiter Timeout time.Duration - dnsCache map[string]string + dnsCache map[string]dns +} + +type dns struct { + result string + err error + ts time.Time +} + +func (dns dns) ok() bool { + dur := time.Minute * 10 + if dns.err != nil { + dur = time.Minute * 1 + } + return time.Since(dns.ts) < dur } func NewServer(c *Config) *Server { @@ -28,12 +42,12 @@ func NewServer(c *Config) *Server { PreferGo: true, Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { d := net.Dialer{Timeout: c.Timeout} - return d.DialContext(ctx, network, c.DNS) + if c.DNS != "" { + addr = c.DNS + } + return d.DialContext(ctx, network, addr) }, } - if c.DNS == "" { - resolver = &net.Resolver{} - } transport := &http.Transport{} transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: c.TLSInsecure, @@ -43,7 +57,7 @@ func NewServer(c *Config) *Server { Transport: transport, Timeout: c.Timeout, resolver: resolver, - dnsCache: map[string]string{}, + dnsCache: map[string]dns{}, } transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { addr, err := s.dig(ctx, addr) @@ -84,12 +98,18 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { return } - dest, err := net.DialTimeout("tcp", host, s.Timeout) + network := "tcp" + if n := strings.Count(host, ":"); n > 3 { + network = "tcp6" + } + dest, err := net.DialTimeout(network, host, s.Timeout) if err != nil { s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err)) return } + w.WriteHeader(http.StatusOK) + client, _, err := hijacker.Hijack() if err != nil { s.Error(r, w, err) @@ -134,8 +154,8 @@ func (s *Server) dig(ctx context.Context, host string) (string, error) { search := host search = strings.TrimPrefix(search, "https://") search = strings.TrimPrefix(search, "http://") - if v, ok := s.dnsCache[host]; ok { - return v, nil + if v, ok := s.dnsCache[host]; ok && v.ok() { + return v.result, v.err } port := "" if splithost, splitport, err := net.SplitHostPort(host); err == nil { @@ -143,14 +163,13 @@ func (s *Server) dig(ctx context.Context, host string) (string, error) { port = ":" + splitport } ip, err := s.resolver.LookupHost(ctx, search) - if err != nil { + result := "" + if len(ip) > 0 { + result = strings.TrimPrefix(ip[0], "//") + port + } + s.dnsCache[host] = dns{result: result, err: err, ts: time.Now()} + if err != nil || result == "" { return "", fmt.Errorf("failed to dns lookup %s: %v", search, err) } - if len(ip) == 0 { - return "", errors.New("name does not resolve") - } - result := strings.TrimPrefix(ip[0], "//") + port - s.dnsCache[host] = result - //log.Printf("dug %s => %s => %s", host, search, result) return result, nil }