time limit dns caching, and tcp6 if tcp6 specifically
parent
899f42d5f8
commit
c4a1c9ce98
51
server.go
51
server.go
|
|
@ -20,7 +20,21 @@ type Server struct {
|
||||||
resolver *net.Resolver
|
resolver *net.Resolver
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
dnsCache map[string]string
|
dnsCache map[string]dns
|
||||||
|
}
|
||||||
|
|
||||||
|
type dns struct {
|
||||||
|
result string
|
||||||
|
err error
|
||||||
|
ts time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dns dns) ok() bool {
|
||||||
|
dur := time.Minute * 10
|
||||||
|
if dns.err != nil {
|
||||||
|
dur = time.Minute * 1
|
||||||
|
}
|
||||||
|
return time.Since(dns.ts) < dur
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(c *Config) *Server {
|
func NewServer(c *Config) *Server {
|
||||||
|
|
@ -28,12 +42,12 @@ func NewServer(c *Config) *Server {
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
d := net.Dialer{Timeout: c.Timeout}
|
d := net.Dialer{Timeout: c.Timeout}
|
||||||
return d.DialContext(ctx, network, c.DNS)
|
if c.DNS != "" {
|
||||||
|
addr = c.DNS
|
||||||
|
}
|
||||||
|
return d.DialContext(ctx, network, addr)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if c.DNS == "" {
|
|
||||||
resolver = &net.Resolver{}
|
|
||||||
}
|
|
||||||
transport := &http.Transport{}
|
transport := &http.Transport{}
|
||||||
transport.TLSClientConfig = &tls.Config{
|
transport.TLSClientConfig = &tls.Config{
|
||||||
InsecureSkipVerify: c.TLSInsecure,
|
InsecureSkipVerify: c.TLSInsecure,
|
||||||
|
|
@ -43,7 +57,7 @@ func NewServer(c *Config) *Server {
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
Timeout: c.Timeout,
|
Timeout: c.Timeout,
|
||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
dnsCache: map[string]string{},
|
dnsCache: map[string]dns{},
|
||||||
}
|
}
|
||||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
addr, err := s.dig(ctx, addr)
|
addr, err := s.dig(ctx, addr)
|
||||||
|
|
@ -84,12 +98,18 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dest, err := net.DialTimeout("tcp", host, s.Timeout)
|
network := "tcp"
|
||||||
|
if n := strings.Count(host, ":"); n > 3 {
|
||||||
|
network = "tcp6"
|
||||||
|
}
|
||||||
|
dest, err := net.DialTimeout(network, host, s.Timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err))
|
s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
client, _, err := hijacker.Hijack()
|
client, _, err := hijacker.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Error(r, w, err)
|
s.Error(r, w, err)
|
||||||
|
|
@ -134,8 +154,8 @@ func (s *Server) dig(ctx context.Context, host string) (string, error) {
|
||||||
search := host
|
search := host
|
||||||
search = strings.TrimPrefix(search, "https://")
|
search = strings.TrimPrefix(search, "https://")
|
||||||
search = strings.TrimPrefix(search, "http://")
|
search = strings.TrimPrefix(search, "http://")
|
||||||
if v, ok := s.dnsCache[host]; ok {
|
if v, ok := s.dnsCache[host]; ok && v.ok() {
|
||||||
return v, nil
|
return v.result, v.err
|
||||||
}
|
}
|
||||||
port := ""
|
port := ""
|
||||||
if splithost, splitport, err := net.SplitHostPort(host); err == nil {
|
if splithost, splitport, err := net.SplitHostPort(host); err == nil {
|
||||||
|
|
@ -143,14 +163,13 @@ func (s *Server) dig(ctx context.Context, host string) (string, error) {
|
||||||
port = ":" + splitport
|
port = ":" + splitport
|
||||||
}
|
}
|
||||||
ip, err := s.resolver.LookupHost(ctx, search)
|
ip, err := s.resolver.LookupHost(ctx, search)
|
||||||
if err != nil {
|
result := ""
|
||||||
|
if len(ip) > 0 {
|
||||||
|
result = strings.TrimPrefix(ip[0], "//") + port
|
||||||
|
}
|
||||||
|
s.dnsCache[host] = dns{result: result, err: err, ts: time.Now()}
|
||||||
|
if err != nil || result == "" {
|
||||||
return "", fmt.Errorf("failed to dns lookup %s: %v", search, err)
|
return "", fmt.Errorf("failed to dns lookup %s: %v", search, err)
|
||||||
}
|
}
|
||||||
if len(ip) == 0 {
|
|
||||||
return "", errors.New("name does not resolve")
|
|
||||||
}
|
|
||||||
result := strings.TrimPrefix(ip[0], "//") + port
|
|
||||||
s.dnsCache[host] = result
|
|
||||||
//log.Printf("dug %s => %s => %s", host, search, result)
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue