From 6b18d2031ae8ac439bbcd151f9b206147498a850 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Sun, 6 Nov 2022 07:44:06 -0700 Subject: [PATCH] dedue --- server.go | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/server.go b/server.go index 0334771..1b9cd7f 100644 --- a/server.go +++ b/server.go @@ -55,6 +55,7 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) + hijacker, ok := w.(http.Hijacker) if !ok { s.Error(w, errors.New("hijack not available")) @@ -67,21 +68,8 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { return } - xfer := func(dst io.WriteCloser, src io.ReadCloser) { - defer dst.Close() - defer src.Close() - io.Copy( - throttledWriter{ - ctx: context.Background(), - w: dst, - limiter: s.limiter, - }, - src, - ) - } - - go xfer(dest, client) - go xfer(client, dest) + go s.xfer(r.Context(), dest, client) + go s.xfer(r.Context(), client, dest) } func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { @@ -98,12 +86,17 @@ func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { w.Header().Add(k, s) } } + s.xfer(r.Context(), w, resp.Body) +} + +func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) { + defer r.Close() io.Copy( throttledWriter{ - ctx: r.Context(), + ctx: ctx, w: w, limiter: s.limiter, }, - resp.Body, + r, ) }