diff --git a/storage/driver/boltdb.go b/storage/driver/boltdb.go index 1ea4f61..34dfb3b 100644 --- a/storage/driver/boltdb.go +++ b/storage/driver/boltdb.go @@ -316,7 +316,29 @@ func applySet(doc, operator bson.M) (bson.M, error) { if k == entity.Name { return nil, errModifiedReserved } - doc[k] = v + nesting := strings.Split(k, ".") + if len(nesting) > 1 { + mInterface, ok := doc[nesting[0]] + if !ok { + return nil, fmt.Errorf("path does not exist: %s (%s): %+v", k, nesting[0], doc) + } + m, ok := mInterface.(map[string]interface{}) + if !ok { + pm, pmok := mInterface.(primitive.M) + m = map[string]interface{}(pm) + ok = pmok + } + if !ok { + return nil, fmt.Errorf("subpath cannot be followed for non object: %s (%s): %+v (%T)", k, nesting[0], mInterface, mInterface) + } + subdoc, err := applySet(bson.M(m), bson.M{strings.Join(nesting[1:], "."): v}) + if err != nil { + return nil, err + } + doc[nesting[0]] = subdoc + } else { + doc[k] = v + } } return doc, nil } diff --git a/storage/driver/boltdb_test.go b/storage/driver/boltdb_test.go index 943b58b..04c1a2d 100644 --- a/storage/driver/boltdb_test.go +++ b/storage/driver/boltdb_test.go @@ -396,3 +396,50 @@ func fillBoltDB(t *testing.T, bdb *BoltDB) { t.Fatal(err) } } + +func TestApplySet(t *testing.T) { + cases := map[string]struct { + doc bson.M + operator bson.M + want bson.M + }{ + "noop on empty": {}, + "noop on full": { + doc: bson.M{"hello": "world"}, + want: bson.M{"hello": "world"}, + }, + "add new field on full": { + operator: bson.M{"hi": "mom"}, + doc: bson.M{"hello": "world"}, + want: bson.M{"hello": "world", "hi": "mom"}, + }, + "change only field on full": { + operator: bson.M{"hello": "lol jk not world"}, + doc: bson.M{"hello": "world"}, + want: bson.M{"hello": "lol jk not world"}, + }, + "set existing, nested field": { + operator: bson.M{"hello.world": "hi"}, + doc: bson.M{"hello": bson.M{"world": "not hi"}}, + want: bson.M{"hello": bson.M{"world": "hi"}}, + }, + "add to existing, nested field": { + operator: bson.M{"hello.notworld": "hi"}, + doc: bson.M{"hello": bson.M{"world": "not hi"}}, + want: bson.M{"hello": bson.M{"world": "not hi", "notworld": "hi"}}, + }, + } + + for name, d := range cases { + c := d + t.Run(name, func(t *testing.T) { + out, err := applySet(c.doc, c.operator) + if err != nil { + t.Fatal(err) + } + if fmt.Sprint(out) != fmt.Sprint(c.want) { + t.Fatalf("(%+v, %+v) => want \n%+v\n, got \n%+v", c.doc, c.operator, c.want, out) + } + }) + } +} diff --git a/storage/operator/modify.go b/storage/operator/modify.go index 3e9f9ee..119732f 100644 --- a/storage/operator/modify.go +++ b/storage/operator/modify.go @@ -23,6 +23,14 @@ func (pi PopIf) MarshalBSON() ([]byte, error) { return opMarshal("$pull", pi.Key, pi.Filter) } +type SetMany struct { + Value interface{} +} + +func (s SetMany) MarshalBSON() ([]byte, error) { + return opMarshal("$set", "", s.Value) +} + type Set struct { Key string Value interface{} @@ -49,20 +57,28 @@ func opMarshal(op, key string, value interface{}) ([]byte, error) { return bson.Marshal(marshalable) } -func opMarshalable(op, key string, value interface{}) map[string]map[string]interface{} { - if len(key) == 0 { +func opMarshalable(op, key string, value interface{}) map[string]interface{} { + if len(key) == 0 && value == nil { return nil } - m := map[string]map[string]interface{}{ + m := map[string]interface{}{ op: map[string]interface{}{ key: value, }, } + if len(key) == 0 { + m[op] = value + } if _, ok := m["$set"]; !ok { m["$set"] = map[string]interface{}{} } - if _, ok := m["$set"][entity.Modified]; !ok { - m["$set"][entity.Modified] = time.Now().UnixNano() + switch m["$set"].(type) { + case map[string]interface{}: + m["$set"].(map[string]interface{})[entity.Modified] = time.Now().UnixNano() + case bson.M: + m["$set"].(bson.M)[entity.Modified] = time.Now().UnixNano() + //case primitive.M: + //m["$set"].(primitive.M)[entity.Modified] = time.Now().UnixNano() } return m } diff --git a/view/json.go b/view/json.go index 669155c..cc562cf 100644 --- a/view/json.go +++ b/view/json.go @@ -27,15 +27,7 @@ func jsonHandler(g storage.Graph) http.Handler { }{ { path: "/who/", - foo: httpwho, - }, - { - path: "/meet/", - foo: httpmeet, - }, - { - path: "/isnow/", - foo: httpisnow, + foo: who, }, } @@ -45,8 +37,12 @@ func jsonHandler(g storage.Graph) http.Handler { foo := route.foo mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { if err := foo(g, w, r); err != nil { + status := http.StatusInternalServerError + if strings.Contains(err.Error(), "collision") { + status = http.StatusConflict + } b, _ := json.Marshal(map[string]string{"error": err.Error()}) - http.Error(w, string(b), http.StatusInternalServerError) + http.Error(w, string(b), status) } }) mux.HandleFunc(nopath, http.NotFound) diff --git a/view/who.go b/view/who.go new file mode 100644 index 0000000..def6769 --- /dev/null +++ b/view/who.go @@ -0,0 +1,136 @@ +package view + +import ( + "encoding/json" + "errors" + "io/ioutil" + "local/dndex/storage" + "local/dndex/storage/entity" + "local/dndex/storage/operator" + "net/http" + "path" + "strings" + + "github.com/buger/jsonparser" + "go.mongodb.org/mongo-driver/bson" +) + +func who(g storage.Graph, w http.ResponseWriter, r *http.Request) error { + namespace := strings.TrimLeft(r.URL.Path, path.Dir(r.URL.Path)) + if len(namespace) == 0 { + http.NotFound(w, r) + return nil + } + namespace = strings.Replace(namespace, "/", ".", -1) + + switch r.Method { + case http.MethodGet: + return whoGet(namespace, g, w, r) + case http.MethodPut: + return whoPut(namespace, g, w, r) + case http.MethodPost: + return whoPost(namespace, g, w, r) + default: + http.NotFound(w, r) + return nil + } +} + +func whoGet(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { + id := r.URL.Query().Get("id") + if id == "" { + http.Error(w, `{"error":"no ?id provided"}`, http.StatusBadRequest) + return nil + } + _, light := r.URL.Query()["light"] + + ones, err := g.List(r.Context(), namespace, id) + if err != nil { + return err + } + if len(ones) == 0 { + http.NotFound(w, r) + return nil + } + if len(ones) > 1 { + return errors.New("more than one result found matching " + id) + } + one := ones[0] + + if !light && len(one.Connections) > 0 { + ones, err := g.List(r.Context(), namespace, one.Peers()...) + if err != nil { + return err + } + for _, another := range ones { + another.Relationship = one.Connections[another.Name].Relationship + one.Connections[another.Name] = another + } + } + + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + return enc.Encode(one) +} + +func whoPut(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { + id := r.URL.Query().Get("id") + if id == "" { + http.Error(w, `{"error":"no ?id provided"}`, http.StatusBadRequest) + return nil + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + operation := entity.One{} + if err := json.Unmarshal(body, &operation); err != nil { + return err + } + if operation.Name != "" { + http.Error(w, `{"error":"cannot specify name in request body"}`, http.StatusBadRequest) + return nil + } + if operation.Modified != 0 { + http.Error(w, `{"error":"cannot specify modified in request body"}`, http.StatusBadRequest) + return nil + } + b, err := bson.Marshal(operation) + if err != nil { + return err + } + op := bson.M{} + if err := bson.Unmarshal(b, &op); err != nil { + return err + } + for k := range op { + if _, _, _, err := jsonparser.Get(body, k); err != nil { + delete(op, k) + } + } + if err := g.Update(r.Context(), namespace, entity.One{Name: id}, operator.SetMany{op}); err != nil { + return err + } + + return whoGet(namespace, g, w, r) +} + +func whoPost(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { + id := r.URL.Query().Get("id") + if id == "" { + http.Error(w, `{"error":"no ?id provided"}`, http.StatusBadRequest) + return nil + } + + one := entity.One{} + if err := json.NewDecoder(r.Body).Decode(&one); err != nil { + return err + } + one.Name = id + if err := g.Insert(r.Context(), namespace, one); err != nil { + return err + } + + return whoGet(namespace, g, w, r) +} diff --git a/view/who_test.go b/view/who_test.go new file mode 100644 index 0000000..6fbde95 --- /dev/null +++ b/view/who_test.go @@ -0,0 +1,155 @@ +package view + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "local/dndex/config" + "local/dndex/storage" + "local/dndex/storage/entity" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func TestWho(t *testing.T) { + os.Args = os.Args[:1] + f, err := ioutil.TempFile(os.TempDir(), "pattern*") + if err != nil { + t.Fatal(err) + } + f.Close() + defer os.Remove(f.Name()) + os.Setenv("DBURI", f.Name()) + + t.Logf("config: %+v", config.New()) + + g := storage.NewGraph() + ones := fillDB(t, g) + want := ones[len(ones)-1] + + handler := jsonHandler(g) + + t.Log(handler, want) + + t.Run("get no namespace is 404", func(t *testing.T) { + iwant := want + r := httptest.NewRequest(http.MethodGet, "/who?id="+iwant.Name, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusNotFound { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("get fake", func(t *testing.T) { + iwant := want + r := httptest.NewRequest(http.MethodGet, "/who/col?id=FAKER"+iwant.Name, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusNotFound { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("get real", func(t *testing.T) { + iwant := want + r := httptest.NewRequest(http.MethodGet, "/who/col?id="+iwant.Name, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + o := entity.One{} + if err := json.Unmarshal(w.Body.Bytes(), &o); err != nil { + t.Fatal(err) + } + if fmt.Sprint(o) == fmt.Sprint(iwant) { + t.Fatal(o, iwant) + } + if len(o.Connections) != len(iwant.Connections) { + t.Fatal(len(o.Connections), len(iwant.Connections)) + } + iwant.Connections = o.Connections + iwant.Modified = 0 + o.Modified = 0 + if fmt.Sprint(o) != fmt.Sprint(iwant) { + t.Fatalf("after resolving connections and modified, iwant and got differ: \nwant %+v\n got %+v", iwant, o) + } + b, _ := json.MarshalIndent(o, "", " ") + t.Logf("POST GET:\n%s", b) + }) + + t.Run("put fake", func(t *testing.T) { + iwant := want + r := httptest.NewRequest(http.MethodPut, "/who/col?id=FAKER"+iwant.Name, strings.NewReader(`{"title":"this should fail to find someone"}`)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusNotFound { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("put real", func(t *testing.T) { + iwant := want + r := httptest.NewRequest(http.MethodPut, "/who/col?id="+iwant.Name, strings.NewReader(`{"title":"this should work"}`)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + o := entity.One{} + if err := json.Unmarshal(w.Body.Bytes(), &o); err != nil { + t.Fatal(err) + } + if len(o.Connections) != len(iwant.Connections) { + t.Fatalf("wrong number of connections returned: want %v, got %v", len(iwant.Connections), len(o.Connections)) + } + if o.Title != "this should work" { + t.Fatalf("failed to PUT a new title: %+v", o) + } + b, _ := json.MarshalIndent(o, "", " ") + t.Logf("POST PUT:\n%s", b) + }) + + t.Run("post exists", func(t *testing.T) { + iwant := want + iwant.Name = "" + r := httptest.NewRequest(http.MethodPost, "/who/col?id="+want.Name, strings.NewReader(`{"title":"this should fail to insert"}`)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusConflict { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("post real", func(t *testing.T) { + iwant := want + iwant.Name = "" + b, err := json.Marshal(iwant) + if err != nil { + t.Fatal(err) + } + r := httptest.NewRequest(http.MethodPost, "/who/col?id=NEWBIE"+want.Name, bytes.NewReader(b)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + o := entity.One{} + if err := json.Unmarshal(w.Body.Bytes(), &o); err != nil { + t.Fatal(err) + } + if len(o.Connections) != len(iwant.Connections) { + t.Fatalf("wrong number of connections returned: want %v, got %v", len(iwant.Connections), len(o.Connections)) + } + if o.Name != "NEWBIE"+want.Name { + t.Fatalf("failed to POST specified name: %+v", o) + } + b, _ = json.MarshalIndent(o, "", " ") + t.Logf("POST POST:\n%s", b) + }) +}