Sanitize at API level
parent
a1d59a0248
commit
09507d38e9
|
|
@ -126,6 +126,51 @@ func TestIntegration(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("graph.Delete(case insensitives() => 0)", func(t *testing.T) {
|
||||||
|
cleanFill()
|
||||||
|
err := graph.Delete(ctx, "col", operator.CaseInsensitives{Key: entity.Name, Values: []string{}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ones, err := graph.List(ctx, "col")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(ones) != 0 {
|
||||||
|
t.Fatal(len(ones))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("graph.Delete(case insensitives(.*) => 0)", func(t *testing.T) {
|
||||||
|
cleanFill()
|
||||||
|
err := graph.Delete(ctx, "col", operator.CaseInsensitives{Key: entity.Name, Values: []string{".*"}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ones, err := graph.List(ctx, "col")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(ones) < 0 {
|
||||||
|
t.Fatal(len(ones))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("graph.Delete(case insensitive(.*) => 0)", func(t *testing.T) {
|
||||||
|
cleanFill()
|
||||||
|
err := graph.Delete(ctx, "col", operator.CaseInsensitive{Key: entity.Name, Value: ".*"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ones, err := graph.List(ctx, "col")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(ones) < 0 {
|
||||||
|
t.Fatal(len(ones))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("graph.Search(foo => *)", func(t *testing.T) {
|
t.Run("graph.Search(foo => *)", func(t *testing.T) {
|
||||||
cleanFill()
|
cleanFill()
|
||||||
some, err := graph.Search(ctx, "col", ones[0].Name[:3])
|
some, err := graph.Search(ctx, "col", ones[0].Name[:3])
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ package operator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"go.mongodb.org/mongo-driver/bson"
|
"go.mongodb.org/mongo-driver/bson"
|
||||||
|
|
@ -26,14 +25,10 @@ func (cis CaseInsensitives) MarshalBSON() ([]byte, error) {
|
||||||
values := cis.Values
|
values := cis.Values
|
||||||
if len(cis.Values) == 0 {
|
if len(cis.Values) == 0 {
|
||||||
values = []string{".*"}
|
values = []string{".*"}
|
||||||
} else {
|
|
||||||
for i := range values {
|
|
||||||
values[i] = escapeRegex(values[i])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ci := CaseInsensitive{
|
ci := CaseInsensitive{
|
||||||
Key: cis.Key,
|
Key: cis.Key,
|
||||||
Value: fmt.Sprintf("^(%s)$", strings.Join(values, "|")),
|
Value: fmt.Sprintf("(%s)", strings.Join(values, "|")),
|
||||||
}
|
}
|
||||||
return bson.Marshal(ci)
|
return bson.Marshal(ci)
|
||||||
}
|
}
|
||||||
|
|
@ -46,11 +41,9 @@ type CaseInsensitive struct {
|
||||||
func (ci CaseInsensitive) MarshalBSON() ([]byte, error) {
|
func (ci CaseInsensitive) MarshalBSON() ([]byte, error) {
|
||||||
value := ci.Value
|
value := ci.Value
|
||||||
if value == "" {
|
if value == "" {
|
||||||
value = "^$"
|
value = ".*"
|
||||||
} else {
|
|
||||||
value = escapeRegex(value)
|
|
||||||
}
|
}
|
||||||
return bson.Marshal(Regex{Key: ci.Key, Value: "(?i)" + ci.Value})
|
return bson.Marshal(Regex{Key: ci.Key, Value: "^(?i)" + value + "$"})
|
||||||
}
|
}
|
||||||
|
|
||||||
type FilterIn struct {
|
type FilterIn struct {
|
||||||
|
|
@ -107,9 +100,3 @@ func filterMarshal(op, key string, value interface{}) ([]byte, error) {
|
||||||
}
|
}
|
||||||
return bson.Marshal(m)
|
return bson.Marshal(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func escapeRegex(s string) string {
|
|
||||||
re := regexp.MustCompile(`[^a-zA-Z0-9]`)
|
|
||||||
s = re.ReplaceAllString(s, `.`)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
|
||||||
20
view/who.go
20
view/who.go
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"local/dndex/storage/entity"
|
"local/dndex/storage/entity"
|
||||||
"local/dndex/storage/operator"
|
"local/dndex/storage/operator"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
|
@ -43,7 +44,7 @@ func who(g storage.Graph, w http.ResponseWriter, r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func whoGet(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error {
|
func whoGet(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error {
|
||||||
id, err := getID(r)
|
id, err := getCleanID(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return whoTrace(namespace, g, w, r)
|
return whoTrace(namespace, g, w, r)
|
||||||
}
|
}
|
||||||
|
|
@ -145,7 +146,7 @@ func whoPost(namespace string, g storage.Graph, w http.ResponseWriter, r *http.R
|
||||||
}
|
}
|
||||||
|
|
||||||
func whoDelete(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error {
|
func whoDelete(namespace string, g storage.Graph, w http.ResponseWriter, r *http.Request) error {
|
||||||
id, err := getID(r)
|
id, err := getCleanID(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
return json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
||||||
|
|
@ -201,6 +202,11 @@ func whoTrace(namespace string, g storage.Graph, w http.ResponseWriter, r *http.
|
||||||
return enc.Encode(names)
|
return enc.Encode(names)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getCleanID(r *http.Request) (string, error) {
|
||||||
|
id, err := getID(r)
|
||||||
|
return sanitize(id), err
|
||||||
|
}
|
||||||
|
|
||||||
func getID(r *http.Request) (string, error) {
|
func getID(r *http.Request) (string, error) {
|
||||||
id := r.URL.Query().Get("id")
|
id := r.URL.Query().Get("id")
|
||||||
if id == "" {
|
if id == "" {
|
||||||
|
|
@ -210,11 +216,11 @@ func getID(r *http.Request) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func sortOnes(ones []entity.One, r *http.Request) []entity.One {
|
func sortOnes(ones []entity.One, r *http.Request) []entity.One {
|
||||||
sorting := r.URL.Query().Get("sort")
|
sorting := sanitize(r.URL.Query().Get("sort"))
|
||||||
if sorting == "" {
|
if sorting == "" {
|
||||||
sorting = entity.Modified
|
sorting = entity.Modified
|
||||||
}
|
}
|
||||||
order := r.URL.Query().Get("order")
|
order := sanitize(r.URL.Query().Get("order"))
|
||||||
if order == "" {
|
if order == "" {
|
||||||
order = "-1"
|
order = "-1"
|
||||||
}
|
}
|
||||||
|
|
@ -236,3 +242,9 @@ func sortOnes(ones []entity.One, r *http.Request) []entity.One {
|
||||||
})
|
})
|
||||||
return ones
|
return ones
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitize(s string) string {
|
||||||
|
re := regexp.MustCompile(`[^a-zA-Z0-9- _]`)
|
||||||
|
s = re.ReplaceAllString(s, `.`)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -246,6 +246,30 @@ func TestWho(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("delete regexp should be sanitized", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodDelete, "/who?namespace=col&id=.*", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("%d: %s", w.Code, w.Body.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
r = httptest.NewRequest(http.MethodTrace, "/who?namespace=col", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("%d: %s", w.Code, w.Body.Bytes())
|
||||||
|
}
|
||||||
|
var v []string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &v); err != nil {
|
||||||
|
t.Fatalf("%v: %s", err, w.Body.Bytes())
|
||||||
|
}
|
||||||
|
if len(v) < 5 {
|
||||||
|
t.Fatal(len(v))
|
||||||
|
}
|
||||||
|
t.Logf("%+v", v)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("patch fake", func(t *testing.T) {
|
t.Run("patch fake", func(t *testing.T) {
|
||||||
r := httptest.NewRequest(http.MethodPatch, "/who?namespace=col&id=FAKER"+want.Name, nil)
|
r := httptest.NewRequest(http.MethodPatch, "/who?namespace=col&id=FAKER"+want.Name, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue