diff --git a/server/server.go b/server/server.go index 5c43963..2479a8e 100755 --- a/server/server.go +++ b/server/server.go @@ -269,7 +269,8 @@ func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { w.WriteHeader(http.StatusTooManyRequests) return } - if did := s.doCORS(w, r); did { + w, did := s.doCORS(w, r) + if did { return } if s.auth.BOAuthZ { @@ -288,20 +289,29 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.Pre(s.Proxy)(w, r) } -func (s *Server) doCORS(w http.ResponseWriter, r *http.Request) bool { +type corsResponseWriter struct { + http.ResponseWriter +} + +func (cb corsResponseWriter) WriteHeader(code int) { + cb.Header().Set("Access-Control-Allow-Origin", "*") + cb.Header().Set("Access-Control-Allow-Headers", "X-Auth-Token, content-type, Content-Type") + cb.ResponseWriter.WriteHeader(code) +} + +func (s *Server) doCORS(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, bool) { key := mapKey(r.Host) if !config.GetCORS(key) { - return false + return w, false } - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "X-Auth-Token, content-type, Content-Type") + w = corsResponseWriter{ResponseWriter: w} if r.Method != "OPTIONS" { - return false + return w, false } w.Header().Set("Content-Length", "0") w.Header().Set("Content-Type", "text/plain") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE") - return true + return w, true } func getProxyAuth(r *http.Request) (string, string) {