accept dns from cli, cache dns

master
Bel LaPointe 2022-11-06 08:21:27 -07:00
parent 85698ad0ae
commit c9658e9dcf
2 changed files with 17 additions and 4 deletions

View File

@ -14,6 +14,7 @@ type Config struct {
Timeout time.Duration Timeout time.Duration
TLSInsecure bool TLSInsecure bool
Limiter *rate.Limiter Limiter *rate.Limiter
DNS string
} }
func NewConfig() *Config { func NewConfig() *Config {
@ -24,6 +25,8 @@ func NewConfig() *Config {
as.Append(args.BOOL, "tls-insecure", "permit tls insecure", false) as.Append(args.BOOL, "tls-insecure", "permit tls insecure", false)
as.Append(args.DURATION, "t", "timeout", time.Minute) as.Append(args.DURATION, "t", "timeout", time.Minute)
as.Append(args.STRING, "dns", "dns ip:port", "1.1.1.1:53")
if err := as.Parse(); err != nil { if err := as.Parse(); err != nil {
panic(err) panic(err)
} }
@ -39,5 +42,6 @@ func NewConfig() *Config {
Timeout: as.GetDuration("t"), Timeout: as.GetDuration("t"),
TLSInsecure: as.GetBool("tls-insecure"), TLSInsecure: as.GetBool("tls-insecure"),
Limiter: limiter, Limiter: limiter,
DNS: as.GetString("dns"),
} }
} }

View File

@ -18,6 +18,7 @@ type Server struct {
resolver *net.Resolver resolver *net.Resolver
limiter *rate.Limiter limiter *rate.Limiter
Timeout time.Duration Timeout time.Duration
dnsCache map[string]string
} }
func NewServer(c *Config) *Server { func NewServer(c *Config) *Server {
@ -25,7 +26,7 @@ func NewServer(c *Config) *Server {
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
d := net.Dialer{Timeout: time.Second * 10} d := net.Dialer{Timeout: time.Second * 10}
return d.DialContext(ctx, network, "1.1.1.1:53") return d.DialContext(ctx, network, c.DNS)
}, },
} }
transport := &http.Transport{} transport := &http.Transport{}
@ -37,6 +38,7 @@ func NewServer(c *Config) *Server {
Transport: transport, Transport: transport,
Timeout: c.Timeout, Timeout: c.Timeout,
resolver: resolver, resolver: resolver,
dnsCache: map[string]string{},
} }
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
addr, err := s.dig(ctx, addr) addr, err := s.dig(ctx, addr)
@ -126,17 +128,24 @@ func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) {
} }
func (s *Server) dig(ctx context.Context, host string) (string, error) { func (s *Server) dig(ctx context.Context, host string) (string, error) {
if v, ok := s.dnsCache[host]; ok {
return v, nil
}
search := host
port := "" port := ""
if splithost, splitport, err := net.SplitHostPort(host); err == nil { if splithost, splitport, err := net.SplitHostPort(host); err == nil {
host = splithost search = splithost
port = ":" + splitport port = ":" + splitport
} }
ip, err := s.resolver.LookupHost(ctx, host) ip, err := s.resolver.LookupHost(ctx, search)
if err != nil { if err != nil {
return "", err return "", err
} }
if len(ip) == 0 { if len(ip) == 0 {
return "", errors.New("name does not resolve") return "", errors.New("name does not resolve")
} }
return ip[0] + port, nil result := ip[0] + port
s.dnsCache[host] = result
//log.Printf("dug %s => %s => %s", host, search, result)
return result, nil
} }