package main import ( "context" "crypto/tls" "errors" "fmt" "io" "log" "net" "net/http" "strings" "time" "golang.org/x/time/rate" ) type Server struct { Transport http.RoundTripper resolver *net.Resolver limiter *rate.Limiter Timeout time.Duration dnsCache map[string]string } func NewServer(c *Config) *Server { resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { d := net.Dialer{Timeout: c.Timeout} return d.DialContext(ctx, network, c.DNS) }, } if c.DNS == "" { resolver = &net.Resolver{} } transport := &http.Transport{} transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: c.TLSInsecure, } s := &Server{ limiter: c.Limiter, Transport: transport, Timeout: c.Timeout, resolver: resolver, dnsCache: map[string]string{}, } transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { addr, err := s.dig(ctx, addr) if err != nil { return nil, fmt.Errorf("err digging for dial context: %w", err) } return (&net.Dialer{}).DialContext(ctx, network, addr) } return s } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { //log.Println(r.Method, r.Host) switch r.Method { case http.MethodConnect: s.Connect(w, r) default: s.Serve(w, r) } } func (s *Server) Error(r *http.Request, w http.ResponseWriter, err error) { log.Printf("err: %s: %v", r.URL.String(), err) http.Error(w, err.Error(), http.StatusServiceUnavailable) return } func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { host, err := s.dig(r.Context(), r.Host) if err != nil { s.Error(r, w, err) return } dest, err := net.DialTimeout("tcp", host, s.Timeout) if err != nil { s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err)) return } w.WriteHeader(http.StatusOK) hijacker, ok := w.(http.Hijacker) if !ok { s.Error(r, w, errors.New("hijack not available")) return } client, _, err := hijacker.Hijack() if err != nil { s.Error(r, w, err) return } go s.xfer(r.Context(), dest, client) go s.xfer(r.Context(), client, dest) } func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { resp, err := s.Transport.RoundTrip(r) if err != nil { s.Error(r, w, err) return } defer resp.Body.Close() for k, v := range resp.Header { //log.Println(k, v) for _, s := range v { w.Header().Add(k, s) } } w.WriteHeader(resp.StatusCode) s.xfer(r.Context(), w, resp.Body) } func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) { defer r.Close() io.Copy( throttledWriter{ ctx: ctx, w: w, limiter: s.limiter, }, r, ) } func (s *Server) dig(ctx context.Context, host string) (string, error) { search := host search = strings.TrimPrefix(search, "https://") search = strings.TrimPrefix(search, "http://") if v, ok := s.dnsCache[host]; ok { return v, nil } port := "" if splithost, splitport, err := net.SplitHostPort(host); err == nil { search = strings.TrimPrefix(strings.TrimPrefix(splithost, "http://"), "https://") port = ":" + splitport } ip, err := s.resolver.LookupHost(ctx, search) if err != nil { return "", fmt.Errorf("failed to dns lookup %s: %v", search, err) } if len(ip) == 0 { return "", errors.New("name does not resolve") } result := strings.TrimPrefix(ip[0], "//") + port s.dnsCache[host] = result //log.Printf("dug %s => %s => %s", host, search, result) return result, nil }