From d3ac4f5c22fc2f313473b3456c507ce81c19c56f Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Fri, 24 Jul 2020 14:45:03 -0600 Subject: [PATCH] Auth implemented ish --- config/config.go | 32 +++++---- view/auth.go | 134 ++++++++++++++++++++++++++++++++---- view/auth_test.go | 169 ++++++++++++++++++++++++++++++++++++++++++++++ view/json.go | 38 ++++++++--- view/who.go | 52 ++++++++------ 5 files changed, 366 insertions(+), 59 deletions(-) create mode 100644 view/auth_test.go diff --git a/config/config.go b/config/config.go index b976de7..73c6877 100644 --- a/config/config.go +++ b/config/config.go @@ -4,16 +4,18 @@ import ( "io/ioutil" "local/args" "os" + "time" ) type Config struct { - Port int - DBURI string - Database string - DriverType string - FilePrefix string - FileRoot string - Auth bool + Port int + DBURI string + Database string + DriverType string + FilePrefix string + FileRoot string + Auth bool + AuthLifetime time.Duration } func New() Config { @@ -32,18 +34,20 @@ func New() Config { as.Append(args.STRING, "database", "database name to use", "db") as.Append(args.STRING, "drivertype", "database driver to use", "boltdb") as.Append(args.BOOL, "auth", "check for authorized access", false) + as.Append(args.DURATION, "authlifetime", "duration auth is valid for", time.Hour) if err := as.Parse(); err != nil { panic(err) } return Config{ - Port: as.GetInt("p"), - DBURI: as.GetString("dburi"), - FilePrefix: as.GetString("fileprefix"), - FileRoot: as.GetString("fileroot"), - Database: as.GetString("database"), - DriverType: as.GetString("drivertype"), - Auth: as.GetBool("auth"), + Port: as.GetInt("p"), + DBURI: as.GetString("dburi"), + FilePrefix: as.GetString("fileprefix"), + FileRoot: as.GetString("fileroot"), + Database: as.GetString("database"), + DriverType: as.GetString("drivertype"), + Auth: as.GetBool("auth"), + AuthLifetime: as.GetDuration("authlifetime"), } } diff --git a/view/auth.go b/view/auth.go index 55a8a88..2504515 100644 --- a/view/auth.go +++ b/view/auth.go @@ -1,15 +1,27 @@ package view import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" "encoding/json" "errors" + "io" "local/dndex/config" "local/dndex/storage" + "local/dndex/storage/entity" "net/http" + "strings" + "time" + + "github.com/google/uuid" ) const ( - AuthKey = "DnDex-Auth" + AuthKey = "DnDex-Auth" + NewAuthKey = "New-" + AuthKey + UserKey = "DnDex-User" ) func Auth(g storage.Graph, w http.ResponseWriter, r *http.Request) error { @@ -25,27 +37,119 @@ func Auth(g storage.Graph, w http.ResponseWriter, r *http.Request) error { func auth(g storage.Graph, w http.ResponseWriter, r *http.Request) error { if !hasAuth(r) { - if err := requestAuth(g, w, r); err != nil { - return err - } - return errors.New("auth requested") + return requestAuth(g, w, r) } - return checkAuth(g, r) + return checkAuth(g, w, r) } func hasAuth(r *http.Request) bool { - _, ok := r.Cookie(AuthKey) - return ok == nil + _, err := r.Cookie(AuthKey) + return err == nil } -func checkAuth(g storage.Graph, r *http.Request) error { - panic(nil) - /* - token, _ := r.Cookie(AuthKey) - return errors.New("not impl") - */ +func checkAuth(g storage.Graph, w http.ResponseWriter, r *http.Request) error { + namespace, err := getAuthNamespace(r) + if err != nil { + return err + } + token, _ := r.Cookie(AuthKey) + results, err := g.List(r.Context(), namespace, token.Value) + if err != nil { + return err + } + if len(results) != 1 { + return requestAuth(g, w, r) + } + modified := time.Unix(0, results[0].Modified) + if time.Since(modified) > config.New().AuthLifetime { + return requestAuth(g, w, r) + } + return nil } func requestAuth(g storage.Graph, w http.ResponseWriter, r *http.Request) error { - return errors.New("not impl") + namespace, err := getAuthNamespace(r) + if err != nil { + http.Error(w, `{"error": "namespace required"}`, http.StatusBadRequest) + return err + } + + ones, err := g.List(r.Context(), namespace, UserKey) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return err + } + if len(ones) != 1 { + http.NotFound(w, r) + return errors.New("namespace not established") + } + userKey := ones[0] + + token := entity.One{ + Name: uuid.New().String(), + Title: namespace, + } + if err := g.Insert(r.Context(), namespace, token); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return err + } + + encodedToken, err := aesEnc(userKey.Title, token.Name) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return err + } + http.SetCookie(w, &http.Cookie{Name: NewAuthKey, Value: encodedToken}) + + http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) + return errors.New("auth requested") +} + +func aesEnc(key, payload string) (string, error) { + if len(key) == 0 { + return "", errors.New("key required") + } + key = strings.Repeat(key, 32)[:32] + + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + b := gcm.Seal(nonce, nonce, []byte(payload), nil) + + return base64.StdEncoding.EncodeToString(b), nil +} + +func aesDec(key, payload string) (string, error) { + if len(key) == 0 { + return "", errors.New("key required") + } + key = strings.Repeat(key, 32)[:32] + + ciphertext, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return "", err + } + + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + if len(ciphertext) < gcm.NonceSize() { + return "", errors.New("short ciphertext") + } + b, err := gcm.Open(nil, ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():], nil) + return string(b), err } diff --git a/view/auth_test.go b/view/auth_test.go new file mode 100644 index 0000000..032ad41 --- /dev/null +++ b/view/auth_test.go @@ -0,0 +1,169 @@ +package view + +import ( + "context" + "fmt" + "io/ioutil" + "local/dndex/storage" + "local/dndex/storage/entity" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/google/uuid" +) + +func TestAuth(t *testing.T) { + os.Args = os.Args[:1] + f, err := ioutil.TempFile(os.TempDir(), "pattern*") + if err != nil { + t.Fatal(err) + } + f.Close() + defer os.Remove(f.Name()) + os.Setenv("DBURI", f.Name()) + + os.Setenv("AUTH", "true") + defer os.Setenv("AUTH", "false") + + g := storage.NewGraph() + handler := jsonHandler(g) + + if err := g.Insert(context.TODO(), "col."+AuthKey, entity.One{Name: UserKey, Title: "password"}); err != nil { + t.Fatal(err) + } + + t.Run("auth: no namespace", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/who", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusBadRequest { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: bad provided", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/who?namespace=col", nil) + r.Header.Set("Cookie", fmt.Sprintf("%s=not-a-real-token", AuthKey)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusSeeOther { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: expired provided", func(t *testing.T) { + os.Setenv("AUTHLIFETIME", "1ms") + defer os.Setenv("AUTHLIFETIME", "1h") + one := entity.One{Name: uuid.New().String(), Title: "title"} + if err := g.Insert(context.TODO(), "col", one); err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 50) + r := httptest.NewRequest(http.MethodGet, "/who?namespace=col", nil) + r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, one.Name)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusSeeOther { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: none provided", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/who?namespace=col", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusSeeOther { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: provided", func(t *testing.T) { + os.Setenv("AUTHLIFETIME", "1h") + one := entity.One{Name: uuid.New().String(), Title: "title"} + if err := g.Insert(context.TODO(), "col."+AuthKey, one); err != nil { + t.Fatal(err) + } + r := httptest.NewRequest(http.MethodTrace, "/who?namespace=col", nil) + r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, one.Name)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: request unknown namespace", func(t *testing.T) { + os.Setenv("AUTHLIFETIME", "1h") + r := httptest.NewRequest(http.MethodTrace, "/who?namespace=not-col", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusNotFound { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: request", func(t *testing.T) { + os.Setenv("AUTHLIFETIME", "1h") + r := httptest.NewRequest(http.MethodTrace, "/who?namespace=col", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusSeeOther { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + cookies := w.Header()["Set-Cookie"] + if len(cookies) == 0 { + t.Fatal(w.Header()) + } + var rawtoken string + for i := range cookies { + value := strings.Split(cookies[i], ";")[0] + key := value[:strings.Index(value, "=")] + value = value[strings.Index(value, "=")+1:] + if key == NewAuthKey { + rawtoken = value + } + } + if rawtoken == "" { + t.Fatal(w.Header()) + } + + token, err := aesDec("password", rawtoken) + if err != nil { + t.Fatal(err) + } + + r = httptest.NewRequest(http.MethodTrace, "/who?namespace=col", nil) + w = httptest.NewRecorder() + r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, token)) + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) +} + +func TestAES(t *testing.T) { + for _, plaintext := range []string{"", "payload!", "a really long payload here"} { + key := "password" + + enc, err := aesEnc(key, plaintext) + if err != nil { + t.Fatal("cannot enc:", err) + } + if enc == plaintext { + t.Fatal(enc) + } + + dec, err := aesDec(key, enc) + if err != nil { + t.Fatal("cannot dec:", err) + } + if dec != plaintext { + t.Fatalf("want decrypted %q, got %q", plaintext, dec) + } + } +} diff --git a/view/json.go b/view/json.go index 14a1012..0333301 100644 --- a/view/json.go +++ b/view/json.go @@ -2,6 +2,7 @@ package view import ( "encoding/json" + "errors" "fmt" "local/dndex/config" "local/dndex/storage" @@ -22,22 +23,31 @@ func jsonHandler(g storage.Graph) http.Handler { mux := http.NewServeMux() routes := []struct { - path string - foo func(g storage.Graph, w http.ResponseWriter, r *http.Request) error + path string + foo func(g storage.Graph, w http.ResponseWriter, r *http.Request) error + noauth bool }{ { path: "/who", foo: who, }, { - path: config.New().FilePrefix + "/", - foo: files, + path: config.New().FilePrefix + "/", + foo: files, + noauth: true, }, } for _, route := range routes { foo := route.foo + auth := !route.noauth mux.HandleFunc(route.path, func(w http.ResponseWriter, r *http.Request) { + if auth { + if err := Auth(g, w, r); err != nil { + log.Println(err) + return + } + } if err := foo(g, w, r); err != nil { status := http.StatusInternalServerError if strings.Contains(err.Error(), "collision") { @@ -50,10 +60,6 @@ func jsonHandler(g storage.Graph) http.Handler { } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := Auth(g, w, r); err != nil { - log.Println(err) - return - } if gziphttp.Can(r) { gz := gziphttp.New(w) defer gz.Close() @@ -62,3 +68,19 @@ func jsonHandler(g storage.Graph) http.Handler { mux.ServeHTTP(w, r) }) } + +func getAuthNamespace(r *http.Request) (string, error) { + namespace := r.URL.Query().Get("namespace") + if len(namespace) == 0 { + return "", errors.New("no namespace found") + } + return strings.Join([]string{namespace, AuthKey}, "."), nil +} + +func getNamespace(r *http.Request) (string, error) { + namespace := r.URL.Query().Get("namespace") + if len(namespace) == 0 { + return "", errors.New("no namespace found") + } + return namespace, nil +} diff --git a/view/who.go b/view/who.go index b352a7f..da2e680 100644 --- a/view/who.go +++ b/view/who.go @@ -15,8 +15,8 @@ import ( ) func who(g storage.Graph, w http.ResponseWriter, r *http.Request) error { - namespace := r.URL.Query().Get("namespace") - if len(namespace) == 0 { + namespace, err := getNamespace(r) + if err != nil { http.NotFound(w, r) return nil } @@ -41,10 +41,10 @@ func who(g storage.Graph, w http.ResponseWriter, r *http.Request) error { } func whoGet(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { - id := r.URL.Query().Get("id") - if id == "" { - http.Error(w, `{"error":"no ?id provided"}`, http.StatusBadRequest) - return nil + id, err := getID(r) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) } _, light := r.URL.Query()["light"] @@ -78,10 +78,10 @@ func whoGet(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Re } func whoPut(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { - id := r.URL.Query().Get("id") - if id == "" { - http.Error(w, `{"error":"no ?id provided"}`, http.StatusBadRequest) - return nil + id, err := getID(r) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) } body, err := ioutil.ReadAll(r.Body) @@ -121,10 +121,10 @@ func whoPut(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Re } func whoPost(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { - id := r.URL.Query().Get("id") - if id == "" { - http.Error(w, `{"error":"no ?id provided"}`, http.StatusBadRequest) - return nil + id, err := getID(r) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) } one := entity.One{} @@ -140,10 +140,10 @@ func whoPost(namespace string, g storage.Graph, w http.ResponseWriter, r *http.R } func whoDelete(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { - id := r.URL.Query().Get("id") - if id == "" { - http.Error(w, `{"error":"no ?id provided"}`, http.StatusBadRequest) - return nil + id, err := getID(r) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) } if err := g.Delete(r.Context(), namespace, entity.One{Name: id}); err != nil { @@ -154,10 +154,10 @@ func whoDelete(namespace string, g storage.Graph, w http.ResponseWriter, r *http } func whoPatch(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { - id := r.URL.Query().Get("id") - if id == "" { - http.Error(w, `{"error":"no ?id provided"}`, http.StatusBadRequest) - return nil + id, err := getID(r) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) } one := entity.One{} @@ -194,3 +194,11 @@ func whoTrace(namespace string, g storage.Graph, w http.ResponseWriter, r *http. enc.SetIndent("", " ") return enc.Encode(names) } + +func getID(r *http.Request) (string, error) { + id := r.URL.Query().Get("id") + if id == "" { + return "", errors.New("no id provided") + } + return id, nil +}