package main import ( "context" "crypto/tls" "errors" "io" "log" "net" "net/http" "time" "golang.org/x/time/rate" ) type Server struct { Transport http.RoundTripper resolver *net.Resolver limiter *rate.Limiter Timeout time.Duration } func NewServer(c *Config) *Server { resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { d := net.Dialer{Timeout: time.Second * 10} return d.DialContext(ctx, network, "1.1.1.1:53") }, } transport := &http.Transport{} transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: c.TLSInsecure, } s := &Server{ limiter: c.Limiter, Transport: transport, Timeout: c.Timeout, resolver: resolver, } transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { addr, err := s.dig(ctx, addr) if err != nil { return nil, err } return (&net.Dialer{}).DialContext(ctx, network, addr) } return s } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { //log.Println(r.Method, r.Host) switch r.Method { case http.MethodConnect: s.Connect(w, r) default: s.Serve(w, r) } } func (s *Server) Error(w http.ResponseWriter, err error) { log.Println(err) http.Error(w, err.Error(), http.StatusServiceUnavailable) return } func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { host, err := s.dig(r.Context(), r.Host) if err != nil { s.Error(w, err) return } dest, err := net.DialTimeout("tcp", host, 30*time.Second) if err != nil { s.Error(w, err) return } w.WriteHeader(http.StatusOK) hijacker, ok := w.(http.Hijacker) if !ok { s.Error(w, errors.New("hijack not available")) return } client, _, err := hijacker.Hijack() if err != nil { s.Error(w, err) return } go s.xfer(r.Context(), dest, client) go s.xfer(r.Context(), client, dest) } func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { resp, err := s.Transport.RoundTrip(r) if err != nil { s.Error(w, err) return } defer resp.Body.Close() for k, v := range resp.Header { //log.Println(k, v) for _, s := range v { w.Header().Add(k, s) } } w.WriteHeader(resp.StatusCode) s.xfer(r.Context(), w, resp.Body) } func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) { defer r.Close() io.Copy( throttledWriter{ ctx: ctx, w: w, limiter: s.limiter, }, r, ) } func (s *Server) dig(ctx context.Context, host string) (string, error) { port := "" if splithost, splitport, err := net.SplitHostPort(host); err == nil { host = splithost port = ":" + splitport } ip, err := s.resolver.LookupHost(ctx, host) if err != nil { return "", err } if len(ip) == 0 { return "", errors.New("name does not resolve") } return ip[0] + port, nil }