use custom resolver for always 1.1.1.1 dns
parent
d0d5af0ac1
commit
85698ad0ae
45
server.go
45
server.go
|
|
@ -15,20 +15,37 @@ import (
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
|
resolver *net.Resolver
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(c *Config) *Server {
|
func NewServer(c *Config) *Server {
|
||||||
|
resolver := &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
d := net.Dialer{Timeout: time.Second * 10}
|
||||||
|
return d.DialContext(ctx, network, "1.1.1.1:53")
|
||||||
|
},
|
||||||
|
}
|
||||||
transport := &http.Transport{}
|
transport := &http.Transport{}
|
||||||
transport.TLSClientConfig = &tls.Config{
|
transport.TLSClientConfig = &tls.Config{
|
||||||
InsecureSkipVerify: c.TLSInsecure,
|
InsecureSkipVerify: c.TLSInsecure,
|
||||||
}
|
}
|
||||||
return &Server{
|
s := &Server{
|
||||||
limiter: c.Limiter,
|
limiter: c.Limiter,
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
Timeout: c.Timeout,
|
Timeout: c.Timeout,
|
||||||
|
resolver: resolver,
|
||||||
}
|
}
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
addr, err := s.dig(ctx, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return (&net.Dialer{}).DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
@ -48,7 +65,13 @@ func (s *Server) Error(w http.ResponseWriter, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
|
||||||
dest, err := net.DialTimeout("tcp", r.Host, 30*time.Second)
|
host, err := s.dig(r.Context(), r.Host)
|
||||||
|
if err != nil {
|
||||||
|
s.Error(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dest, err := net.DialTimeout("tcp", host, 30*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Error(w, err)
|
s.Error(w, err)
|
||||||
return
|
return
|
||||||
|
|
@ -81,7 +104,7 @@ func (s *Server) Serve(w http.ResponseWriter, r *http.Request) {
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
for k, v := range resp.Header {
|
||||||
log.Println(k, v)
|
//log.Println(k, v)
|
||||||
for _, s := range v {
|
for _, s := range v {
|
||||||
w.Header().Add(k, s)
|
w.Header().Add(k, s)
|
||||||
}
|
}
|
||||||
|
|
@ -101,3 +124,19 @@ func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) {
|
||||||
r,
|
r,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) dig(ctx context.Context, host string) (string, error) {
|
||||||
|
port := ""
|
||||||
|
if splithost, splitport, err := net.SplitHostPort(host); err == nil {
|
||||||
|
host = splithost
|
||||||
|
port = ":" + splitport
|
||||||
|
}
|
||||||
|
ip, err := s.resolver.LookupHost(ctx, host)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(ip) == 0 {
|
||||||
|
return "", errors.New("name does not resolve")
|
||||||
|
}
|
||||||
|
return ip[0] + port, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue