From 304956da7473edbdfbb5e377d27ff93c4329c1fd Mon Sep 17 00:00:00 2001 From: breel Date: Fri, 7 Aug 2020 16:15:52 -0600 Subject: [PATCH] Storage to uuids --- .DS_Store | Bin 0 -> 6148 bytes server/auth.go | 194 ++++++++++++++++++++++++++++++ server/auth_test.go | 184 ++++++++++++++++++++++++++++ server/const.go | 11 ++ server/middleware.go | 47 ++++++++ server/response.go | 16 +++ server/rest.go | 77 ++++++++++++ server/rest_test.go | 217 ++++++++++++++++++++++++++++++++++ server/version.go | 12 ++ storage/driver/boltdb.go | 4 +- storage/driver/boltdb_test.go | 34 ++---- storage/driver/map_test.go | 13 +- storage/entity/one.go | 79 ++++--------- storage/entity/one_test.go | 92 ++------------ storage/graph.go | 16 ++- storage/graph_test.go | 46 +++++-- storage/ratelimitedgraph.go | 6 +- 17 files changed, 862 insertions(+), 186 deletions(-) create mode 100644 .DS_Store create mode 100644 server/auth.go create mode 100644 server/auth_test.go create mode 100644 server/const.go create mode 100644 server/middleware.go create mode 100644 server/response.go create mode 100644 server/rest.go create mode 100644 server/rest_test.go create mode 100644 server/version.go diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ceedc6b7cf72e2b44d0a37f05b482114fcfff093 GIT binary patch literal 6148 zcmeHK&1%~~5Z<+2N2;NyP)GwgEcDQe4|ZGE^d?+8r$9nw>OujAbEhy?5+dG?XjgUGcfy&W@l!#--=x=V~lrlzsZ=z7&AZ-OC~hG z5bQ_YkdpQwa#!}(a24Uo$){(e7#(nhOzhj;N7SE>$lUh^NY)???2QcaQI%OJl42^8!*=R zm^o=2CGi*VuX0wggv0tS64+#ATNE)ai2L37ochH|=UH||9 literal 0 HcmV?d00001 diff --git a/server/auth.go b/server/auth.go new file mode 100644 index 0000000..80ea0fa --- /dev/null +++ b/server/auth.go @@ -0,0 +1,194 @@ +package server + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "io" + "local/dndex/config" + "local/dndex/storage" + "local/dndex/storage/entity" + "net/http" + "strings" + "time" + + "github.com/google/uuid" +) + +func Auth(g storage.RateLimitedGraph, w http.ResponseWriter, r *http.Request) error { + if !config.New().Auth { + return nil + } + if err := auth(g, w, r); err != nil { + json.NewEncoder(w).Encode(map[string]interface{}{"error": "error when authorizing: " + err.Error()}) + return err + } + return nil +} + +func auth(g storage.RateLimitedGraph, w http.ResponseWriter, r *http.Request) error { + if isPublic(g, r) { + return nil + } + if !hasAuth(r) { + return requestAuth(g, w, r) + } + return checkAuth(g, w, r) +} + +func isPublic(g storage.RateLimitedGraph, r *http.Request) bool { + namespace, err := getAuthNamespace(r) + + if err != nil { + return false + } + ones, err := g.List(r.Context(), namespace, UserKey) + if err != nil { + return false + } + if len(ones) == 0 { + return false + } + return ones[0].Title == "" +} + +func hasAuth(r *http.Request) bool { + _, err := r.Cookie(AuthKey) + return err == nil +} + +func checkAuth(g storage.RateLimitedGraph, w http.ResponseWriter, r *http.Request) error { + namespace, err := getAuthNamespace(r) + if err != nil { + return err + } + token, _ := r.Cookie(AuthKey) + results, err := g.List(r.Context(), namespace, token.Value) + if err != nil { + return err + } + if len(results) != 1 { + return requestAuth(g, w, r) + } + modified := time.Unix(0, results[0].Modified) + if time.Since(modified) > config.New().AuthLifetime { + return requestAuth(g, w, r) + } + return nil +} + +func requestAuth(g storage.RateLimitedGraph, w http.ResponseWriter, r *http.Request) error { + namespace, err := getAuthNamespace(r) + if err != nil { + http.Error(w, `{"error": "namespace required"}`, http.StatusBadRequest) + return err + } + + ones, err := g.List(r.Context(), namespace, UserKey) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return err + } + if len(ones) != 1 { + http.NotFound(w, r) + return errors.New("namespace not established") + } + userKey := ones[0] + + id := uuid.New().String() + token := entity.One{ + ID: id, + Name: id, + Title: namespace, + } + if err := g.Insert(r.Context(), namespace, token); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return err + } + + encodedToken, err := aesEnc(userKey.Title, token.Name) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return err + } + http.SetCookie(w, &http.Cookie{Name: NewAuthKey, Value: encodedToken}) + + http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) + return errors.New("auth requested") +} + +func aesEnc(key, payload string) (string, error) { + if len(key) == 0 { + return "", errors.New("key required") + } + key = strings.Repeat(key, 32)[:32] + + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + b := gcm.Seal(nonce, nonce, []byte(payload), nil) + + return base64.StdEncoding.EncodeToString(b), nil +} + +func aesDec(key, payload string) (string, error) { + if len(key) == 0 { + return "", errors.New("key required") + } + key = strings.Repeat(key, 32)[:32] + + ciphertext, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return "", err + } + + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + if len(ciphertext) < gcm.NonceSize() { + return "", errors.New("short ciphertext") + } + b, err := gcm.Open(nil, ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():], nil) + return string(b), err +} + +func getAuthNamespace(r *http.Request) (string, error) { + namespace, err := getNamespace(r) + return strings.Join([]string{namespace, AuthKey}, "."), err +} + +func getNamespace(r *http.Request) (string, error) { + if strings.HasPrefix(r.URL.Path, config.New().FilePrefix) { + path := strings.TrimPrefix(r.URL.Path, config.New().FilePrefix+"/") + if path == r.URL.Path { + return "", errors.New("no namespace on files") + } + path = strings.Split(path, "/")[0] + if path == "" { + return "", errors.New("empty namespace on files") + } + return path, nil + } + namespace := r.URL.Query().Get("namespace") + if len(namespace) == 0 { + return "", errors.New("no namespace found") + } + return namespace, nil +} diff --git a/server/auth_test.go b/server/auth_test.go new file mode 100644 index 0000000..10fe0be --- /dev/null +++ b/server/auth_test.go @@ -0,0 +1,184 @@ +package server + +import ( + "context" + "fmt" + "local/dndex/storage/entity" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/google/uuid" +) + +func TestAuth(t *testing.T) { + os.Args = os.Args[:1] + + rest, clean := testREST(t) + defer clean() + + handler := rest.router + g := rest.g + + os.Setenv("AUTH", "true") + defer os.Setenv("AUTH", "false") + + if err := g.Insert(context.TODO(), "col."+AuthKey, entity.One{ID: UserKey, Name: UserKey, Title: "password"}); err != nil { + t.Fatal(err) + } + + t.Run("auth: no namespace", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/who", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusBadRequest { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: bad provided", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/who?namespace=col", nil) + r.Header.Set("Cookie", fmt.Sprintf("%s=not-a-real-token", AuthKey)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusSeeOther { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: expired provided", func(t *testing.T) { + os.Setenv("AUTHLIFETIME", "1ms") + defer os.Setenv("AUTHLIFETIME", "1h") + one := entity.One{ID: uuid.New().String(), Name: uuid.New().String(), Title: "title"} + if err := g.Insert(context.TODO(), "col", one); err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 50) + r := httptest.NewRequest(http.MethodGet, "/who?namespace=col", nil) + r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, one.ID)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusSeeOther { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: none provided: who", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/who?namespace=col", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusSeeOther { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: none provided: files", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/__files__/col/myfile", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusSeeOther { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: provided", func(t *testing.T) { + os.Setenv("AUTHLIFETIME", "1h") + one := entity.One{ID: uuid.New().String(), Name: uuid.New().String(), Title: "title"} + if err := g.Insert(context.TODO(), "col."+AuthKey, one); err != nil { + t.Fatal(err) + } + r := httptest.NewRequest(http.MethodTrace, "/who?namespace=col", nil) + r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, one.Name)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: request unknown namespace", func(t *testing.T) { + os.Setenv("AUTHLIFETIME", "1h") + r := httptest.NewRequest(http.MethodTrace, "/who?namespace=not-col", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusNotFound { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("auth: request", func(t *testing.T) { + os.Setenv("AUTHLIFETIME", "1h") + r := httptest.NewRequest(http.MethodTrace, "/who?namespace=col", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusSeeOther { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + + rawtoken := getCookie(NewAuthKey, w.Header()) + if rawtoken == "" { + t.Fatal(w.Header()) + } + token, err := aesDec("password", rawtoken) + if err != nil { + t.Fatal(err) + } + + r = httptest.NewRequest(http.MethodTrace, "/who?namespace=col", nil) + w = httptest.NewRecorder() + r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, token)) + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + + r = httptest.NewRequest(http.MethodTrace, "/__files__/col/myfile", nil) + w = httptest.NewRecorder() + r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, token)) + handler.ServeHTTP(w, r) + if w.Code != http.StatusNotFound { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) +} + +func TestAES(t *testing.T) { + for _, plaintext := range []string{"", "payload!", "a really long payload here"} { + key := "password" + + enc, err := aesEnc(key, plaintext) + if err != nil { + t.Fatal("cannot enc:", err) + } + if enc == plaintext { + t.Fatal(enc) + } + + dec, err := aesDec(key, enc) + if err != nil { + t.Fatal("cannot dec:", err) + } + if dec != plaintext { + t.Fatalf("want decrypted %q, got %q", plaintext, dec) + } + } +} + +func getCookie(key string, header http.Header) string { + cookies, _ := header["Set-Cookie"] + if len(cookies) == 0 { + cookies, _ = header["Cookie"] + } + for i := range cookies { + value := strings.Split(cookies[i], ";")[0] + k := value[:strings.Index(value, "=")] + v := value[strings.Index(value, "=")+1:] + if k == key { + return v + } + } + return "" +} diff --git a/server/const.go b/server/const.go new file mode 100644 index 0000000..5bc92f4 --- /dev/null +++ b/server/const.go @@ -0,0 +1,11 @@ +package server + +const ( + AuthKey = "DnDex-Auth" + UserKey = "DnDex-User" +) + +var ( + NewAuthKey = "New-" + AuthKey + GitCommit string +) diff --git a/server/middleware.go b/server/middleware.go new file mode 100644 index 0000000..0813392 --- /dev/null +++ b/server/middleware.go @@ -0,0 +1,47 @@ +package server + +import ( + "io" + "local/dndex/config" + "local/gziphttp" + "net/http" + "time" +) + +func (rest *REST) delay(foo http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + select { + case <-time.After(config.New().Delay): + foo(w, r) + case <-r.Context().Done(): + http.Error(w, r.Context().Err().Error(), 499) + } + } +} + +func (rest *REST) defend(foo http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if gziphttp.Can(r) { + gz := gziphttp.New(w) + defer gz.Close() + w = gz + } + r.Body = struct { + io.Reader + io.Closer + }{ + Reader: io.LimitReader(r.Body, config.New().MaxFileSize), + Closer: r.Body, + } + foo(w, r) + } +} + +func (rest *REST) auth(foo http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := Auth(rest.g, w, r); err != nil { + return + } + foo(w, r) + } +} diff --git a/server/response.go b/server/response.go new file mode 100644 index 0000000..895e00b --- /dev/null +++ b/server/response.go @@ -0,0 +1,16 @@ +package server + +import ( + "encoding/json" + "net/http" +) + +func (rest *REST) respMap(w http.ResponseWriter, key string, value interface{}) { + rest.resp(w, map[string]interface{}{key: value}) +} + +func (rest *REST) resp(w http.ResponseWriter, body interface{}) { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + enc.Encode(body) +} diff --git a/server/rest.go b/server/rest.go new file mode 100644 index 0000000..1fab3b8 --- /dev/null +++ b/server/rest.go @@ -0,0 +1,77 @@ +package server + +import ( + "fmt" + "local/dndex/config" + "local/dndex/storage" + "local/router" + "net/http" +) + +type REST struct { + port int + router *router.Router + g storage.RateLimitedGraph +} + +type RESTScope struct { + entity scope + user scope +} +type scope struct { + name string + id string +} + +func Listen(g storage.RateLimitedGraph) error { + rest, err := NewREST(g) + if err != nil { + return err + } + return http.ListenAndServe(fmt.Sprintf(":%d", rest.port), rest.router) +} + +func NewREST(g storage.RateLimitedGraph) (*REST, error) { + rest := &REST{ + g: g, + port: config.New().Port, + router: router.New(), + } + + param := router.Wildcard + params := router.Wildcard + router.Wildcard + _, _ = param, params + + paths := map[string]http.HandlerFunc{ + fmt.Sprintf("version"): rest.version, + fmt.Sprintf("files/%s/%s", config.New().FilePrefix, params): rest.files, + fmt.Sprintf("users"): rest.users, + fmt.Sprintf("entities/%s", params): rest.entities, + } + + for path, foo := range paths { + bar := foo + bar = rest.auth(bar) + bar = rest.defend(bar) + bar = rest.delay(bar) + if err := rest.router.Add(path, bar); err != nil { + return nil, err + } + } + + return rest, nil +} + +func (rest *REST) scope(r *http.Request) RESTScope { + value, _ := r.Context().Value(AuthKey).(RESTScope) + return value +} + +func (rest *REST) files(w http.ResponseWriter, _ *http.Request) { +} + +func (rest *REST) users(w http.ResponseWriter, _ *http.Request) { +} + +func (rest *REST) entities(w http.ResponseWriter, _ *http.Request) { +} diff --git a/server/rest_test.go b/server/rest_test.go new file mode 100644 index 0000000..b940f3f --- /dev/null +++ b/server/rest_test.go @@ -0,0 +1,217 @@ +package server + +import ( + "context" + "fmt" + "io/ioutil" + "local/dndex/config" + "local/dndex/storage" + "local/dndex/storage/entity" + "net/http" + "net/http/httptest" + "os" + "path" + "strings" + "testing" + "time" + + "github.com/google/uuid" +) + +var ( + testNamespaceName = "col-name-" + uuid.New().String()[:10] + testNamespaceID = "col-id-" + uuid.New().String()[:10] + testEntityName = "ent-name-" + uuid.New().String()[:10] + testEntityID = "ent-id-" + uuid.New().String()[:10] + testFilename = "filename-" + uuid.New().String()[:10] + testContent = "content-" + uuid.New().String()[:10] +) + +func TestRESTRouter(t *testing.T) { + rest, clean := testREST(t) + defer clean() + cases := map[string]struct { + method string + is404 bool + }{ + "/version": { + method: http.MethodGet, + }, + fmt.Sprintf(`%s`, config.New().FilePrefix): { + method: http.MethodGet, + is404: true, + }, + fmt.Sprintf(`%s/`, config.New().FilePrefix): { + method: http.MethodGet, + is404: true, + }, + fmt.Sprintf(`%s/%s`, config.New().FilePrefix, testFilename): { + method: http.MethodGet, + }, + fmt.Sprintf(`%s/fake.fake`, config.New().FilePrefix): { + method: http.MethodGet, + is404: true, + }, + fmt.Sprintf("/users/%s", testNamespaceID): { + method: http.MethodGet, + is404: true, + }, + fmt.Sprintf("/users/%s", testNamespaceID): { + method: http.MethodPost, + is404: true, + }, + "/users/": { + method: http.MethodGet, + is404: true, + }, + "/users": { + method: http.MethodPost, + }, + "/users?": { + method: http.MethodGet, + }, + "/entities": { + method: http.MethodGet, + }, + "/entities/": { + method: http.MethodGet, + }, + fmt.Sprintf("/entities/fake-%s", testEntityID): { + method: http.MethodGet, + is404: true, + }, + fmt.Sprintf("/entities/%s", testEntityID): { + method: http.MethodGet, + }, + fmt.Sprintf("/entities/%s", testEntityID): { + method: http.MethodPatch, + }, + fmt.Sprintf("/entities/%s", testEntityID): { + method: http.MethodPost, + }, + fmt.Sprintf("/entities/%s", testEntityID): { + method: http.MethodPut, + }, + fmt.Sprintf("/entities/%s/connections/uuid", testEntityID): { + method: http.MethodPatch, + }, + fmt.Sprintf("/entities/%s/connections/uuid", testEntityID): { + method: http.MethodPut, + }, + fmt.Sprintf("/entities/%s/connections/uuid", testEntityID): { + method: http.MethodPost, + is404: true, + }, + fmt.Sprintf("/entities/%s/connections/", testEntityID): { + method: http.MethodPost, + is404: true, + }, + fmt.Sprintf("/entities/%s/connections/", testEntityID): { + method: http.MethodPut, + is404: true, + }, + fmt.Sprintf("/entities/%s/connections/", testEntityID): { + method: http.MethodPatch, + is404: true, + }, + fmt.Sprintf("/entities/%s/connections", testEntityID): { + method: http.MethodPost, + is404: true, + }, + fmt.Sprintf("/entities/%s/connections", testEntityID): { + method: http.MethodPut, + }, + fmt.Sprintf("/entities/%s/connections", testEntityID): { + method: http.MethodPatch, + }, + fmt.Sprintf("/entities/%s/connections/uuid", testEntityID): { + method: http.MethodDelete, + }, + fmt.Sprintf("/entities/%s/connections/", testEntityID): { + method: http.MethodDelete, + is404: true, + }, + fmt.Sprintf("/entities/%s/connections", testEntityID): { + method: http.MethodDelete, + }, + fmt.Sprintf("/entities/%s/", testEntityID): { + method: http.MethodDelete, + is404: true, + }, + fmt.Sprintf("/entities/%s", testEntityID): { + method: http.MethodDelete, + }, + } + + for name, d := range cases { + c := d + path := name + t.Run(name, func(t *testing.T) { + r := httptest.NewRequest(c.method, path, strings.NewReader(``)) + w := httptest.NewRecorder() + rest.router.ServeHTTP(w, r) + if (w.Code == http.StatusNotFound) != c.is404 { + t.Fatalf("want 404==%v, got %d: %+v: %s", c.is404, w.Code, c, w.Body.Bytes()) + } + }) + } +} + +func testREST(t *testing.T) (*REST, func()) { + d, err := ioutil.TempDir(os.TempDir(), "tempdir.*") + if err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(path.Join(d, testFilename), []byte(testContent), os.ModePerm); err != nil { + t.Fatal(err) + } + os.Setenv("FILEROOT", d) + os.Setenv("DRIVER_TYPE", "map") + os.Setenv("AUTH", "false") + os.Args = os.Args[:1] + os.Args = os.Args[:1] + rest, err := NewREST(storage.NewRateLimitedGraph("")) + if err != nil { + t.Fatal(err) + } + ctx, can := context.WithTimeout(context.Background(), time.Second*5) + defer can() + one := randomOne() + one.Name = testEntityName + one.ID = testEntityID + if err := rest.g.Insert(ctx, testNamespaceID, one); err != nil { + t.Fatal(err) + } + if err := rest.g.Insert(ctx, testNamespaceID+"."+AuthKey, entity.One{ + Name: testNamespaceName, + ID: testNamespaceID, + Title: "title", + }); err != nil { + t.Fatal(err) + } + if err := rest.g.Insert(ctx, testNamespaceID+"."+AuthKey, entity.One{ + Name: UserKey, + ID: UserKey, + Title: "", + }); err != nil { + t.Fatal(err) + } + return rest, func() { + os.RemoveAll(d) + } +} + +func randomOne() entity.One { + return entity.One{ + ID: "iddd-" + uuid.New().String()[:5], + Name: "name-" + uuid.New().String()[:5], + Type: "type-" + uuid.New().String()[:5], + Title: "titl-" + uuid.New().String()[:5], + Text: "text-" + uuid.New().String()[:5], + Modified: time.Now().UnixNano(), + Connections: map[string]entity.One{}, + Attachments: map[string]string{ + uuid.New().String()[:5]: uuid.New().String()[:5], + }, + } +} diff --git a/server/version.go b/server/version.go new file mode 100644 index 0000000..335c7e3 --- /dev/null +++ b/server/version.go @@ -0,0 +1,12 @@ +package server + +import ( + "encoding/json" + "net/http" +) + +func (rest *REST) version(w http.ResponseWriter, _ *http.Request) { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + enc.Encode(map[string]string{"version": GitCommit}) +} diff --git a/storage/driver/boltdb.go b/storage/driver/boltdb.go index cb75741..920b30c 100644 --- a/storage/driver/boltdb.go +++ b/storage/driver/boltdb.go @@ -289,7 +289,7 @@ func applyUnset(doc, operator bson.M) (bson.M, error) { ok = pmok } if !ok { - return nil, fmt.Errorf("subpath cannot be followed for non object: %s (%s): %+v (%T)", k, nesting[0], mInterface, mInterface) + return nil, fmt.Errorf("subpath of %v (%v) cannot be followed for non object: %s (%s): %+v (%T)", doc, doc[nesting[0]], k, nesting[0], mInterface, mInterface) } subdoc, err := applyUnset(bson.M(m), bson.M{strings.Join(nesting[1:], "."): ""}) if err != nil { @@ -322,7 +322,7 @@ func applySet(doc, operator bson.M) (bson.M, error) { ok = pmok } if !ok { - return nil, fmt.Errorf("subpath cannot be followed for non object: %s (%s): %+v (%T)", k, nesting[0], mInterface, mInterface) + return nil, fmt.Errorf("subpath of %v (%v) cannot be followed for non object: %s (%s): %+v (%T)", doc, doc[nesting[0]], k, nesting[0], mInterface, mInterface) } subdoc, err := applySet(bson.M(m), bson.M{strings.Join(nesting[1:], "."): v}) if err != nil { diff --git a/storage/driver/boltdb_test.go b/storage/driver/boltdb_test.go index bc53ebe..d93d5d1 100644 --- a/storage/driver/boltdb_test.go +++ b/storage/driver/boltdb_test.go @@ -143,9 +143,6 @@ func TestBoltDBFind(t *testing.T) { if o.Text == "" { t.Error(o.Text) } - if o.Relationship != "" { - t.Error(o.Relationship) - } if o.Modified == 0 { t.Error(o.Modified) } @@ -156,18 +153,9 @@ func TestBoltDBFind(t *testing.T) { t.Error(o.Connections) } for k := range o.Connections { - if o.Connections[k].Name == "" { - t.Error(o.Connections[k]) - } - if o.Connections[k].Title == "" { - t.Error(o.Connections[k]) - } if o.Connections[k].Relationship == "" { t.Error(o.Connections[k]) } - if o.Connections[k].Type == "" { - t.Error(o.Connections[k]) - } } } if n != testN { @@ -317,13 +305,6 @@ func TestBoltDBInsert(t *testing.T) { if _, ok := ones[0].Connections[k]; !ok { t.Fatalf("db had more connections than real: %s", k) } - c := o.Connections[k] - c.Modified = 0 - o.Connections[k] = c - - c = ones[0].Connections[k] - c.Modified = 0 - ones[0].Connections[k] = c } o.Modified = 0 ones[0].Modified = 0 @@ -372,7 +353,7 @@ func TestBoltDBDelete(t *testing.T) { t.Fatal(err) } if n != wantN { - t.Error(n, filter) + t.Error(wantN, n, filter) } } }) @@ -401,11 +382,10 @@ func fillBoltDB(t *testing.T, bdb *BoltDB) { } for i := 0; i < testN; i++ { p := entity.One{ - ID: "iddd-" + uuid.New().String()[:5], - Name: "name-" + uuid.New().String()[:5], - Type: "type-" + uuid.New().String()[:5], - Relationship: "rshp-" + uuid.New().String()[:5], - Title: "titl-" + uuid.New().String()[:5], + ID: "iddd-" + uuid.New().String()[:5], + Name: "name-" + uuid.New().String()[:5], + Type: "type-" + uuid.New().String()[:5], + Title: "titl-" + uuid.New().String()[:5], } o := entity.One{ ID: "iddd-" + uuid.New().String()[:5], @@ -414,8 +394,8 @@ func fillBoltDB(t *testing.T, bdb *BoltDB) { Title: "titl-" + uuid.New().String()[:5], Text: "text-" + uuid.New().String()[:5], Modified: time.Now().UnixNano(), - Connections: map[string]entity.One{p.ID: p}, - Attachments: map[string]string{"filename": "/path/to/file"}, + Connections: map[string]entity.Connection{p.ID: entity.Connection{p.Name}}, + Attachments: map[string]entity.Attachment{"filename": {"/path/to/file"}}, } b, err := bson.Marshal(o) if err != nil { diff --git a/storage/driver/map_test.go b/storage/driver/map_test.go index 3eea872..e5657e4 100644 --- a/storage/driver/map_test.go +++ b/storage/driver/map_test.go @@ -14,11 +14,10 @@ func tempMap(t *testing.T) *Map { mp.db[testNS] = map[string][]byte{} for i := 0; i < testN; i++ { p := entity.One{ - ID: "iddd-" + uuid.New().String()[:5], - Name: "name-" + uuid.New().String()[:5], - Type: "type-" + uuid.New().String()[:5], - Relationship: "rshp-" + uuid.New().String()[:5], - Title: "titl-" + uuid.New().String()[:5], + ID: "iddd-" + uuid.New().String()[:5], + Name: "name-" + uuid.New().String()[:5], + Type: "type-" + uuid.New().String()[:5], + Title: "titl-" + uuid.New().String()[:5], } o := entity.One{ ID: "iddd-" + uuid.New().String()[:5], @@ -27,8 +26,8 @@ func tempMap(t *testing.T) *Map { Title: "titl-" + uuid.New().String()[:5], Text: "text-" + uuid.New().String()[:5], Modified: time.Now().UnixNano(), - Connections: map[string]entity.One{p.ID: p}, - Attachments: map[string]string{"filename": "/path/to/file"}, + Connections: map[string]entity.Connection{p.ID: entity.Connection{p.Name}}, + Attachments: map[string]entity.Attachment{"filename": {"/path/to/file"}}, } b, err := bson.Marshal(o) if err != nil { diff --git a/storage/entity/one.go b/storage/entity/one.go index e948aa0..09ce9a9 100644 --- a/storage/entity/one.go +++ b/storage/entity/one.go @@ -19,42 +19,40 @@ const ( Modified = "modified" Connections = "connections" Attachments = "attachments" + Location = "location" ) type One struct { - ID string `bson:"_id,omitempty" json:"_id,omitempty"` - Name string `bson:"name,omitempty" json:"name,omitempty"` - Type string `bson:"type,omitempty" json:"type,omitempty"` - Title string `bson:"title,omitempty" json:"title,omitempty"` - Text string `bson:"text,omitempty" json:"text,omitempty"` - Relationship string `bson:"relationship,omitempty" json:"relationship,omitempty"` - Modified int64 `bson:"modified,omitempty" json:"modified,omitempty"` - Connections map[string]One `bson:"connections" json:"connections,omitempty"` - Attachments map[string]string `bson:"attachments" json:"attachments,omitempty"` + ID string `bson:"_id,omitempty" json:"_id"` + Name string `bson:"name,omitempty" json:"name"` + Type string `bson:"type,omitempty" json:"type"` + Title string `bson:"title,omitempty" json:"title"` + Text string `bson:"text,omitempty" json:"text"` + Modified int64 `bson:"modified,omitempty" json:"modified"` + Connections map[string]Connection `bson:"connections" json:"connections"` + Attachments map[string]Attachment `bson:"attachments" json:"attachments"` } -func (o One) Query() One { - return One{Name: o.Name} +type Connection struct { + Relationship string `bson:"relationship,omitempty" json:"relationship"` } -func (o One) Peer() One { - return One{ - Name: o.Name, - Type: o.Type, - Title: o.Title, - Relationship: o.Relationship, - Modified: o.Modified, - } +type Attachment struct { + Location string `bson:"location,omitempty" json:"location"` +} + +func (o One) Query() bson.M { + return bson.M{ID: o.ID} } func (o One) Peers() []string { - names := make([]string, len(o.Connections)) + ids := make([]string, len(o.Connections)) i := 0 for k := range o.Connections { - names[i] = o.Connections[k].Name + ids[i] = k i += 1 } - return names + return ids } func (o One) MarshalBSON() ([]byte, error) { @@ -62,6 +60,12 @@ func (o One) MarshalBSON() ([]byte, error) { if !isMin { o.Modified = time.Now().UnixNano() } + if o.Connections == nil { + o.Connections = make(map[string]Connection) + } + if o.Attachments == nil { + o.Attachments = make(map[string]Attachment) + } b, err := json.Marshal(o) if err != nil { return nil, err @@ -76,36 +80,5 @@ func (o One) MarshalBSON() ([]byte, error) { m[k] = strings.TrimSpace(v.(string)) } } - if !isMin { - connections := map[string]interface{}{} - switch m[Connections].(type) { - case nil: - case map[string]interface{}: - connections = m[Connections].(map[string]interface{}) - default: - return nil, fmt.Errorf("bad connections type %T", m[Connections]) - } - delete(connections, "") - for k := range connections { - if k == "" { - continue - } - if o.Connections[k].Name == "" { - p := o.Connections[k] - p.Name = k - o.Connections[k] = p - } - b, err := bson.Marshal(o.Connections[k]) - if err != nil { - return nil, err - } - m := bson.M{} - if err := bson.Unmarshal(b, &m); err != nil { - return nil, err - } - connections[k] = m - } - m[Connections] = connections - } return bson.Marshal(m) } diff --git a/storage/entity/one_test.go b/storage/entity/one_test.go index 991239c..236041b 100644 --- a/storage/entity/one_test.go +++ b/storage/entity/one_test.go @@ -9,11 +9,11 @@ import ( func TestOne(t *testing.T) { one := One{ - Name: "myname", + ID: "myname", Type: "mytype", } q := one.Query() - if want := fmt.Sprint(One{Name: one.Name}); want != fmt.Sprint(q) { + if want := fmt.Sprint(bson.M{ID: one.ID}); want != fmt.Sprint(q) { t.Error(want, q) } } @@ -21,7 +21,7 @@ func TestOne(t *testing.T) { func TestOneMarshalBSON(t *testing.T) { cases := map[string]struct { sameAsQuery bool - one One + one interface{} }{ "query no modified change": { sameAsQuery: true, @@ -31,18 +31,16 @@ func TestOneMarshalBSON(t *testing.T) { one: One{Name: "hello", Type: "world", Modified: 1}, }, "w/ connections": { - one: One{Name: "hello", Type: "world", Modified: 1, Connections: map[string]One{"hi": One{Name: "hi", Relationship: "mom"}}}, + one: One{Name: "hello", Type: "world", Modified: 1, Connections: map[string]Connection{"hi": Connection{Relationship: "mom"}}}, }, "w/ attachments": { - one: One{Name: "hello", Type: "world", Modified: 1, Attachments: map[string]string{"hello": "/world"}}, + one: One{Name: "hello", Type: "world", Modified: 1, Attachments: map[string]Attachment{"hello": Attachment{"/world"}}}, }, } for name, d := range cases { c := d t.Run(name, func(t *testing.T) { - var bm bson.Marshaler = c.one - t.Log(bm) b, err := bson.Marshal(c.one) if err != nil { t.Fatal(err) @@ -51,85 +49,11 @@ func TestOneMarshalBSON(t *testing.T) { if err := bson.Unmarshal(b, &one); err != nil { t.Fatal(err) } - if c.sameAsQuery && (fmt.Sprint(one) != fmt.Sprint(one.Query()) || fmt.Sprint(one) != fmt.Sprint(c.one)) { - t.Error(c.sameAsQuery, c.one, one) - } else if !c.sameAsQuery { - if c.one.Modified == one.Modified { - t.Error(c.one.Modified, one.Modified) - } - c.one.Modified = 0 - one.Modified = 0 - for k := range one.Connections { - temp := one.Connections[k] - temp.Modified = 0 - one.Connections[k] = temp - } - if fmt.Sprint(c.one) != fmt.Sprint(one) { - t.Error(c.one, one) + if !c.sameAsQuery { + if one.Modified < 2 { + t.Error(one.Modified) } } }) } } - -func TestOneMarshalBSONBadConnections(t *testing.T) { - t.Run("connections has an empty string for a key that should die", func(t *testing.T) { - input := One{Name: "hello", Connections: map[string]One{"": One{Name: "teehee"}}} - - b, err := bson.Marshal(input) - if err != nil { - t.Fatal(err) - } - - output := One{} - if err := bson.Unmarshal(b, &output); err != nil { - t.Fatal(err) - } - - if len(output.Connections) != 0 { - t.Fatal(output.Connections) - } - - input.Connections = nil - output.Connections = nil - input.Modified = 0 - output.Modified = 0 - - if fmt.Sprint(input) != fmt.Sprint(output) { - t.Fatal(input, output) - } - }) - - t.Run("connections has a key but empty name that should correct", func(t *testing.T) { - input := One{Name: "hello", Connections: map[string]One{"teehee": One{Name: ""}}} - - b, err := bson.Marshal(input) - if err != nil { - t.Fatal(err) - } - - output := One{} - if err := bson.Unmarshal(b, &output); err != nil { - t.Fatal(err) - } - - if len(output.Connections) != 1 { - t.Fatal(output.Connections) - } else { - for k := range output.Connections { - if k != output.Connections[k].Name { - t.Fatal(k, output.Connections) - } - } - } - - input.Connections = nil - output.Connections = nil - input.Modified = 0 - output.Modified = 0 - - if fmt.Sprint(input) != fmt.Sprint(output) { - t.Fatal(input, output) - } - }) -} diff --git a/storage/graph.go b/storage/graph.go index 491e840..d9258c6 100644 --- a/storage/graph.go +++ b/storage/graph.go @@ -31,6 +31,20 @@ func (g Graph) ListCaseInsensitive(ctx context.Context, namespace string, from . return g.find(ctx, namespace, filter) } +func (g Graph) Get(ctx context.Context, namespace, id string) (entity.One, error) { + ones, err := g.find(ctx, namespace, bson.M{entity.ID: id}) + if err != nil { + return entity.One{}, err + } + if len(ones) == 0 { + return entity.One{}, errors.New("not found") + } + if len(ones) > 1 { + return entity.One{}, errors.New("primary key collision detected") + } + return ones[0], nil +} + func (g Graph) List(ctx context.Context, namespace string, from ...string) ([]entity.One, error) { filter := operator.NewFilterIn(entity.Name, from) return g.find(ctx, namespace, filter) @@ -68,7 +82,7 @@ func (g Graph) Insert(ctx context.Context, namespace string, one entity.One) err return g.driver.Insert(ctx, namespace, one) } -func (g Graph) Update(ctx context.Context, namespace string, one entity.One, modify interface{}) error { +func (g Graph) Update(ctx context.Context, namespace string, one, modify interface{}) error { return g.driver.Update(ctx, namespace, one, modify) } diff --git a/storage/graph_test.go b/storage/graph_test.go index e5ccf2f..bbf9b36 100644 --- a/storage/graph_test.go +++ b/storage/graph_test.go @@ -29,7 +29,7 @@ func TestIntegration(t *testing.T) { randomOne(), randomOne(), } - ones[0].Connections = map[string]entity.One{ones[2].Name: entity.One{Name: ones[2].Name, Relationship: ":("}} + ones[0].Connections = map[string]entity.Connection{ones[2].Name: entity.Connection{Relationship: ":("}} ones[1].Name = ones[0].Name[1 : len(ones[0].Name)-1] cleanFill := func() { clean() @@ -62,6 +62,30 @@ func TestIntegration(t *testing.T) { } }) + t.Run("graph.Get 404", func(t *testing.T) { + cleanFill() + _, err := graph.Get(ctx, "col", "fake_here") + if err == nil { + t.Fatal(err) + } + }) + + t.Run("graph.Get", func(t *testing.T) { + cleanFill() + all, err := graph.List(ctx, "col") + if err != nil { + t.Fatal(err) + } + want := all[0] + got, err := graph.Get(ctx, "col", want.ID) + if err != nil { + t.Fatal(err) + } + if got.ID != want.ID { + t.Fatal(got) + } + }) + t.Run("graph.ListCaseInsensitive", func(t *testing.T) { cleanFill() all, err := graph.ListCaseInsensitive(ctx, "col") @@ -201,9 +225,9 @@ func TestIntegration(t *testing.T) { t.Run("graph.Update(foo, +=2); graph.Update(foo, -=1)", func(t *testing.T) { cleanFill() - err := graph.Update(ctx, "col", ones[0].Query(), operator.Set{entity.Connections, map[string]entity.One{ - "hello": entity.One{Name: "hello", Relationship: ":("}, - "world": entity.One{Name: "world", Relationship: ":("}, + err := graph.Update(ctx, "col", ones[0].Query(), operator.Set{entity.Connections, map[string]entity.Connection{ + "hello": entity.Connection{Relationship: ":("}, + "world": entity.Connection{Relationship: ":("}, }}) if err != nil { t.Fatal(err) @@ -240,7 +264,7 @@ func TestIntegration(t *testing.T) { t.Run("graph.Update(new attachment), Update(--new attachment)", func(t *testing.T) { cleanFill() - err := graph.Update(ctx, "col", ones[0].Query(), operator.Set{Key: fmt.Sprintf("%s.new attachment", entity.Attachments), Value: "my new attachment"}) + err := graph.Update(ctx, "col", ones[0].Query(), operator.Set{Key: fmt.Sprintf("%s.new attachment", entity.Attachments), Value: entity.Attachment{Location: "my new attachment"}}) if err != nil { t.Fatal(err) } @@ -255,8 +279,8 @@ func TestIntegration(t *testing.T) { } if v, ok := some1[0].Attachments["new attachment"]; !ok { t.Fatal(ok, some1[0].Attachments) - } else if v != "my new attachment" { - t.Fatal(v, some1[0].Attachments) + } else if v.Location != "my new attachment" { + t.Fatalf("when listing from DB, did not find updated attachment: got %+v from %+v", v, some1[0].Attachments) } err = graph.Update(ctx, "col", ones[0].Query(), operator.Unset(fmt.Sprintf("%s.new attachment", entity.Attachments))) @@ -372,10 +396,10 @@ func randomOne() entity.One { Title: "Biggus", Text: "tee hee xd", Modified: time.Now().UnixNano(), - Connections: map[string]entity.One{}, - Attachments: map[string]string{ - "pdf file": "/path/to.pdf", - "png file": "/path/to.png", + Connections: map[string]entity.Connection{}, + Attachments: map[string]entity.Attachment{ + "pdf file": entity.Attachment{Location: "/path/to.pdf"}, + "png file": entity.Attachment{Location: "/path/to.png"}, }, } } diff --git a/storage/ratelimitedgraph.go b/storage/ratelimitedgraph.go index b99652b..7e6c283 100644 --- a/storage/ratelimitedgraph.go +++ b/storage/ratelimitedgraph.go @@ -46,6 +46,10 @@ func (rlg RateLimitedGraph) Insert(ctx context.Context, namespace string, one en return rlg.g.Insert(ctx, namespace, one) } +func (rlg RateLimitedGraph) Get(ctx context.Context, namespace, id string) (entity.One, error) { + return rlg.g.Get(ctx, namespace, id) +} + func (rlg RateLimitedGraph) List(ctx context.Context, namespace string, from ...string) ([]entity.One, error) { return rlg.g.List(ctx, namespace, from...) } @@ -58,6 +62,6 @@ func (rlg RateLimitedGraph) Search(ctx context.Context, namespace string, nameCo return rlg.g.Search(ctx, namespace, nameContains) } -func (rlg RateLimitedGraph) Update(ctx context.Context, namespace string, one entity.One, modify interface{}) error { +func (rlg RateLimitedGraph) Update(ctx context.Context, namespace string, one, modify interface{}) error { return rlg.g.Update(ctx, namespace, one, modify) }