From 1a06d9634b247df39b9b4d76b737df36b0c7c22d Mon Sep 17 00:00:00 2001 From: breel Date: Sat, 8 Aug 2020 09:42:02 -0600 Subject: [PATCH] Test auth and add scope --- server/auth/const.go | 1 + server/auth/generate.go | 24 +++++- server/auth/generate_test.go | 46 +++++++++++ server/auth/scope.go | 75 ++++++++++++++++++ server/auth/scope_test.go | 146 +++++++++++++++++++++++++++++++++++ server/middleware.go | 4 +- server/rest.go | 15 +--- 7 files changed, 294 insertions(+), 17 deletions(-) create mode 100644 server/auth/scope.go create mode 100644 server/auth/scope_test.go diff --git a/server/auth/const.go b/server/auth/const.go index 4030cad..ad0170c 100644 --- a/server/auth/const.go +++ b/server/auth/const.go @@ -4,4 +4,5 @@ const ( AuthKey = "DnDex-Auth" UserKey = "DnDex-User" NewAuthKey = "New-" + AuthKey + ScopeKey = "Scope" ) diff --git a/server/auth/generate.go b/server/auth/generate.go index 4f50627..e1cc8a6 100644 --- a/server/auth/generate.go +++ b/server/auth/generate.go @@ -9,19 +9,35 @@ import ( "github.com/google/uuid" ) -func Generate(g storage.RateLimitedGraph, r *http.Request, salt string) (string, error) { - namespaceRequested := readRequestedNamespace(r) - key, err := getKeyForNamespace(r.Context(), g, namespaceRequested) +func GeneratePlain(g storage.RateLimitedGraph, r *http.Request) (string, error) { + token, _, err := generateToken(g, r) if err != nil { return "", err } - token, err := makeTokenForNamespace(r.Context(), g, namespaceRequested) + return token.Obfuscate() +} + +func Generate(g storage.RateLimitedGraph, r *http.Request, salt string) (string, error) { + token, key, err := generateToken(g, r) if err != nil { return "", err } return token.Encode(salt + key) } +func generateToken(g storage.RateLimitedGraph, r *http.Request) (Token, string, error) { + namespaceRequested := readRequestedNamespace(r) + key, err := getKeyForNamespace(r.Context(), g, namespaceRequested) + if err != nil { + return Token{}, "", err + } + token, err := makeTokenForNamespace(r.Context(), g, namespaceRequested) + if err != nil { + return Token{}, "", err + } + return token, key, nil +} + func readRequestedNamespace(r *http.Request) string { return readRequested(r, UserKey) } diff --git a/server/auth/generate_test.go b/server/auth/generate_test.go index 22b5480..cec99fc 100644 --- a/server/auth/generate_test.go +++ b/server/auth/generate_test.go @@ -2,6 +2,7 @@ package auth import ( "context" + "io" "local/dndex/storage" "local/dndex/storage/entity" "net/http" @@ -48,4 +49,49 @@ func TestGenerate(t *testing.T) { t.Fatal(err) } }) + + t.Run("ok plain", func(t *testing.T) { + g, r, _ := fresh() + obf, err := GeneratePlain(g, r) + if err != nil { + t.Fatal(err) + } + var token Token + if err := token.Deobfuscate(obf); err != nil { + t.Fatal(err) + } + }) + + t.Run("404", func(t *testing.T) { + g, r, _ := fresh() + r.Body = struct { + io.Reader + io.Closer + }{ + Reader: strings.NewReader(UserKey + "=" + uuid.New().String()), + Closer: r.Body, + } + r.ParseForm() + salt := uuid.New().String() + _, err := Generate(g, r, salt) + if err == nil { + t.Fatal(err) + } + }) + + t.Run("404 plain", func(t *testing.T) { + g, r, _ := fresh() + r.Body = struct { + io.Reader + io.Closer + }{ + Reader: strings.NewReader(UserKey + "=" + uuid.New().String()), + Closer: r.Body, + } + r.ParseForm() + _, err := GeneratePlain(g, r) + if err == nil { + t.Fatal(err) + } + }) } diff --git a/server/auth/scope.go b/server/auth/scope.go new file mode 100644 index 0000000..9c38aa3 --- /dev/null +++ b/server/auth/scope.go @@ -0,0 +1,75 @@ +package auth + +import ( + "context" + "local/dndex/storage" + "net/http" + "strings" +) + +type Scope struct { + Namespace string + EntityName string + EntityID string +} + +func GetScope(r *http.Request, gs ...storage.RateLimitedGraph) Scope { + scope := Scope{} + if ok := scope.fromCtx(r.Context()); ok { + return scope + } + scope.fromToken(r) + scope.fromPath(r.URL.Path) + for _, g := range gs { + scope.fromGraph(r.Context(), g) + } + reqWithContextValue(r, ScopeKey, scope) + return scope +} + +func reqWithContextValue(r *http.Request, k string, v interface{}) { + *r = *(r.WithContext(context.WithValue(r.Context(), k, v))) +} + +func (s *Scope) fromCtx(ctx context.Context) bool { + var ok bool + *s, ok = ctx.Value(ScopeKey).(Scope) + return ok +} + +func (s *Scope) fromToken(r *http.Request) bool { + token, ok := getToken(r) + if ok { + s.Namespace = token.Namespace + } + return ok +} + +func (s *Scope) fromPath(path string) bool { + if !strings.HasPrefix(path, "/entity/") { + return false + } + paths := strings.Split(path, "/") + if len(paths) < 2 { + return false + } + path = paths[2] + path = strings.Split(path, "#")[0] + if len(path) == 0 { + return false + } + s.EntityID = path + return true +} + +func (s *Scope) fromGraph(ctx context.Context, g storage.RateLimitedGraph) bool { + if s.EntityID == "" { + return false + } + one, err := g.Get(ctx, s.Namespace, s.EntityID) + ok := err == nil && one.Name != "" + if ok { + s.EntityName = one.Name + } + return ok +} diff --git a/server/auth/scope_test.go b/server/auth/scope_test.go new file mode 100644 index 0000000..c40f5d1 --- /dev/null +++ b/server/auth/scope_test.go @@ -0,0 +1,146 @@ +package auth + +import ( + "context" + "local/dndex/storage" + "local/dndex/storage/entity" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" +) + +func TestScopeFromCtx(t *testing.T) { + var scope Scope + if ok := scope.fromCtx(context.TODO()); ok { + t.Fatal(ok) + } + + ctxScope := Scope{ + Namespace: uuid.New().String(), + EntityName: uuid.New().String(), + EntityID: uuid.New().String(), + } + ctx := context.WithValue(context.TODO(), ScopeKey, ctxScope) + if ok := scope.fromCtx(ctx); !ok { + t.Fatal(ok) + } else if ctxScope != scope { + t.Fatal(scope) + } +} + +func TestScopeFromToken(t *testing.T) { + var scope Scope + + r := httptest.NewRequest(http.MethodGet, "/", nil) + if ok := scope.fromToken(r); ok { + t.Fatal(ok) + } + + token := Token{ + Namespace: uuid.New().String(), + } + obf, _ := token.Obfuscate() + r.AddCookie(&http.Cookie{ + Name: AuthKey, + Value: obf, + }) + if ok := scope.fromToken(r); !ok { + t.Fatal(ok) + } +} + +func TestScopeFromPath(t *testing.T) { + cases := map[string]struct { + ok bool + id string + }{ + "/": {}, + "/hello": {}, + "/hello/": {}, + "/hello/entity": {}, + "/entity": {}, + "/entity/": {}, + "/entity/id": { + ok: true, + id: "id", + }, + "/entity/id/": { + ok: true, + id: "id", + }, + "/entity/id/excess": { + ok: true, + id: "id", + }, + "/entity/id#excess": { + ok: true, + id: "id", + }, + "/entity/id?excess": { + ok: true, + id: "id", + }, + } + + for name, d := range cases { + c := d + path := name + t.Run(name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, path, nil) + var scope Scope + ok := scope.fromPath(r.URL.Path) + if ok != c.ok { + t.Fatal(c.ok, ok) + } + if scope.EntityID != c.id { + t.Fatalf("want %q, got %q", c.id, scope.EntityID) + } + }) + } +} + +// func (s *Scope) fromGraph(ctx context.Context, g storage.RateLimitedGraph) bool { +func TestScopeFromGraph(t *testing.T) { + namespace := uuid.New().String() + id := uuid.New().String() + name := uuid.New().String() + + gEmpty := storage.NewRateLimitedGraph() + gOne := storage.NewRateLimitedGraph() + if err := gOne.Insert(context.TODO(), namespace, entity.One{ID: id, Name: name}); err != nil { + t.Fatal(err) + } + + scope := Scope{ + Namespace: namespace, + } + + if ok := scope.fromGraph(context.TODO(), gEmpty); ok { + t.Fatal(ok) + } else if scope.EntityName != "" { + t.Fatal(scope) + } + + if ok := scope.fromGraph(context.TODO(), gOne); ok { + t.Fatal(ok) + } else if scope.EntityName != "" { + t.Fatal(scope) + } + + scope.EntityID = id + + if ok := scope.fromGraph(context.TODO(), gEmpty); ok { + t.Fatal(ok) + } else if scope.EntityName != "" { + t.Fatal(scope) + } + + if ok := scope.fromGraph(context.TODO(), gOne); !ok { + t.Fatal(ok) + } else if scope.EntityName != name { + t.Fatal(scope) + } + +} diff --git a/server/middleware.go b/server/middleware.go index 0813392..be9c188 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -3,6 +3,7 @@ package server import ( "io" "local/dndex/config" + "local/dndex/server/auth" "local/gziphttp" "net/http" "time" @@ -39,7 +40,8 @@ func (rest *REST) defend(foo http.HandlerFunc) http.HandlerFunc { func (rest *REST) auth(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if err := Auth(rest.g, w, r); err != nil { + if err := auth.Verify(rest.g, w, r); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) return } foo(w, r) diff --git a/server/rest.go b/server/rest.go index fafbebe..ef18786 100644 --- a/server/rest.go +++ b/server/rest.go @@ -3,6 +3,7 @@ package server import ( "fmt" "local/dndex/config" + "local/dndex/server/auth" "local/dndex/storage" "local/router" "net/http" @@ -14,15 +15,6 @@ type REST struct { g storage.RateLimitedGraph } -type RESTScope struct { - entity scope - user scope -} -type scope struct { - name string - id string -} - func Listen(g storage.RateLimitedGraph) error { rest, err := NewREST(g) if err != nil { @@ -62,9 +54,8 @@ func NewREST(g storage.RateLimitedGraph) (*REST, error) { return rest, nil } -func (rest *REST) scope(r *http.Request) RESTScope { - value, _ := r.Context().Value(AuthKey).(RESTScope) - return value +func (rest *REST) scope(r *http.Request) auth.Scope { + return auth.GetScope(r) } func (rest *REST) files(w http.ResponseWriter, _ *http.Request) {