package main import ( "context" "crypto/tls" "errors" "io" "log" "net" "net/http" "time" "golang.org/x/time/rate" ) type Server struct { Transport http.RoundTripper limiter *rate.Limiter Timeout time.Duration } func NewServer(c *Config) *Server { transport := &http.Transport{} transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: c.TLSInsecure, } return &Server{ limiter: c.Limiter, Transport: transport, Timeout: c.Timeout, } } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { //log.Println(r.Method, r.Host) switch r.Method { case http.MethodConnect: s.Connect(w, r) default: s.Serve(w, r) } } func (s *Server) Error(w http.ResponseWriter, err error) { log.Println(err) http.Error(w, err.Error(), http.StatusServiceUnavailable) return } func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { dest, err := net.DialTimeout("tcp", r.Host, 30*time.Second) if err != nil { s.Error(w, err) return } w.WriteHeader(http.StatusOK) hijacker, ok := w.(http.Hijacker) if !ok { s.Error(w, errors.New("hijack not available")) return } client, _, err := hijacker.Hijack() if err != nil { s.Error(w, err) return } xfer := func(dst io.WriteCloser, src io.ReadCloser) { defer dst.Close() defer src.Close() io.Copy( throttledWriter{ ctx: context.Background(), w: dst, limiter: s.limiter, }, src, ) } go xfer(dest, client) go xfer(client, dest) } func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { resp, err := s.Transport.RoundTrip(r) if err != nil { s.Error(w, err) return } defer resp.Body.Close() w.WriteHeader(resp.StatusCode) for k, v := range resp.Header { for _, s := range v { w.Header().Add(k, s) } } io.Copy( throttledWriter{ ctx: r.Context(), w: w, limiter: s.limiter, }, resp.Body, ) }