diff --git a/server.go b/server.go index 0334771..1b9cd7f 100644 --- a/server.go +++ b/server.go @@ -55,6 +55,7 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) + hijacker, ok := w.(http.Hijacker) if !ok { s.Error(w, errors.New("hijack not available")) @@ -67,21 +68,8 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { return } - xfer := func(dst io.WriteCloser, src io.ReadCloser) { - defer dst.Close() - defer src.Close() - io.Copy( - throttledWriter{ - ctx: context.Background(), - w: dst, - limiter: s.limiter, - }, - src, - ) - } - - go xfer(dest, client) - go xfer(client, dest) + go s.xfer(r.Context(), dest, client) + go s.xfer(r.Context(), client, dest) } func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { @@ -98,12 +86,17 @@ func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { w.Header().Add(k, s) } } + s.xfer(r.Context(), w, resp.Body) +} + +func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) { + defer r.Close() io.Copy( throttledWriter{ - ctx: r.Context(), + ctx: ctx, w: w, limiter: s.limiter, }, - resp.Body, + r, ) }