diff --git a/config/config.go b/config/config.go index 9ca495f..0fef641 100644 --- a/config/config.go +++ b/config/config.go @@ -18,6 +18,7 @@ type Config struct { AuthLifetime time.Duration MaxFileSize int64 RPS int + SysRPS int } func New() Config { @@ -39,6 +40,7 @@ func New() Config { as.Append(args.DURATION, "authlifetime", "duration auth is valid for", time.Hour) as.Append(args.INT, "max-file-size", "max file size for uploads in bytes", 50*(1<<20)) as.Append(args.INT, "rps", "rps per namespace", 5) + as.Append(args.INT, "sys-rps", "rps for the sys", 10) if err := as.Parse(); err != nil { os.Remove(f.Name()) @@ -56,5 +58,6 @@ func New() Config { AuthLifetime: as.GetDuration("authlifetime"), MaxFileSize: int64(as.GetInt("max-file-size")), RPS: as.GetInt("rps"), + SysRPS: as.GetInt("sys-rps"), } } diff --git a/view/json.go b/view/json.go index 9765668..f2534b3 100644 --- a/view/json.go +++ b/view/json.go @@ -11,6 +11,8 @@ import ( "log" "net/http" "strings" + + "golang.org/x/time/rate" ) func JSON(g storage.Graph) error { @@ -67,7 +69,7 @@ func jsonHandler(g storage.Graph) http.Handler { }) } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return rateLimited(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if gziphttp.Can(r) { gz := gziphttp.New(w) defer gz.Close() @@ -81,6 +83,18 @@ func jsonHandler(g storage.Graph) http.Handler { Closer: r.Body, } mux.ServeHTTP(w, r) + })) +} + +func rateLimited(foo http.HandlerFunc) http.HandlerFunc { + sysRPS := config.New().SysRPS + limiter := rate.NewLimiter(rate.Limit(sysRPS), sysRPS) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := limiter.Wait(r.Context()); err != nil { + http.Error(w, err.Error(), http.StatusTooManyRequests) + } + foo(w, r) }) } diff --git a/view/json_test.go b/view/json_test.go index 2d1493b..a1d9300 100644 --- a/view/json_test.go +++ b/view/json_test.go @@ -1,11 +1,14 @@ package view import ( + "context" "local/dndex/config" "net/http" + "net/http/httptest" "net/url" "os" "testing" + "time" ) func TestGetNamespace(t *testing.T) { @@ -59,3 +62,31 @@ func TestGetNamespace(t *testing.T) { }) } } + +func TestRateLimited(t *testing.T) { + os.Args = os.Args[:1] + os.Setenv("SYS_RPS", "10") + foo := rateLimited(func(w http.ResponseWriter, r *http.Request) {}) + + ok := 0 + tooMany := 0 + for i := 0; i < 20; i++ { + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + ctx, can := context.WithTimeout(r.Context(), time.Millisecond*50) + defer can() + r = r.WithContext(ctx) + foo(w, r) + switch w.Code { + case http.StatusOK: + ok += 1 + case http.StatusTooManyRequests: + tooMany += 1 + default: + t.Fatal("unexpected status", w.Code) + } + } + if ok < 9 || ok > 11 { + t.Fatal(ok, tooMany) + } +}