diff --git a/storage/graph_test.go b/storage/graph_test.go index 0f02009..28b237e 100644 --- a/storage/graph_test.go +++ b/storage/graph_test.go @@ -126,6 +126,51 @@ func TestIntegration(t *testing.T) { } }) + t.Run("graph.Delete(case insensitives() => 0)", func(t *testing.T) { + cleanFill() + err := graph.Delete(ctx, "col", operator.CaseInsensitives{Key: entity.Name, Values: []string{}}) + if err != nil { + t.Fatal(err) + } + ones, err := graph.List(ctx, "col") + if err != nil { + t.Fatal(err) + } + if len(ones) != 0 { + t.Fatal(len(ones)) + } + }) + + t.Run("graph.Delete(case insensitives(.*) => 0)", func(t *testing.T) { + cleanFill() + err := graph.Delete(ctx, "col", operator.CaseInsensitives{Key: entity.Name, Values: []string{".*"}}) + if err != nil { + t.Fatal(err) + } + ones, err := graph.List(ctx, "col") + if err != nil { + t.Fatal(err) + } + if len(ones) < 0 { + t.Fatal(len(ones)) + } + }) + + t.Run("graph.Delete(case insensitive(.*) => 0)", func(t *testing.T) { + cleanFill() + err := graph.Delete(ctx, "col", operator.CaseInsensitive{Key: entity.Name, Value: ".*"}) + if err != nil { + t.Fatal(err) + } + ones, err := graph.List(ctx, "col") + if err != nil { + t.Fatal(err) + } + if len(ones) < 0 { + t.Fatal(len(ones)) + } + }) + t.Run("graph.Search(foo => *)", func(t *testing.T) { cleanFill() some, err := graph.Search(ctx, "col", ones[0].Name[:3]) diff --git a/storage/operator/filter.go b/storage/operator/filter.go index c2f4cfd..fdd5ad1 100644 --- a/storage/operator/filter.go +++ b/storage/operator/filter.go @@ -2,7 +2,6 @@ package operator import ( "fmt" - "regexp" "strings" "go.mongodb.org/mongo-driver/bson" @@ -26,14 +25,10 @@ func (cis CaseInsensitives) MarshalBSON() ([]byte, error) { values := cis.Values if len(cis.Values) == 0 { values = []string{".*"} - } else { - for i := range values { - values[i] = escapeRegex(values[i]) - } } ci := CaseInsensitive{ Key: cis.Key, - Value: fmt.Sprintf("^(%s)$", strings.Join(values, "|")), + Value: fmt.Sprintf("(%s)", strings.Join(values, "|")), } return bson.Marshal(ci) } @@ -46,11 +41,9 @@ type CaseInsensitive struct { func (ci CaseInsensitive) MarshalBSON() ([]byte, error) { value := ci.Value if value == "" { - value = "^$" - } else { - value = escapeRegex(value) + value = ".*" } - return bson.Marshal(Regex{Key: ci.Key, Value: "(?i)" + ci.Value}) + return bson.Marshal(Regex{Key: ci.Key, Value: "^(?i)" + value + "$"}) } type FilterIn struct { @@ -107,9 +100,3 @@ func filterMarshal(op, key string, value interface{}) ([]byte, error) { } return bson.Marshal(m) } - -func escapeRegex(s string) string { - re := regexp.MustCompile(`[^a-zA-Z0-9]`) - s = re.ReplaceAllString(s, `.`) - return s -} diff --git a/view/who.go b/view/who.go index ef23a7e..85ce58a 100644 --- a/view/who.go +++ b/view/who.go @@ -9,6 +9,7 @@ import ( "local/dndex/storage/entity" "local/dndex/storage/operator" "net/http" + "regexp" "sort" "strings" @@ -43,7 +44,7 @@ func who(g storage.Graph, w http.ResponseWriter, r *http.Request) error { } func whoGet(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { - id, err := getID(r) + id, err := getCleanID(r) if err != nil { return whoTrace(namespace, g, w, r) } @@ -145,7 +146,7 @@ func whoPost(namespace string, g storage.Graph, w http.ResponseWriter, r *http.R } func whoDelete(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error { - id, err := getID(r) + id, err := getCleanID(r) if err != nil { w.WriteHeader(http.StatusBadRequest) return json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) @@ -201,6 +202,11 @@ func whoTrace(namespace string, g storage.Graph, w http.ResponseWriter, r *http. return enc.Encode(names) } +func getCleanID(r *http.Request) (string, error) { + id, err := getID(r) + return sanitize(id), err +} + func getID(r *http.Request) (string, error) { id := r.URL.Query().Get("id") if id == "" { @@ -210,11 +216,11 @@ func getID(r *http.Request) (string, error) { } func sortOnes(ones []entity.One, r *http.Request) []entity.One { - sorting := r.URL.Query().Get("sort") + sorting := sanitize(r.URL.Query().Get("sort")) if sorting == "" { sorting = entity.Modified } - order := r.URL.Query().Get("order") + order := sanitize(r.URL.Query().Get("order")) if order == "" { order = "-1" } @@ -236,3 +242,9 @@ func sortOnes(ones []entity.One, r *http.Request) []entity.One { }) return ones } + +func sanitize(s string) string { + re := regexp.MustCompile(`[^a-zA-Z0-9- _]`) + s = re.ReplaceAllString(s, `.`) + return s +} diff --git a/view/who_test.go b/view/who_test.go index e108beb..f2b1e5e 100644 --- a/view/who_test.go +++ b/view/who_test.go @@ -246,6 +246,30 @@ func TestWho(t *testing.T) { } }) + t.Run("delete regexp should be sanitized", func(t *testing.T) { + r := httptest.NewRequest(http.MethodDelete, "/who?namespace=col&id=.*", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + + r = httptest.NewRequest(http.MethodTrace, "/who?namespace=col", nil) + w = httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + var v []string + if err := json.Unmarshal(w.Body.Bytes(), &v); err != nil { + t.Fatalf("%v: %s", err, w.Body.Bytes()) + } + if len(v) < 5 { + t.Fatal(len(v)) + } + t.Logf("%+v", v) + }) + t.Run("patch fake", func(t *testing.T) { r := httptest.NewRequest(http.MethodPatch, "/who?namespace=col&id=FAKER"+want.Name, nil) w := httptest.NewRecorder()