Limit incoming request body size for all endpoints and add rate limiting wrappera round storage
This commit is contained in:
@@ -3,6 +3,7 @@ package entity
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
@@ -68,6 +69,12 @@ func (o One) MarshalBSON() ([]byte, error) {
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range m {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
m[k] = strings.TrimSpace(v.(string))
|
||||
}
|
||||
}
|
||||
if name, ok := m[JSONName]; ok {
|
||||
m[Name] = name
|
||||
delete(m, JSONName)
|
||||
|
||||
@@ -23,11 +23,11 @@ func TestIntegration(t *testing.T) {
|
||||
f.Close()
|
||||
defer os.Remove(f.Name())
|
||||
os.Setenv("DBURI", f.Name())
|
||||
graph := NewGraph()
|
||||
graph := NewRateLimitedGraph()
|
||||
ctx, can := context.WithCancel(context.TODO())
|
||||
defer can()
|
||||
clean := func() {
|
||||
graph.driver.Delete(context.TODO(), "col", map[string]string{})
|
||||
graph.g.driver.Delete(context.TODO(), "col", map[string]string{})
|
||||
}
|
||||
clean()
|
||||
defer clean()
|
||||
@@ -42,7 +42,7 @@ func TestIntegration(t *testing.T) {
|
||||
cleanFill := func() {
|
||||
clean()
|
||||
for i := range ones {
|
||||
if err := graph.driver.Insert(context.TODO(), "col", ones[i]); err != nil {
|
||||
if err := graph.g.driver.Insert(context.TODO(), "col", ones[i]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
63
storage/ratelimitedgraph.go
Normal file
63
storage/ratelimitedgraph.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"local/dndex/config"
|
||||
"local/dndex/storage/entity"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type RateLimitedGraph struct {
|
||||
g Graph
|
||||
rps int
|
||||
limiters *sync.Map
|
||||
}
|
||||
|
||||
func NewRateLimitedGraph() RateLimitedGraph {
|
||||
return RateLimitedGraph{
|
||||
g: NewGraph(),
|
||||
rps: config.New().RPS,
|
||||
limiters: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) limit(ctx context.Context, namespace string) error {
|
||||
limiter, ok := rlg.limiters.Load(namespace)
|
||||
if !ok {
|
||||
config := config.New()
|
||||
limiter = rate.NewLimiter(rate.Limit(config.RPS), config.RPS)
|
||||
rlg.limiters.Store(namespace, limiter)
|
||||
}
|
||||
limit, ok := limiter.(*rate.Limiter)
|
||||
if !ok {
|
||||
return fmt.Errorf("rate limiter is of type %T", limiter)
|
||||
}
|
||||
return limit.Wait(ctx)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) Delete(ctx context.Context, namespace string, filter interface{}) error {
|
||||
return rlg.g.Delete(ctx, namespace, filter)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) Insert(ctx context.Context, namespace string, one entity.One) error {
|
||||
return rlg.g.Insert(ctx, namespace, one)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) List(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
|
||||
return rlg.g.List(ctx, namespace, from...)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) ListCaseInsensitive(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
|
||||
return rlg.g.ListCaseInsensitive(ctx, namespace, from...)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) Search(ctx context.Context, namespace string, nameContains string) ([]entity.One, error) {
|
||||
return rlg.g.Search(ctx, namespace, nameContains)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) Update(ctx context.Context, namespace string, one entity.One, modify interface{}) error {
|
||||
return rlg.g.Update(ctx, namespace, one, modify)
|
||||
}
|
||||
Reference in New Issue
Block a user