test middleware
parent
2174361d8e
commit
64772166cc
|
|
@ -6,6 +6,7 @@ import (
|
|||
"local/dndex/server/auth"
|
||||
"local/gziphttp"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -62,6 +63,9 @@ func (rest *REST) shift(foo http.HandlerFunc) http.HandlerFunc {
|
|||
i++
|
||||
}
|
||||
r.URL.Path = r.URL.Path[i:]
|
||||
if !strings.HasPrefix(r.URL.Path, "/") {
|
||||
r.URL.Path = "/" + r.URL.Path
|
||||
}
|
||||
foo(w, r)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMiddlewareScoped(t *testing.T) {
|
||||
rest, doauth, clean := testREST(t)
|
||||
defer clean()
|
||||
|
||||
foo := rest.scoped(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/entities/"+testEntityID, nil)
|
||||
doauth(r)
|
||||
w := httptest.NewRecorder()
|
||||
foo(w, r)
|
||||
|
||||
scope := rest.scope(r)
|
||||
if scope.Namespace == "" {
|
||||
t.Fatalf("scope ns: %q", scope.Namespace)
|
||||
}
|
||||
if scope.EntityID == "" {
|
||||
t.Fatalf("entity id: %q", scope.EntityID)
|
||||
}
|
||||
if scope.EntityName == "" {
|
||||
t.Fatalf("entity name: %q", scope.EntityName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareDelay(t *testing.T) {
|
||||
os.Setenv("DELAY", "100ms")
|
||||
defer os.Setenv("DELAY", "0ms")
|
||||
rest, _, clean := testREST(t)
|
||||
defer clean()
|
||||
|
||||
foo := rest.delay(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
start := time.Now()
|
||||
foo(w, r)
|
||||
|
||||
if time.Since(start) < time.Millisecond*100 {
|
||||
t.Fatal(time.Since(start))
|
||||
}
|
||||
|
||||
ctx, can := context.WithCancel(context.TODO())
|
||||
can()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
start = time.Now()
|
||||
foo(w, r)
|
||||
if time.Since(start) > time.Millisecond*80 {
|
||||
t.Fatal(time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareDefend(t *testing.T) {
|
||||
os.Setenv("MAX_FILE_SIZE", "20")
|
||||
defer os.Unsetenv("MAX_FILE_SIZE")
|
||||
rest, _, clean := testREST(t)
|
||||
defer clean()
|
||||
|
||||
var b []byte
|
||||
foo := rest.defend(func(w http.ResponseWriter, r *http.Request) {
|
||||
b, _ = ioutil.ReadAll(r.Body)
|
||||
})
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(strings.Repeat("a", 30)))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
foo(w, r)
|
||||
|
||||
if len(b) != 20 {
|
||||
t.Fatal(len(b))
|
||||
}
|
||||
t.Log(len(b))
|
||||
}
|
||||
|
||||
func TestMiddlewareShift(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
input string
|
||||
output string
|
||||
}{
|
||||
"root": {
|
||||
input: "/",
|
||||
output: "/",
|
||||
},
|
||||
"root w/ query param": {
|
||||
input: "/?a=b",
|
||||
output: "/?a=b",
|
||||
},
|
||||
"base": {
|
||||
input: "/base",
|
||||
output: "/",
|
||||
},
|
||||
"base w/ query param": {
|
||||
input: "/base?a=b",
|
||||
output: "/?a=b",
|
||||
},
|
||||
"sub": {
|
||||
input: "/base/sub",
|
||||
output: "/sub",
|
||||
},
|
||||
"sub w/ query param": {
|
||||
input: "/base/sub?a=b",
|
||||
output: "/sub?a=b",
|
||||
},
|
||||
}
|
||||
|
||||
for name, d := range cases {
|
||||
c := d
|
||||
t.Run(name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, c.input, nil)
|
||||
w := httptest.NewRecorder()
|
||||
rest := &REST{}
|
||||
foo := rest.shift(func(http.ResponseWriter, *http.Request) {})
|
||||
log.Printf("%+v", r.URL)
|
||||
foo(w, r)
|
||||
if r.URL.String() != c.output {
|
||||
t.Fatalf("from %q, want %q, got %q", c.input, c.output, r.URL.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -59,5 +59,5 @@ func NewREST(g storage.RateLimitedGraph) (*REST, error) {
|
|||
}
|
||||
|
||||
func (rest *REST) scope(r *http.Request) auth.Scope {
|
||||
return auth.GetScope(r)
|
||||
return auth.GetScope(r, rest.g)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue