diff --git a/server/middleware.go b/server/middleware.go index 85a503f..8e71493 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -6,6 +6,7 @@ import ( "local/dndex/server/auth" "local/gziphttp" "net/http" + "strings" "time" ) @@ -62,6 +63,9 @@ func (rest *REST) shift(foo http.HandlerFunc) http.HandlerFunc { i++ } r.URL.Path = r.URL.Path[i:] + if !strings.HasPrefix(r.URL.Path, "/") { + r.URL.Path = "/" + r.URL.Path + } foo(w, r) } } diff --git a/server/middleware_test.go b/server/middleware_test.go new file mode 100644 index 0000000..604fa03 --- /dev/null +++ b/server/middleware_test.go @@ -0,0 +1,134 @@ +package server + +import ( + "context" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" +) + +func TestMiddlewareScoped(t *testing.T) { + rest, doauth, clean := testREST(t) + defer clean() + + foo := rest.scoped(func(w http.ResponseWriter, r *http.Request) {}) + + r := httptest.NewRequest(http.MethodGet, "/entities/"+testEntityID, nil) + doauth(r) + w := httptest.NewRecorder() + foo(w, r) + + scope := rest.scope(r) + if scope.Namespace == "" { + t.Fatalf("scope ns: %q", scope.Namespace) + } + if scope.EntityID == "" { + t.Fatalf("entity id: %q", scope.EntityID) + } + if scope.EntityName == "" { + t.Fatalf("entity name: %q", scope.EntityName) + } +} + +func TestMiddlewareDelay(t *testing.T) { + os.Setenv("DELAY", "100ms") + defer os.Setenv("DELAY", "0ms") + rest, _, clean := testREST(t) + defer clean() + + foo := rest.delay(func(w http.ResponseWriter, r *http.Request) {}) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + start := time.Now() + foo(w, r) + + if time.Since(start) < time.Millisecond*100 { + t.Fatal(time.Since(start)) + } + + ctx, can := context.WithCancel(context.TODO()) + can() + r = r.WithContext(ctx) + + start = time.Now() + foo(w, r) + if time.Since(start) > time.Millisecond*80 { + t.Fatal(time.Since(start)) + } +} + +func TestMiddlewareDefend(t *testing.T) { + os.Setenv("MAX_FILE_SIZE", "20") + defer os.Unsetenv("MAX_FILE_SIZE") + rest, _, clean := testREST(t) + defer clean() + + var b []byte + foo := rest.defend(func(w http.ResponseWriter, r *http.Request) { + b, _ = ioutil.ReadAll(r.Body) + }) + + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(strings.Repeat("a", 30))) + w := httptest.NewRecorder() + + foo(w, r) + + if len(b) != 20 { + t.Fatal(len(b)) + } + t.Log(len(b)) +} + +func TestMiddlewareShift(t *testing.T) { + cases := map[string]struct { + input string + output string + }{ + "root": { + input: "/", + output: "/", + }, + "root w/ query param": { + input: "/?a=b", + output: "/?a=b", + }, + "base": { + input: "/base", + output: "/", + }, + "base w/ query param": { + input: "/base?a=b", + output: "/?a=b", + }, + "sub": { + input: "/base/sub", + output: "/sub", + }, + "sub w/ query param": { + input: "/base/sub?a=b", + output: "/sub?a=b", + }, + } + + for name, d := range cases { + c := d + t.Run(name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, c.input, nil) + w := httptest.NewRecorder() + rest := &REST{} + foo := rest.shift(func(http.ResponseWriter, *http.Request) {}) + log.Printf("%+v", r.URL) + foo(w, r) + if r.URL.String() != c.output { + t.Fatalf("from %q, want %q, got %q", c.input, c.output, r.URL.String()) + } + }) + } +} diff --git a/server/rest.go b/server/rest.go index 73a1baf..e1800ce 100644 --- a/server/rest.go +++ b/server/rest.go @@ -59,5 +59,5 @@ func NewREST(g storage.RateLimitedGraph) (*REST, error) { } func (rest *REST) scope(r *http.Request) auth.Scope { - return auth.GetScope(r) + return auth.GetScope(r, rest.g) }