Limit incoming request body size for all endpoints and add rate limiting wrappera round storage
parent
c3b948556c
commit
36c4ae520d
|
|
@ -17,6 +17,7 @@ type Config struct {
|
|||
Auth bool
|
||||
AuthLifetime time.Duration
|
||||
MaxFileSize int64
|
||||
RPS int
|
||||
}
|
||||
|
||||
func New() Config {
|
||||
|
|
@ -37,6 +38,7 @@ func New() Config {
|
|||
as.Append(args.BOOL, "auth", "check for authorized access", false)
|
||||
as.Append(args.DURATION, "authlifetime", "duration auth is valid for", time.Hour)
|
||||
as.Append(args.INT, "max-file-size", "max file size for uploads in bytes", 50*(1<<20))
|
||||
as.Append(args.INT, "rps", "rps per namespace", 5)
|
||||
|
||||
if err := as.Parse(); err != nil {
|
||||
os.Remove(f.Name())
|
||||
|
|
@ -53,5 +55,6 @@ func New() Config {
|
|||
Auth: as.GetBool("auth"),
|
||||
AuthLifetime: as.GetDuration("authlifetime"),
|
||||
MaxFileSize: int64(as.GetInt("max-file-size")),
|
||||
RPS: as.GetInt("rps"),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package entity
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
|
|
@ -68,6 +69,12 @@ func (o One) MarshalBSON() ([]byte, error) {
|
|||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range m {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
m[k] = strings.TrimSpace(v.(string))
|
||||
}
|
||||
}
|
||||
if name, ok := m[JSONName]; ok {
|
||||
m[Name] = name
|
||||
delete(m, JSONName)
|
||||
|
|
|
|||
|
|
@ -23,11 +23,11 @@ func TestIntegration(t *testing.T) {
|
|||
f.Close()
|
||||
defer os.Remove(f.Name())
|
||||
os.Setenv("DBURI", f.Name())
|
||||
graph := NewGraph()
|
||||
graph := NewRateLimitedGraph()
|
||||
ctx, can := context.WithCancel(context.TODO())
|
||||
defer can()
|
||||
clean := func() {
|
||||
graph.driver.Delete(context.TODO(), "col", map[string]string{})
|
||||
graph.g.driver.Delete(context.TODO(), "col", map[string]string{})
|
||||
}
|
||||
clean()
|
||||
defer clean()
|
||||
|
|
@ -42,7 +42,7 @@ func TestIntegration(t *testing.T) {
|
|||
cleanFill := func() {
|
||||
clean()
|
||||
for i := range ones {
|
||||
if err := graph.driver.Insert(context.TODO(), "col", ones[i]); err != nil {
|
||||
if err := graph.g.driver.Insert(context.TODO(), "col", ones[i]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,63 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"local/dndex/config"
|
||||
"local/dndex/storage/entity"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type RateLimitedGraph struct {
|
||||
g Graph
|
||||
rps int
|
||||
limiters *sync.Map
|
||||
}
|
||||
|
||||
func NewRateLimitedGraph() RateLimitedGraph {
|
||||
return RateLimitedGraph{
|
||||
g: NewGraph(),
|
||||
rps: config.New().RPS,
|
||||
limiters: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) limit(ctx context.Context, namespace string) error {
|
||||
limiter, ok := rlg.limiters.Load(namespace)
|
||||
if !ok {
|
||||
config := config.New()
|
||||
limiter = rate.NewLimiter(rate.Limit(config.RPS), config.RPS)
|
||||
rlg.limiters.Store(namespace, limiter)
|
||||
}
|
||||
limit, ok := limiter.(*rate.Limiter)
|
||||
if !ok {
|
||||
return fmt.Errorf("rate limiter is of type %T", limiter)
|
||||
}
|
||||
return limit.Wait(ctx)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) Delete(ctx context.Context, namespace string, filter interface{}) error {
|
||||
return rlg.g.Delete(ctx, namespace, filter)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) Insert(ctx context.Context, namespace string, one entity.One) error {
|
||||
return rlg.g.Insert(ctx, namespace, one)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) List(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
|
||||
return rlg.g.List(ctx, namespace, from...)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) ListCaseInsensitive(ctx context.Context, namespace string, from ...string) ([]entity.One, error) {
|
||||
return rlg.g.ListCaseInsensitive(ctx, namespace, from...)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) Search(ctx context.Context, namespace string, nameContains string) ([]entity.One, error) {
|
||||
return rlg.g.Search(ctx, namespace, nameContains)
|
||||
}
|
||||
|
||||
func (rlg RateLimitedGraph) Update(ctx context.Context, namespace string, one entity.One, modify interface{}) error {
|
||||
return rlg.g.Update(ctx, namespace, one, modify)
|
||||
}
|
||||
|
|
@ -82,7 +82,7 @@ func filesPostFromDirectLink(w http.ResponseWriter, r *http.Request) error {
|
|||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
_, err = io.Copy(f, io.LimitReader(resp.Body, config.New().MaxFileSize))
|
||||
_, err = io.Copy(f, resp.Body)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -109,7 +109,7 @@ func filesPostFromUpload(w http.ResponseWriter, r *http.Request) error {
|
|||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
if _, err := io.Copy(f, io.LimitReader(file, config.New().MaxFileSize)); err != nil {
|
||||
if _, err := io.Copy(f, file); err != nil {
|
||||
return err
|
||||
}
|
||||
return json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"})
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"local/dndex/config"
|
||||
"local/dndex/storage"
|
||||
"local/gziphttp"
|
||||
|
|
@ -72,6 +73,13 @@ func jsonHandler(g storage.Graph) http.Handler {
|
|||
defer gz.Close()
|
||||
w = gz
|
||||
}
|
||||
r.Body = struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}{
|
||||
Reader: io.LimitReader(r.Body, config.New().MaxFileSize),
|
||||
Closer: r.Body,
|
||||
}
|
||||
mux.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue