Sanitize at API level
This commit is contained in:
@@ -126,6 +126,51 @@ func TestIntegration(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("graph.Delete(case insensitives() => 0)", func(t *testing.T) {
|
||||
cleanFill()
|
||||
err := graph.Delete(ctx, "col", operator.CaseInsensitives{Key: entity.Name, Values: []string{}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ones, err := graph.List(ctx, "col")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(ones) != 0 {
|
||||
t.Fatal(len(ones))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("graph.Delete(case insensitives(.*) => 0)", func(t *testing.T) {
|
||||
cleanFill()
|
||||
err := graph.Delete(ctx, "col", operator.CaseInsensitives{Key: entity.Name, Values: []string{".*"}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ones, err := graph.List(ctx, "col")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(ones) < 0 {
|
||||
t.Fatal(len(ones))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("graph.Delete(case insensitive(.*) => 0)", func(t *testing.T) {
|
||||
cleanFill()
|
||||
err := graph.Delete(ctx, "col", operator.CaseInsensitive{Key: entity.Name, Value: ".*"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ones, err := graph.List(ctx, "col")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(ones) < 0 {
|
||||
t.Fatal(len(ones))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("graph.Search(foo => *)", func(t *testing.T) {
|
||||
cleanFill()
|
||||
some, err := graph.Search(ctx, "col", ones[0].Name[:3])
|
||||
|
||||
@@ -2,7 +2,6 @@ package operator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
@@ -26,14 +25,10 @@ func (cis CaseInsensitives) MarshalBSON() ([]byte, error) {
|
||||
values := cis.Values
|
||||
if len(cis.Values) == 0 {
|
||||
values = []string{".*"}
|
||||
} else {
|
||||
for i := range values {
|
||||
values[i] = escapeRegex(values[i])
|
||||
}
|
||||
}
|
||||
ci := CaseInsensitive{
|
||||
Key: cis.Key,
|
||||
Value: fmt.Sprintf("^(%s)$", strings.Join(values, "|")),
|
||||
Value: fmt.Sprintf("(%s)", strings.Join(values, "|")),
|
||||
}
|
||||
return bson.Marshal(ci)
|
||||
}
|
||||
@@ -46,11 +41,9 @@ type CaseInsensitive struct {
|
||||
func (ci CaseInsensitive) MarshalBSON() ([]byte, error) {
|
||||
value := ci.Value
|
||||
if value == "" {
|
||||
value = "^$"
|
||||
} else {
|
||||
value = escapeRegex(value)
|
||||
value = ".*"
|
||||
}
|
||||
return bson.Marshal(Regex{Key: ci.Key, Value: "(?i)" + ci.Value})
|
||||
return bson.Marshal(Regex{Key: ci.Key, Value: "^(?i)" + value + "$"})
|
||||
}
|
||||
|
||||
type FilterIn struct {
|
||||
@@ -107,9 +100,3 @@ func filterMarshal(op, key string, value interface{}) ([]byte, error) {
|
||||
}
|
||||
return bson.Marshal(m)
|
||||
}
|
||||
|
||||
func escapeRegex(s string) string {
|
||||
re := regexp.MustCompile(`[^a-zA-Z0-9]`)
|
||||
s = re.ReplaceAllString(s, `.`)
|
||||
return s
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user