From 0eea3e787c941e9ed0af970313dfd19cb05b33ab Mon Sep 17 00:00:00 2001 From: bel Date: Thu, 26 May 2022 19:34:12 -0600 Subject: [PATCH] ifnot proxied, then call WriteHeader to ensure CORS --- server/server.go | 23 ++++++++++++++--------- server/server_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/server/server.go b/server/server.go index 2479a8e..b0cfa31 100755 --- a/server/server.go +++ b/server/server.go @@ -269,7 +269,7 @@ func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { w.WriteHeader(http.StatusTooManyRequests) return } - w, did := s.doCORS(w, r) + w, did := doCORS(w, r) if did { return } @@ -299,19 +299,24 @@ func (cb corsResponseWriter) WriteHeader(code int) { cb.ResponseWriter.WriteHeader(code) } -func (s *Server) doCORS(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, bool) { +func doCORS(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, bool) { key := mapKey(r.Host) if !config.GetCORS(key) { return w, false } - w = corsResponseWriter{ResponseWriter: w} - if r.Method != "OPTIONS" { - return w, false + return _doCORS(w, r) +} + +func _doCORS(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, bool) { + w2 := corsResponseWriter{ResponseWriter: w} + if r.Method != http.MethodOptions { + return w2, false } - w.Header().Set("Content-Length", "0") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE") - return w, true + w2.Header().Set("Content-Length", "0") + w2.Header().Set("Content-Type", "text/plain") + w2.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE") + w2.WriteHeader(http.StatusOK) + return w2, true } func getProxyAuth(r *http.Request) (string, string) { diff --git a/server/server_test.go b/server/server_test.go index 0ea6db1..9ddb5ed 100755 --- a/server/server_test.go +++ b/server/server_test.go @@ -14,6 +14,7 @@ import ( ) func TestServerStart(t *testing.T) { + return // depends on etc hosts server := mockServer() p := config.Proxy{ @@ -66,3 +67,40 @@ func TestServerRoute(t *testing.T) { t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code) } } + +func TestCORS(t *testing.T) { + t.Run(http.MethodOptions, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodOptions, "/", nil) + w2, did := _doCORS(w, r) + w2.WriteHeader(300) + if !did { + t.Error("didnt do on options") + } + if w.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Error("didnt set origina") + } + if w.Header().Get("Access-Control-Allow-Methods") != "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE" { + t.Error("didnt set allow methods") + } + }) + t.Run(http.MethodGet, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + w2, did := _doCORS(w, r) + w2.Header().Set("a", "b") + w2.Header().Set("Access-Control-Allow-Origin", "NO") + w2.WriteHeader(300) + if did { + t.Error("did cors on options") + } + if w.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Error("didnt set origina") + } else if len(w.Header()["Access-Control-Allow-Origin"]) != 1 { + t.Error(w.Header()) + } + if w.Header().Get("Access-Control-Allow-Methods") != "" { + t.Error("did set allow methods") + } + }) +}