Limit incoming request body size for all endpoints and add rate limiting wrappera round storage

master
breel 2020-07-26 20:25:39 -06:00
parent c3b948556c
commit 36c4ae520d
6 changed files with 86 additions and 5 deletions

View File

@ -17,6 +17,7 @@ type Config struct {
Auth bool
AuthLifetime time.Duration
MaxFileSize int64
RPS int
}
func New() Config {
@ -37,6 +38,7 @@ func New() Config {
as.Append(args.BOOL, "auth", "check for authorized access", false)
as.Append(args.DURATION, "authlifetime", "duration auth is valid for", time.Hour)
as.Append(args.INT, "max-file-size", "max file size for uploads in bytes", 50*(1<<20))
as.Append(args.INT, "rps", "rps per namespace", 5)
if err := as.Parse(); err != nil {
os.Remove(f.Name())
@ -53,5 +55,6 @@ func New() Config {
Auth: as.GetBool("auth"),
AuthLifetime: as.GetDuration("authlifetime"),
MaxFileSize: int64(as.GetInt("max-file-size")),
RPS: as.GetInt("rps"),
}
}

View File

@ -3,6 +3,7 @@ package entity
import (
"encoding/json"
"fmt"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson"
@ -68,6 +69,12 @@ func (o One) MarshalBSON() ([]byte, error) {
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
for k, v := range m {
switch v.(type) {
case string:
m[k] = strings.TrimSpace(v.(string))
}
}
if name, ok := m[JSONName]; ok {
m[Name] = name
delete(m, JSONName)

View File

@ -23,11 +23,11 @@ func TestIntegration(t *testing.T) {
f.Close()
defer os.Remove(f.Name())
os.Setenv("DBURI", f.Name())
graph := NewGraph()
graph := NewRateLimitedGraph()
ctx, can := context.WithCancel(context.TODO())
defer can()
clean := func() {
graph.driver.Delete(context.TODO(), "col", map[string]string{})
graph.g.driver.Delete(context.TODO(), "col", map[string]string{})
}
clean()
defer clean()
@ -42,7 +42,7 @@ func TestIntegration(t *testing.T) {
cleanFill := func() {
clean()
for i := range ones {
if err := graph.driver.Insert(context.TODO(), "col", ones[i]); err != nil {
if err := graph.g.driver.Insert(context.TODO(), "col", ones[i]); err != nil {
t.Fatal(err)
}
}

View File

@ -0,0 +1,63 @@
package storage
import (
"context"
"fmt"
"local/dndex/config"
"local/dndex/storage/entity"
"sync"
"golang.org/x/time/rate"
)
type RateLimitedGraph struct {
g Graph
rps int
limiters *sync.Map
}
func NewRateLimitedGraph() RateLimitedGraph {
return RateLimitedGraph{
g: NewGraph(),
rps: config.New().RPS,
limiters: &sync.Map{},
}
}
func (rlg RateLimitedGraph) limit(ctx context.Context, namespace string) error {
limiter, ok := rlg.limiters.Load(namespace)
if !ok {
config := config.New()
limiter = rate.NewLimiter(rate.Limit(config.RPS), config.RPS)
rlg.limiters.Store(namespace, limiter)
}
limit, ok := limiter.(*rate.Limiter)
if !ok {
return fmt.Errorf("rate limiter is of type %T", limiter)
}
return limit.Wait(ctx)
}
func (rlg RateLimitedGraph) Delete(ctx context.Context, namespace string, filter interface{}) error {
return rlg.g.Delete(ctx, namespace, filter)
}
func (rlg RateLimitedGraph) Insert(ctx context.Context, namespace string, one entity.One) error {
return rlg.g.Insert(ctx, namespace, one)
}
func (rlg RateLimitedGraph) List(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
return rlg.g.List(ctx, namespace, from...)
}
func (rlg RateLimitedGraph) ListCaseInsensitive(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
return rlg.g.ListCaseInsensitive(ctx, namespace, from...)
}
func (rlg RateLimitedGraph) Search(ctx context.Context, namespace string, nameContains string) ([]entity.One, error) {
return rlg.g.Search(ctx, namespace, nameContains)
}
func (rlg RateLimitedGraph) Update(ctx context.Context, namespace string, one entity.One, modify interface{}) error {
return rlg.g.Update(ctx, namespace, one, modify)
}

View File

@ -82,7 +82,7 @@ func filesPostFromDirectLink(w http.ResponseWriter, r *http.Request) error {
return err
}
defer f.Close()
_, err = io.Copy(f, io.LimitReader(resp.Body, config.New().MaxFileSize))
_, err = io.Copy(f, resp.Body)
return err
}
@ -109,7 +109,7 @@ func filesPostFromUpload(w http.ResponseWriter, r *http.Request) error {
return err
}
defer file.Close()
if _, err := io.Copy(f, io.LimitReader(file, config.New().MaxFileSize)); err != nil {
if _, err := io.Copy(f, file); err != nil {
return err
}
return json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"})

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"local/dndex/config"
"local/dndex/storage"
"local/gziphttp"
@ -72,6 +73,13 @@ func jsonHandler(g storage.Graph) http.Handler {
defer gz.Close()
w = gz
}
r.Body = struct {
io.Reader
io.Closer
}{
Reader: io.LimitReader(r.Body, config.New().MaxFileSize),
Closer: r.Body,
}
mux.ServeHTTP(w, r)
})
}