Limit incoming request body size for all endpoints and add rate limiting wrappera round storage

master
breel 2020-07-26 20:25:39 -06:00
parent c3b948556c
commit 36c4ae520d
6 changed files with 86 additions and 5 deletions

View File

@ -17,6 +17,7 @@ type Config struct {
Auth bool Auth bool
AuthLifetime time.Duration AuthLifetime time.Duration
MaxFileSize int64 MaxFileSize int64
RPS int
} }
func New() Config { func New() Config {
@ -37,6 +38,7 @@ func New() Config {
as.Append(args.BOOL, "auth", "check for authorized access", false) as.Append(args.BOOL, "auth", "check for authorized access", false)
as.Append(args.DURATION, "authlifetime", "duration auth is valid for", time.Hour) as.Append(args.DURATION, "authlifetime", "duration auth is valid for", time.Hour)
as.Append(args.INT, "max-file-size", "max file size for uploads in bytes", 50*(1<<20)) as.Append(args.INT, "max-file-size", "max file size for uploads in bytes", 50*(1<<20))
as.Append(args.INT, "rps", "rps per namespace", 5)
if err := as.Parse(); err != nil { if err := as.Parse(); err != nil {
os.Remove(f.Name()) os.Remove(f.Name())
@ -53,5 +55,6 @@ func New() Config {
Auth: as.GetBool("auth"), Auth: as.GetBool("auth"),
AuthLifetime: as.GetDuration("authlifetime"), AuthLifetime: as.GetDuration("authlifetime"),
MaxFileSize: int64(as.GetInt("max-file-size")), MaxFileSize: int64(as.GetInt("max-file-size")),
RPS: as.GetInt("rps"),
} }
} }

View File

@ -3,6 +3,7 @@ package entity
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"time" "time"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
@ -68,6 +69,12 @@ func (o One) MarshalBSON() ([]byte, error) {
if err := json.Unmarshal(b, &m); err != nil { if err := json.Unmarshal(b, &m); err != nil {
return nil, err return nil, err
} }
for k, v := range m {
switch v.(type) {
case string:
m[k] = strings.TrimSpace(v.(string))
}
}
if name, ok := m[JSONName]; ok { if name, ok := m[JSONName]; ok {
m[Name] = name m[Name] = name
delete(m, JSONName) delete(m, JSONName)

View File

@ -23,11 +23,11 @@ func TestIntegration(t *testing.T) {
f.Close() f.Close()
defer os.Remove(f.Name()) defer os.Remove(f.Name())
os.Setenv("DBURI", f.Name()) os.Setenv("DBURI", f.Name())
graph := NewGraph() graph := NewRateLimitedGraph()
ctx, can := context.WithCancel(context.TODO()) ctx, can := context.WithCancel(context.TODO())
defer can() defer can()
clean := func() { clean := func() {
graph.driver.Delete(context.TODO(), "col", map[string]string{}) graph.g.driver.Delete(context.TODO(), "col", map[string]string{})
} }
clean() clean()
defer clean() defer clean()
@ -42,7 +42,7 @@ func TestIntegration(t *testing.T) {
cleanFill := func() { cleanFill := func() {
clean() clean()
for i := range ones { for i := range ones {
if err := graph.driver.Insert(context.TODO(), "col", ones[i]); err != nil { if err := graph.g.driver.Insert(context.TODO(), "col", ones[i]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@ -0,0 +1,63 @@
package storage
import (
"context"
"fmt"
"local/dndex/config"
"local/dndex/storage/entity"
"sync"
"golang.org/x/time/rate"
)
type RateLimitedGraph struct {
g Graph
rps int
limiters *sync.Map
}
func NewRateLimitedGraph() RateLimitedGraph {
return RateLimitedGraph{
g: NewGraph(),
rps: config.New().RPS,
limiters: &sync.Map{},
}
}
func (rlg RateLimitedGraph) limit(ctx context.Context, namespace string) error {
limiter, ok := rlg.limiters.Load(namespace)
if !ok {
config := config.New()
limiter = rate.NewLimiter(rate.Limit(config.RPS), config.RPS)
rlg.limiters.Store(namespace, limiter)
}
limit, ok := limiter.(*rate.Limiter)
if !ok {
return fmt.Errorf("rate limiter is of type %T", limiter)
}
return limit.Wait(ctx)
}
func (rlg RateLimitedGraph) Delete(ctx context.Context, namespace string, filter interface{}) error {
return rlg.g.Delete(ctx, namespace, filter)
}
func (rlg RateLimitedGraph) Insert(ctx context.Context, namespace string, one entity.One) error {
return rlg.g.Insert(ctx, namespace, one)
}
func (rlg RateLimitedGraph) List(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
return rlg.g.List(ctx, namespace, from...)
}
func (rlg RateLimitedGraph) ListCaseInsensitive(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
return rlg.g.ListCaseInsensitive(ctx, namespace, from...)
}
func (rlg RateLimitedGraph) Search(ctx context.Context, namespace string, nameContains string) ([]entity.One, error) {
return rlg.g.Search(ctx, namespace, nameContains)
}
func (rlg RateLimitedGraph) Update(ctx context.Context, namespace string, one entity.One, modify interface{}) error {
return rlg.g.Update(ctx, namespace, one, modify)
}

View File

@ -82,7 +82,7 @@ func filesPostFromDirectLink(w http.ResponseWriter, r *http.Request) error {
return err return err
} }
defer f.Close() defer f.Close()
_, err = io.Copy(f, io.LimitReader(resp.Body, config.New().MaxFileSize)) _, err = io.Copy(f, resp.Body)
return err return err
} }
@ -109,7 +109,7 @@ func filesPostFromUpload(w http.ResponseWriter, r *http.Request) error {
return err return err
} }
defer file.Close() defer file.Close()
if _, err := io.Copy(f, io.LimitReader(file, config.New().MaxFileSize)); err != nil { if _, err := io.Copy(f, file); err != nil {
return err return err
} }
return json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"}) return json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"})

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"local/dndex/config" "local/dndex/config"
"local/dndex/storage" "local/dndex/storage"
"local/gziphttp" "local/gziphttp"
@ -72,6 +73,13 @@ func jsonHandler(g storage.Graph) http.Handler {
defer gz.Close() defer gz.Close()
w = gz w = gz
} }
r.Body = struct {
io.Reader
io.Closer
}{
Reader: io.LimitReader(r.Body, config.New().MaxFileSize),
Closer: r.Body,
}
mux.ServeHTTP(w, r) mux.ServeHTTP(w, r)
}) })
} }