test middleware
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
|||||||
"local/dndex/server/auth"
|
"local/dndex/server/auth"
|
||||||
"local/gziphttp"
|
"local/gziphttp"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,6 +63,9 @@ func (rest *REST) shift(foo http.HandlerFunc) http.HandlerFunc {
|
|||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
r.URL.Path = r.URL.Path[i:]
|
r.URL.Path = r.URL.Path[i:]
|
||||||
|
if !strings.HasPrefix(r.URL.Path, "/") {
|
||||||
|
r.URL.Path = "/" + r.URL.Path
|
||||||
|
}
|
||||||
foo(w, r)
|
foo(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
134
server/middleware_test.go
Normal file
134
server/middleware_test.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMiddlewareScoped(t *testing.T) {
|
||||||
|
rest, doauth, clean := testREST(t)
|
||||||
|
defer clean()
|
||||||
|
|
||||||
|
foo := rest.scoped(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/entities/"+testEntityID, nil)
|
||||||
|
doauth(r)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
foo(w, r)
|
||||||
|
|
||||||
|
scope := rest.scope(r)
|
||||||
|
if scope.Namespace == "" {
|
||||||
|
t.Fatalf("scope ns: %q", scope.Namespace)
|
||||||
|
}
|
||||||
|
if scope.EntityID == "" {
|
||||||
|
t.Fatalf("entity id: %q", scope.EntityID)
|
||||||
|
}
|
||||||
|
if scope.EntityName == "" {
|
||||||
|
t.Fatalf("entity name: %q", scope.EntityName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareDelay(t *testing.T) {
|
||||||
|
os.Setenv("DELAY", "100ms")
|
||||||
|
defer os.Setenv("DELAY", "0ms")
|
||||||
|
rest, _, clean := testREST(t)
|
||||||
|
defer clean()
|
||||||
|
|
||||||
|
foo := rest.delay(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
foo(w, r)
|
||||||
|
|
||||||
|
if time.Since(start) < time.Millisecond*100 {
|
||||||
|
t.Fatal(time.Since(start))
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, can := context.WithCancel(context.TODO())
|
||||||
|
can()
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
start = time.Now()
|
||||||
|
foo(w, r)
|
||||||
|
if time.Since(start) > time.Millisecond*80 {
|
||||||
|
t.Fatal(time.Since(start))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareDefend(t *testing.T) {
|
||||||
|
os.Setenv("MAX_FILE_SIZE", "20")
|
||||||
|
defer os.Unsetenv("MAX_FILE_SIZE")
|
||||||
|
rest, _, clean := testREST(t)
|
||||||
|
defer clean()
|
||||||
|
|
||||||
|
var b []byte
|
||||||
|
foo := rest.defend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
b, _ = ioutil.ReadAll(r.Body)
|
||||||
|
})
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(strings.Repeat("a", 30)))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
foo(w, r)
|
||||||
|
|
||||||
|
if len(b) != 20 {
|
||||||
|
t.Fatal(len(b))
|
||||||
|
}
|
||||||
|
t.Log(len(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareShift(t *testing.T) {
|
||||||
|
cases := map[string]struct {
|
||||||
|
input string
|
||||||
|
output string
|
||||||
|
}{
|
||||||
|
"root": {
|
||||||
|
input: "/",
|
||||||
|
output: "/",
|
||||||
|
},
|
||||||
|
"root w/ query param": {
|
||||||
|
input: "/?a=b",
|
||||||
|
output: "/?a=b",
|
||||||
|
},
|
||||||
|
"base": {
|
||||||
|
input: "/base",
|
||||||
|
output: "/",
|
||||||
|
},
|
||||||
|
"base w/ query param": {
|
||||||
|
input: "/base?a=b",
|
||||||
|
output: "/?a=b",
|
||||||
|
},
|
||||||
|
"sub": {
|
||||||
|
input: "/base/sub",
|
||||||
|
output: "/sub",
|
||||||
|
},
|
||||||
|
"sub w/ query param": {
|
||||||
|
input: "/base/sub?a=b",
|
||||||
|
output: "/sub?a=b",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, d := range cases {
|
||||||
|
c := d
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, c.input, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rest := &REST{}
|
||||||
|
foo := rest.shift(func(http.ResponseWriter, *http.Request) {})
|
||||||
|
log.Printf("%+v", r.URL)
|
||||||
|
foo(w, r)
|
||||||
|
if r.URL.String() != c.output {
|
||||||
|
t.Fatalf("from %q, want %q, got %q", c.input, c.output, r.URL.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -59,5 +59,5 @@ func NewREST(g storage.RateLimitedGraph) (*REST, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rest *REST) scope(r *http.Request) auth.Scope {
|
func (rest *REST) scope(r *http.Request) auth.Scope {
|
||||||
return auth.GetScope(r)
|
return auth.GetScope(r, rest.g)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user