time limit dns caching, and tcp6 if tcp6 specifically

master
Bel LaPointe 2022-12-26 14:34:51 -05:00
parent 899f42d5f8
commit c4a1c9ce98
1 changed files with 35 additions and 16 deletions

View File

@ -20,7 +20,21 @@ type Server struct {
resolver *net.Resolver resolver *net.Resolver
limiter *rate.Limiter limiter *rate.Limiter
Timeout time.Duration Timeout time.Duration
dnsCache map[string]string dnsCache map[string]dns
}
type dns struct {
result string
err error
ts time.Time
}
func (dns dns) ok() bool {
dur := time.Minute * 10
if dns.err != nil {
dur = time.Minute * 1
}
return time.Since(dns.ts) < dur
} }
func NewServer(c *Config) *Server { func NewServer(c *Config) *Server {
@ -28,12 +42,12 @@ func NewServer(c *Config) *Server {
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
d := net.Dialer{Timeout: c.Timeout} d := net.Dialer{Timeout: c.Timeout}
return d.DialContext(ctx, network, c.DNS) if c.DNS != "" {
addr = c.DNS
}
return d.DialContext(ctx, network, addr)
}, },
} }
if c.DNS == "" {
resolver = &net.Resolver{}
}
transport := &http.Transport{} transport := &http.Transport{}
transport.TLSClientConfig = &tls.Config{ transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: c.TLSInsecure, InsecureSkipVerify: c.TLSInsecure,
@ -43,7 +57,7 @@ func NewServer(c *Config) *Server {
Transport: transport, Transport: transport,
Timeout: c.Timeout, Timeout: c.Timeout,
resolver: resolver, resolver: resolver,
dnsCache: map[string]string{}, dnsCache: map[string]dns{},
} }
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
addr, err := s.dig(ctx, addr) addr, err := s.dig(ctx, addr)
@ -84,12 +98,18 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
return return
} }
dest, err := net.DialTimeout("tcp", host, s.Timeout) network := "tcp"
if n := strings.Count(host, ":"); n > 3 {
network = "tcp6"
}
dest, err := net.DialTimeout(network, host, s.Timeout)
if err != nil { if err != nil {
s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err)) s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err))
return return
} }
w.WriteHeader(http.StatusOK)
client, _, err := hijacker.Hijack() client, _, err := hijacker.Hijack()
if err != nil { if err != nil {
s.Error(r, w, err) s.Error(r, w, err)
@ -134,8 +154,8 @@ func (s *Server) dig(ctx context.Context, host string) (string, error) {
search := host search := host
search = strings.TrimPrefix(search, "https://") search = strings.TrimPrefix(search, "https://")
search = strings.TrimPrefix(search, "http://") search = strings.TrimPrefix(search, "http://")
if v, ok := s.dnsCache[host]; ok { if v, ok := s.dnsCache[host]; ok && v.ok() {
return v, nil return v.result, v.err
} }
port := "" port := ""
if splithost, splitport, err := net.SplitHostPort(host); err == nil { if splithost, splitport, err := net.SplitHostPort(host); err == nil {
@ -143,14 +163,13 @@ func (s *Server) dig(ctx context.Context, host string) (string, error) {
port = ":" + splitport port = ":" + splitport
} }
ip, err := s.resolver.LookupHost(ctx, search) ip, err := s.resolver.LookupHost(ctx, search)
if err != nil { result := ""
if len(ip) > 0 {
result = strings.TrimPrefix(ip[0], "//") + port
}
s.dnsCache[host] = dns{result: result, err: err, ts: time.Now()}
if err != nil || result == "" {
return "", fmt.Errorf("failed to dns lookup %s: %v", search, err) return "", fmt.Errorf("failed to dns lookup %s: %v", search, err)
} }
if len(ip) == 0 {
return "", errors.New("name does not resolve")
}
result := strings.TrimPrefix(ip[0], "//") + port
s.dnsCache[host] = result
//log.Printf("dug %s => %s => %s", host, search, result)
return result, nil return result, nil
} }