use custom resolver for always 1.1.1.1 dns

master
Bel LaPointe 2022-11-06 08:17:04 -07:00
parent d0d5af0ac1
commit 85698ad0ae
1 changed files with 42 additions and 3 deletions

View File

@ -15,20 +15,37 @@ import (
type Server struct { type Server struct {
Transport http.RoundTripper Transport http.RoundTripper
resolver *net.Resolver
limiter *rate.Limiter limiter *rate.Limiter
Timeout time.Duration Timeout time.Duration
} }
func NewServer(c *Config) *Server { func NewServer(c *Config) *Server {
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
d := net.Dialer{Timeout: time.Second * 10}
return d.DialContext(ctx, network, "1.1.1.1:53")
},
}
transport := &http.Transport{} transport := &http.Transport{}
transport.TLSClientConfig = &tls.Config{ transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: c.TLSInsecure, InsecureSkipVerify: c.TLSInsecure,
} }
return &Server{ s := &Server{
limiter: c.Limiter, limiter: c.Limiter,
Transport: transport, Transport: transport,
Timeout: c.Timeout, Timeout: c.Timeout,
resolver: resolver,
} }
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
addr, err := s.dig(ctx, addr)
if err != nil {
return nil, err
}
return (&net.Dialer{}).DialContext(ctx, network, addr)
}
return s
} }
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -48,7 +65,13 @@ func (s *Server) Error(w http.ResponseWriter, err error) {
} }
func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
dest, err := net.DialTimeout("tcp", r.Host, 30*time.Second) host, err := s.dig(r.Context(), r.Host)
if err != nil {
s.Error(w, err)
return
}
dest, err := net.DialTimeout("tcp", host, 30*time.Second)
if err != nil { if err != nil {
s.Error(w, err) s.Error(w, err)
return return
@ -81,7 +104,7 @@ func (s *Server) Serve(w http.ResponseWriter, r *http.Request) {
defer resp.Body.Close() defer resp.Body.Close()
for k, v := range resp.Header { for k, v := range resp.Header {
log.Println(k, v) //log.Println(k, v)
for _, s := range v { for _, s := range v {
w.Header().Add(k, s) w.Header().Add(k, s)
} }
@ -101,3 +124,19 @@ func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) {
r, r,
) )
} }
func (s *Server) dig(ctx context.Context, host string) (string, error) {
port := ""
if splithost, splitport, err := net.SplitHostPort(host); err == nil {
host = splithost
port = ":" + splitport
}
ip, err := s.resolver.LookupHost(ctx, host)
if err != nil {
return "", err
}
if len(ip) == 0 {
return "", errors.New("name does not resolve")
}
return ip[0] + port, nil
}