Limit incoming request body size for all endpoints and add rate limiting wrappera round storage

This commit is contained in:
breel
2020-07-26 20:25:39 -06:00
parent c3b948556c
commit 36c4ae520d
6 changed files with 86 additions and 5 deletions

View File

@@ -3,6 +3,7 @@ package entity
import (
"encoding/json"
"fmt"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson"
@@ -68,6 +69,12 @@ func (o One) MarshalBSON() ([]byte, error) {
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
for k, v := range m {
switch v.(type) {
case string:
m[k] = strings.TrimSpace(v.(string))
}
}
if name, ok := m[JSONName]; ok {
m[Name] = name
delete(m, JSONName)

View File

@@ -23,11 +23,11 @@ func TestIntegration(t *testing.T) {
f.Close()
defer os.Remove(f.Name())
os.Setenv("DBURI", f.Name())
graph := NewGraph()
graph := NewRateLimitedGraph()
ctx, can := context.WithCancel(context.TODO())
defer can()
clean := func() {
graph.driver.Delete(context.TODO(), "col", map[string]string{})
graph.g.driver.Delete(context.TODO(), "col", map[string]string{})
}
clean()
defer clean()
@@ -42,7 +42,7 @@ func TestIntegration(t *testing.T) {
cleanFill := func() {
clean()
for i := range ones {
if err := graph.driver.Insert(context.TODO(), "col", ones[i]); err != nil {
if err := graph.g.driver.Insert(context.TODO(), "col", ones[i]); err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,63 @@
package storage
import (
"context"
"fmt"
"local/dndex/config"
"local/dndex/storage/entity"
"sync"
"golang.org/x/time/rate"
)
type RateLimitedGraph struct {
g Graph
rps int
limiters *sync.Map
}
func NewRateLimitedGraph() RateLimitedGraph {
return RateLimitedGraph{
g: NewGraph(),
rps: config.New().RPS,
limiters: &sync.Map{},
}
}
func (rlg RateLimitedGraph) limit(ctx context.Context, namespace string) error {
limiter, ok := rlg.limiters.Load(namespace)
if !ok {
config := config.New()
limiter = rate.NewLimiter(rate.Limit(config.RPS), config.RPS)
rlg.limiters.Store(namespace, limiter)
}
limit, ok := limiter.(*rate.Limiter)
if !ok {
return fmt.Errorf("rate limiter is of type %T", limiter)
}
return limit.Wait(ctx)
}
func (rlg RateLimitedGraph) Delete(ctx context.Context, namespace string, filter interface{}) error {
return rlg.g.Delete(ctx, namespace, filter)
}
func (rlg RateLimitedGraph) Insert(ctx context.Context, namespace string, one entity.One) error {
return rlg.g.Insert(ctx, namespace, one)
}
func (rlg RateLimitedGraph) List(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
return rlg.g.List(ctx, namespace, from...)
}
func (rlg RateLimitedGraph) ListCaseInsensitive(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
return rlg.g.ListCaseInsensitive(ctx, namespace, from...)
}
func (rlg RateLimitedGraph) Search(ctx context.Context, namespace string, nameContains string) ([]entity.One, error) {
return rlg.g.Search(ctx, namespace, nameContains)
}
func (rlg RateLimitedGraph) Update(ctx context.Context, namespace string, one entity.One, modify interface{}) error {
return rlg.g.Update(ctx, namespace, one, modify)
}