135 lines
2.6 KiB
Go
135 lines
2.6 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestMiddlewareScoped(t *testing.T) {
|
|
rest, doauth, clean := testREST(t)
|
|
defer clean()
|
|
|
|
foo := rest.scoped(func(w http.ResponseWriter, r *http.Request) {})
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/entities/"+testEntityID, nil)
|
|
doauth(r)
|
|
w := httptest.NewRecorder()
|
|
foo(w, r)
|
|
|
|
scope := rest.scope(r)
|
|
if scope.Namespace == "" {
|
|
t.Fatalf("scope ns: %q", scope.Namespace)
|
|
}
|
|
if scope.EntityID == "" {
|
|
t.Fatalf("entity id: %q", scope.EntityID)
|
|
}
|
|
if scope.EntityName == "" {
|
|
t.Fatalf("entity name: %q", scope.EntityName)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareDelay(t *testing.T) {
|
|
os.Setenv("DELAY", "100ms")
|
|
defer os.Setenv("DELAY", "0ms")
|
|
rest, _, clean := testREST(t)
|
|
defer clean()
|
|
|
|
foo := rest.delay(func(w http.ResponseWriter, r *http.Request) {})
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
start := time.Now()
|
|
foo(w, r)
|
|
|
|
if time.Since(start) < time.Millisecond*100 {
|
|
t.Fatal(time.Since(start))
|
|
}
|
|
|
|
ctx, can := context.WithCancel(context.Background())
|
|
can()
|
|
r = r.WithContext(ctx)
|
|
|
|
start = time.Now()
|
|
foo(w, r)
|
|
if time.Since(start) > time.Millisecond*80 {
|
|
t.Fatal(time.Since(start))
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareDefend(t *testing.T) {
|
|
os.Setenv("MAX_FILE_SIZE", "20")
|
|
defer os.Unsetenv("MAX_FILE_SIZE")
|
|
rest, _, clean := testREST(t)
|
|
defer clean()
|
|
|
|
var b []byte
|
|
foo := rest.defend(func(w http.ResponseWriter, r *http.Request) {
|
|
b, _ = ioutil.ReadAll(r.Body)
|
|
})
|
|
|
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(strings.Repeat("a", 30)))
|
|
w := httptest.NewRecorder()
|
|
|
|
foo(w, r)
|
|
|
|
if len(b) != 20 {
|
|
t.Fatal(len(b))
|
|
}
|
|
t.Log(len(b))
|
|
}
|
|
|
|
func TestMiddlewareShift(t *testing.T) {
|
|
cases := map[string]struct {
|
|
input string
|
|
output string
|
|
}{
|
|
"root": {
|
|
input: "/",
|
|
output: "/",
|
|
},
|
|
"root w/ query param": {
|
|
input: "/?a=b",
|
|
output: "/?a=b",
|
|
},
|
|
"base": {
|
|
input: "/base",
|
|
output: "/",
|
|
},
|
|
"base w/ query param": {
|
|
input: "/base?a=b",
|
|
output: "/?a=b",
|
|
},
|
|
"sub": {
|
|
input: "/base/sub",
|
|
output: "/sub",
|
|
},
|
|
"sub w/ query param": {
|
|
input: "/base/sub?a=b",
|
|
output: "/sub?a=b",
|
|
},
|
|
}
|
|
|
|
for name, d := range cases {
|
|
c := d
|
|
t.Run(name, func(t *testing.T) {
|
|
r := httptest.NewRequest(http.MethodGet, c.input, nil)
|
|
w := httptest.NewRecorder()
|
|
rest := &REST{}
|
|
foo := rest.shift(func(http.ResponseWriter, *http.Request) {})
|
|
log.Printf("%+v", r.URL)
|
|
foo(w, r)
|
|
if r.URL.String() != c.output {
|
|
t.Fatalf("from %q, want %q, got %q", c.input, c.output, r.URL.String())
|
|
}
|
|
})
|
|
}
|
|
}
|