diff --git a/server/dump.go b/server/dump.go new file mode 100644 index 0000000..a29979a --- /dev/null +++ b/server/dump.go @@ -0,0 +1,62 @@ +package server + +import ( + "encoding/json" + "local/dndex/server/auth" + "local/dndex/storage/entity" + "local/dndex/storage/operator" + "net/http" + + "gopkg.in/mgo.v2/bson" +) + +func (rest *REST) dump(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + rest.dumpOut(w, r) + case http.MethodPost: + rest.dumpIn(w, r) + default: + rest.respNotFound(w) + } +} + +func (rest *REST) dumpOut(w http.ResponseWriter, r *http.Request) { + scope := rest.scope(r) + entities, err := rest.g.List(r.Context(), scope.Namespace) + if err != nil { + rest.respError(w, err) + } + for i := len(entities) - 1; i >= 0; i-- { + if entities[i].ID == auth.UserKey { + if i < len(entities)-1 { + entities = append(entities[:i], entities[i+1:]...) + } else { + entities = entities[:i] + } + } + } + rest.respMap(w, scope.Namespace, entities) +} + +func (rest *REST) dumpIn(w http.ResponseWriter, r *http.Request) { + scope := rest.scope(r) + var request map[string][]entity.One + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + rest.respBadRequest(w, err.Error()) + return + } + for namespace, ones := range request { + if namespace == scope.Namespace { + for _, one := range ones { + if err := rest.g.Insert(r.Context(), scope.Namespace, one); err != nil { + if err := rest.g.Update(r.Context(), scope.Namespace, bson.M{entity.ID: one.ID}, operator.SetMany{Value: one}); err != nil { + rest.respError(w, err) + return + } + } + } + } + } + rest.entitiesGetN(w, r) +} diff --git a/server/dump_test.go b/server/dump_test.go new file mode 100644 index 0000000..f5bc71d --- /dev/null +++ b/server/dump_test.go @@ -0,0 +1,93 @@ +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "local/dndex/storage/entity" + "net/http" + "net/http/httptest" + "testing" +) + +func TestDump(t *testing.T) { + cases := map[string]func(*testing.T, *REST, func(*http.Request)){ + "dump out": func(t *testing.T, rest *REST, scope func(r *http.Request)) { + r := httptest.NewRequest(http.MethodGet, "/dump", nil) + w := httptest.NewRecorder() + scope(r) + rest.dump(w, r) + + if w.Code != http.StatusOK { + t.Fatal(w.Code) + } + var dump map[string][]entity.One + if err := json.Unmarshal(w.Body.Bytes(), &dump); err != nil { + t.Fatal(err) + } + if len(dump) == 0 { + t.Fatal(dump) + } + for _, ones := range dump { + for _, one := range ones { + if fmt.Sprint(one) == fmt.Sprint(entity.One{}) { + t.Fatal(one) + } + if len(one.Attachments) == 0 { + t.Fatal(one) + } + } + } + t.Logf("%d: %s: %+v", w.Code, w.Body.Bytes(), dump) + }, + "dump in": func(t *testing.T, rest *REST, scope func(r *http.Request)) { + oneA := randomOne() + oneB := randomOne() + b, err := json.MarshalIndent( + map[string][]entity.One{testNamespace: []entity.One{oneA, oneB}}, + "", + " ", + ) + if err != nil { + t.Fatal(err) + } + t.Logf("dumping in %s", b) + r := httptest.NewRequest(http.MethodPost, "/dump", bytes.NewReader(b)) + w := httptest.NewRecorder() + scope(r) + rest.dump(w, r) + if w.Code != http.StatusOK { + t.Fatal(w.Code) + } + foundA := false + foundB := false + b = w.Body.Bytes() + testEntitiesGetNResponse(t, w.Body, func(one shortEntity) bool { + foundA = foundA || one.ID == oneA.ID + foundB = foundB || one.ID == oneB.ID + return foundA && foundB + }) + for _, one := range []entity.One{oneA, oneB} { + t.Logf("looking for %s", one.Name) + w = testEntitiesMethod(t, scope, rest, http.MethodGet, "/"+one.ID, ``) + if w.Code != http.StatusOK { + t.Fatal(w.Code, ":", string(w.Body.Bytes())) + } + testEntitiesGetOneResponse(t, w.Body, func(got entity.One) bool { + got.Modified = 0 + one.Modified = 0 + return fmt.Sprint(one) == fmt.Sprint(got) + }) + } + }, + } + + for name, foo := range cases { + bar := foo + t.Run(name, func(t *testing.T) { + rest, scope, clean := testREST(t) + bar(t, rest, scope) + defer clean() + }) + } +} diff --git a/server/rest.go b/server/rest.go index 5c852aa..3aa39f1 100644 --- a/server/rest.go +++ b/server/rest.go @@ -39,13 +39,17 @@ func NewREST(g storage.RateLimitedGraph) (*REST, error) { fmt.Sprintf("%s/%s", config.New().FilePrefix, params): rest.files, fmt.Sprintf("users/%s", param): rest.users, fmt.Sprintf("entities/%s", params): rest.entities, + fmt.Sprintf("dump"): rest.dump, } for path, foo := range paths { bar := foo bar = rest.shift(bar) bar = rest.scoped(bar) - if !strings.HasPrefix(path, "users/") && path != "version" { + switch strings.Split(path, "/")[0] { + case "users": + case "version": + default: bar = rest.auth(bar) } bar = rest.defend(bar)