dndex/server/middleware_test.go

135 lines
2.6 KiB
Go

package server
import (
"context"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
)
func TestMiddlewareScoped(t *testing.T) {
rest, doauth, clean := testREST(t)
defer clean()
foo := rest.scoped(func(w http.ResponseWriter, r *http.Request) {})
r := httptest.NewRequest(http.MethodGet, "/entities/"+testEntityID, nil)
doauth(r)
w := httptest.NewRecorder()
foo(w, r)
scope := rest.scope(r)
if scope.Namespace == "" {
t.Fatalf("scope ns: %q", scope.Namespace)
}
if scope.EntityID == "" {
t.Fatalf("entity id: %q", scope.EntityID)
}
if scope.EntityName == "" {
t.Fatalf("entity name: %q", scope.EntityName)
}
}
func TestMiddlewareDelay(t *testing.T) {
os.Setenv("DELAY", "100ms")
defer os.Setenv("DELAY", "0ms")
rest, _, clean := testREST(t)
defer clean()
foo := rest.delay(func(w http.ResponseWriter, r *http.Request) {})
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
start := time.Now()
foo(w, r)
if time.Since(start) < time.Millisecond*100 {
t.Fatal(time.Since(start))
}
ctx, can := context.WithCancel(context.Background())
can()
r = r.WithContext(ctx)
start = time.Now()
foo(w, r)
if time.Since(start) > time.Millisecond*80 {
t.Fatal(time.Since(start))
}
}
func TestMiddlewareDefend(t *testing.T) {
os.Setenv("MAX_FILE_SIZE", "20")
defer os.Unsetenv("MAX_FILE_SIZE")
rest, _, clean := testREST(t)
defer clean()
var b []byte
foo := rest.defend(func(w http.ResponseWriter, r *http.Request) {
b, _ = ioutil.ReadAll(r.Body)
})
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(strings.Repeat("a", 30)))
w := httptest.NewRecorder()
foo(w, r)
if len(b) != 20 {
t.Fatal(len(b))
}
t.Log(len(b))
}
func TestMiddlewareShift(t *testing.T) {
cases := map[string]struct {
input string
output string
}{
"root": {
input: "/",
output: "/",
},
"root w/ query param": {
input: "/?a=b",
output: "/?a=b",
},
"base": {
input: "/base",
output: "/",
},
"base w/ query param": {
input: "/base?a=b",
output: "/?a=b",
},
"sub": {
input: "/base/sub",
output: "/sub",
},
"sub w/ query param": {
input: "/base/sub?a=b",
output: "/sub?a=b",
},
}
for name, d := range cases {
c := d
t.Run(name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, c.input, nil)
w := httptest.NewRecorder()
rest := &REST{}
foo := rest.shift(func(http.ResponseWriter, *http.Request) {})
log.Printf("%+v", r.URL)
foo(w, r)
if r.URL.String() != c.output {
t.Fatalf("from %q, want %q, got %q", c.input, c.output, r.URL.String())
}
})
}
}