From 80017bb32ba6a85f1efd2699644f1a9438d1eaa9 Mon Sep 17 00:00:00 2001 From: bel Date: Tue, 22 Oct 2019 02:12:27 +0000 Subject: [PATCH] Rate limit login stuff --- oauth2server/server/authorize.go | 1 + oauth2server/server/server.go | 10 +++++++--- oauth2server/server/users.go | 3 +++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/oauth2server/server/authorize.go b/oauth2server/server/authorize.go index a9c1952..5eb5b2b 100644 --- a/oauth2server/server/authorize.go +++ b/oauth2server/server/authorize.go @@ -11,6 +11,7 @@ import ( ) func (s *Server) authorize(w http.ResponseWriter, r *http.Request) { + s.limiter.Wait(r.Context()) if r.Method != "POST" { http.NotFound(w, r) return diff --git a/oauth2server/server/server.go b/oauth2server/server/server.go index d064ab9..be85262 100644 --- a/oauth2server/server/server.go +++ b/oauth2server/server/server.go @@ -4,6 +4,8 @@ import ( "local/oauth2/oauth2server/config" "local/router" "local/storage" + + "golang.org/x/time/rate" ) var wildcard = router.Wildcard @@ -17,7 +19,8 @@ const ( type Server struct { *router.Router - store storage.DB + store storage.DB + limiter *rate.Limiter } func New() *Server { @@ -27,8 +30,9 @@ func New() *Server { } purgeIssuedCredentials(store) return &Server{ - Router: router.New(), - store: store, + Router: router.New(), + store: store, + limiter: rate.NewLimiter(1, 3), } } diff --git a/oauth2server/server/users.go b/oauth2server/server/users.go index 1089847..1651090 100644 --- a/oauth2server/server/users.go +++ b/oauth2server/server/users.go @@ -13,6 +13,7 @@ import ( ) func (s *Server) usersLog(w http.ResponseWriter, r *http.Request) { + s.limiter.Wait(r.Context()) q := r.URL.Query() fmt.Fprintln(w, ` @@ -27,6 +28,7 @@ func (s *Server) usersLog(w http.ResponseWriter, r *http.Request) { } func (s *Server) usersRegister(w http.ResponseWriter, r *http.Request) { + s.limiter.Wait(r.Context()) fmt.Fprintln(w, ` @@ -40,6 +42,7 @@ func (s *Server) usersRegister(w http.ResponseWriter, r *http.Request) { } func (s *Server) usersSubmit(w http.ResponseWriter, r *http.Request) { + s.limiter.Wait(r.Context()) if r.Method != "POST" { http.NotFound(w, r) return