package main import ( "context" "crypto/tls" "errors" "fmt" "io" "log" "net" "net/http" "strings" "sync" "time" "golang.org/x/time/rate" ) type Server struct { Transport http.RoundTripper resolver *net.Resolver limiter *rate.Limiter Timeout time.Duration dnsCacheLock sync.Mutex dnsCache map[string]dns } type dns struct { result string err error ts time.Time } func (dns dns) ok() bool { dur := time.Minute * 10 if dns.err != nil { dur = time.Minute * 1 } return time.Since(dns.ts) < dur } func NewServer(c *Config) *Server { resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { d := net.Dialer{Timeout: c.Timeout} if c.DNS != "" { addr = c.DNS } return d.DialContext(ctx, network, addr) }, } transport := &http.Transport{} transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: c.TLSInsecure, } s := &Server{ limiter: c.Limiter, Transport: transport, Timeout: c.Timeout, resolver: resolver, dnsCache: map[string]dns{}, } transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { addr, err := s.dig(ctx, addr) if err != nil { return nil, fmt.Errorf("err digging for dial context: %w", err) } return (&net.Dialer{}).DialContext(ctx, network, addr) } return s } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { //log.Println(r.Method, r.Host) switch r.Method { case http.MethodConnect: s.Connect(w, r) default: s.Serve(w, r) } } func (s *Server) Error(r *http.Request, w http.ResponseWriter, err error) { log.Printf("err: %s: %v", r.URL.String(), err) http.Error(w, err.Error(), http.StatusServiceUnavailable) return } func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { hijacker, ok := w.(http.Hijacker) if !ok { s.Error(r, w, errors.New("cannot hijack")) return } host, err := s.dig(r.Context(), r.Host) if err != nil { s.Error(r, w, err) return } network := "tcp" if n := strings.Count(host, ":"); n > 3 { network = "tcp6" } dest, err := net.DialTimeout(network, host, s.Timeout) if err != nil { s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err)) return } w.WriteHeader(http.StatusOK) client, _, err := hijacker.Hijack() if err != nil { s.Error(r, w, err) return } go s.xfer(r.Context(), dest, client) go s.xfer(r.Context(), client, dest) } func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { resp, err := s.Transport.RoundTrip(r) if err != nil { s.Error(r, w, err) return } defer resp.Body.Close() for k, v := range resp.Header { //log.Println(k, v) for _, s := range v { w.Header().Add(k, s) } } w.WriteHeader(resp.StatusCode) s.xfer(r.Context(), w, resp.Body) } func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) { defer r.Close() io.Copy( throttledWriter{ ctx: ctx, w: w, limiter: s.limiter, }, r, ) } func (s *Server) dig(ctx context.Context, host string) (string, error) { search := host search = strings.TrimPrefix(search, "https://") search = strings.TrimPrefix(search, "http://") s.dnsCacheLock.Lock() defer s.dnsCacheLock.Unlock() if v, ok := s.dnsCache[host]; ok && v.ok() { return v.result, v.err } port := "" if splithost, splitport, err := net.SplitHostPort(host); err == nil { search = strings.TrimPrefix(strings.TrimPrefix(splithost, "http://"), "https://") port = ":" + splitport } ip, err := s.resolver.LookupHost(ctx, search) result := "" if len(ip) > 0 { result = strings.TrimPrefix(ip[0], "//") + port } s.dnsCache[host] = dns{result: result, err: err, ts: time.Now()} if err != nil || result == "" { return "", fmt.Errorf("failed to dns lookup %s: %v", search, err) } return result, nil }