package server import ( "context" "io/ioutil" "net/http" "net/http/httptest" "os" "strings" "testing" "time" ) func TestMiddlewareScoped(t *testing.T) { rest, doauth, clean := testREST(t) defer clean() foo := rest.scoped(func(w http.ResponseWriter, r *http.Request) {}) r := httptest.NewRequest(http.MethodGet, "/entities/"+testEntityID, nil) doauth(r) w := httptest.NewRecorder() foo(w, r) scope := rest.scope(r) if scope.Namespace == "" { t.Fatalf("scope ns: %q", scope.Namespace) } if scope.EntityID == "" { t.Fatalf("entity id: %q", scope.EntityID) } if scope.EntityName == "" { t.Fatalf("entity name: %q", scope.EntityName) } } func TestMiddlewareDelay(t *testing.T) { os.Setenv("DELAY", "100ms") defer os.Setenv("DELAY", "0ms") rest, _, clean := testREST(t) defer clean() foo := rest.delay(func(w http.ResponseWriter, r *http.Request) {}) r := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() start := time.Now() foo(w, r) if time.Since(start) < time.Millisecond*100 { t.Fatal(time.Since(start)) } ctx, can := context.WithCancel(context.Background()) can() r = r.WithContext(ctx) start = time.Now() foo(w, r) if time.Since(start) > time.Millisecond*80 { t.Fatal(time.Since(start)) } } func TestMiddlewareDefend(t *testing.T) { os.Setenv("MAX_FILE_SIZE", "20") defer os.Unsetenv("MAX_FILE_SIZE") rest, _, clean := testREST(t) defer clean() var b []byte foo := rest.defend(func(w http.ResponseWriter, r *http.Request) { b, _ = ioutil.ReadAll(r.Body) }) r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(strings.Repeat("a", 30))) w := httptest.NewRecorder() foo(w, r) if len(b) != 20 { t.Fatal(len(b)) } t.Log(len(b)) } func TestMiddlewareShift(t *testing.T) { cases := map[string]struct { input string output string }{ "root": { input: "/", output: "/", }, "root w/ query param": { input: "/?a=b", output: "/?a=b", }, "base": { input: "/base", output: "/", }, "base w/ query param": { input: "/base?a=b", output: "/?a=b", }, "sub": { input: "/base/sub", output: "/sub", }, "sub w/ query param": { input: "/base/sub?a=b", output: "/sub?a=b", }, } for name, d := range cases { c := d t.Run(name, func(t *testing.T) { r := httptest.NewRequest(http.MethodGet, c.input, nil) w := httptest.NewRecorder() rest := &REST{} foo := rest.shift(func(http.ResponseWriter, *http.Request) {}) t.Logf("%+v", r.URL) foo(w, r) if r.URL.String() != c.output { t.Fatalf("from %q, want %q, got %q", c.input, c.output, r.URL.String()) } }) } }