From 36c4ae520da10efbb5af96c30d7b4b36e11186d7 Mon Sep 17 00:00:00 2001 From: breel Date: Sun, 26 Jul 2020 20:25:39 -0600 Subject: [PATCH] Limit incoming request body size for all endpoints and add rate limiting wrappera round storage --- config/config.go | 3 ++ storage/entity/one.go | 7 +++++ storage/graph_test.go | 6 ++-- storage/ratelimitedgraph.go | 63 +++++++++++++++++++++++++++++++++++++ view/files.go | 4 +-- view/json.go | 8 +++++ 6 files changed, 86 insertions(+), 5 deletions(-) create mode 100644 storage/ratelimitedgraph.go diff --git a/config/config.go b/config/config.go index 71a585a..9ca495f 100644 --- a/config/config.go +++ b/config/config.go @@ -17,6 +17,7 @@ type Config struct { Auth bool AuthLifetime time.Duration MaxFileSize int64 + RPS int } func New() Config { @@ -37,6 +38,7 @@ func New() Config { as.Append(args.BOOL, "auth", "check for authorized access", false) as.Append(args.DURATION, "authlifetime", "duration auth is valid for", time.Hour) as.Append(args.INT, "max-file-size", "max file size for uploads in bytes", 50*(1<<20)) + as.Append(args.INT, "rps", "rps per namespace", 5) if err := as.Parse(); err != nil { os.Remove(f.Name()) @@ -53,5 +55,6 @@ func New() Config { Auth: as.GetBool("auth"), AuthLifetime: as.GetDuration("authlifetime"), MaxFileSize: int64(as.GetInt("max-file-size")), + RPS: as.GetInt("rps"), } } diff --git a/storage/entity/one.go b/storage/entity/one.go index 3223878..875f7f5 100644 --- a/storage/entity/one.go +++ b/storage/entity/one.go @@ -3,6 +3,7 @@ package entity import ( "encoding/json" "fmt" + "strings" "time" "go.mongodb.org/mongo-driver/bson" @@ -68,6 +69,12 @@ func (o One) MarshalBSON() ([]byte, error) { if err := json.Unmarshal(b, &m); err != nil { return nil, err } + for k, v := range m { + switch v.(type) { + case string: + m[k] = strings.TrimSpace(v.(string)) + } + } if name, ok := m[JSONName]; ok { m[Name] = name delete(m, JSONName) diff --git a/storage/graph_test.go b/storage/graph_test.go index cbe4a6a..12840fc 100644 --- a/storage/graph_test.go +++ b/storage/graph_test.go @@ -23,11 +23,11 @@ func TestIntegration(t *testing.T) { f.Close() defer os.Remove(f.Name()) os.Setenv("DBURI", f.Name()) - graph := NewGraph() + graph := NewRateLimitedGraph() ctx, can := context.WithCancel(context.TODO()) defer can() clean := func() { - graph.driver.Delete(context.TODO(), "col", map[string]string{}) + graph.g.driver.Delete(context.TODO(), "col", map[string]string{}) } clean() defer clean() @@ -42,7 +42,7 @@ func TestIntegration(t *testing.T) { cleanFill := func() { clean() for i := range ones { - if err := graph.driver.Insert(context.TODO(), "col", ones[i]); err != nil { + if err := graph.g.driver.Insert(context.TODO(), "col", ones[i]); err != nil { t.Fatal(err) } } diff --git a/storage/ratelimitedgraph.go b/storage/ratelimitedgraph.go new file mode 100644 index 0000000..2c344c0 --- /dev/null +++ b/storage/ratelimitedgraph.go @@ -0,0 +1,63 @@ +package storage + +import ( + "context" + "fmt" + "local/dndex/config" + "local/dndex/storage/entity" + "sync" + + "golang.org/x/time/rate" +) + +type RateLimitedGraph struct { + g Graph + rps int + limiters *sync.Map +} + +func NewRateLimitedGraph() RateLimitedGraph { + return RateLimitedGraph{ + g: NewGraph(), + rps: config.New().RPS, + limiters: &sync.Map{}, + } +} + +func (rlg RateLimitedGraph) limit(ctx context.Context, namespace string) error { + limiter, ok := rlg.limiters.Load(namespace) + if !ok { + config := config.New() + limiter = rate.NewLimiter(rate.Limit(config.RPS), config.RPS) + rlg.limiters.Store(namespace, limiter) + } + limit, ok := limiter.(*rate.Limiter) + if !ok { + return fmt.Errorf("rate limiter is of type %T", limiter) + } + return limit.Wait(ctx) +} + +func (rlg RateLimitedGraph) Delete(ctx context.Context, namespace string, filter interface{}) error { + return rlg.g.Delete(ctx, namespace, filter) +} + +func (rlg RateLimitedGraph) Insert(ctx context.Context, namespace string, one entity.One) error { + return rlg.g.Insert(ctx, namespace, one) +} + +func (rlg RateLimitedGraph) List(ctx context.Context, namespace string, from ...string) ([]entity.One, error) { + return rlg.g.List(ctx, namespace, from...) +} + +func (rlg RateLimitedGraph) ListCaseInsensitive(ctx context.Context, namespace string, from ...string) ([]entity.One, error) { + return rlg.g.ListCaseInsensitive(ctx, namespace, from...) +} + +func (rlg RateLimitedGraph) Search(ctx context.Context, namespace string, nameContains string) ([]entity.One, error) { + return rlg.g.Search(ctx, namespace, nameContains) +} + +func (rlg RateLimitedGraph) Update(ctx context.Context, namespace string, one entity.One, modify interface{}) error { + return rlg.g.Update(ctx, namespace, one, modify) +} diff --git a/view/files.go b/view/files.go index e948169..e676610 100644 --- a/view/files.go +++ b/view/files.go @@ -82,7 +82,7 @@ func filesPostFromDirectLink(w http.ResponseWriter, r *http.Request) error { return err } defer f.Close() - _, err = io.Copy(f, io.LimitReader(resp.Body, config.New().MaxFileSize)) + _, err = io.Copy(f, resp.Body) return err } @@ -109,7 +109,7 @@ func filesPostFromUpload(w http.ResponseWriter, r *http.Request) error { return err } defer file.Close() - if _, err := io.Copy(f, io.LimitReader(file, config.New().MaxFileSize)); err != nil { + if _, err := io.Copy(f, file); err != nil { return err } return json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"}) diff --git a/view/json.go b/view/json.go index d474759..9765668 100644 --- a/view/json.go +++ b/view/json.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "local/dndex/config" "local/dndex/storage" "local/gziphttp" @@ -72,6 +73,13 @@ func jsonHandler(g storage.Graph) http.Handler { defer gz.Close() w = gz } + r.Body = struct { + io.Reader + io.Closer + }{ + Reader: io.LimitReader(r.Body, config.New().MaxFileSize), + Closer: r.Body, + } mux.ServeHTTP(w, r) }) }