diff --git a/oauth2server/config/config.go b/oauth2server/config/config.go index a99cdbd..c622fcb 100644 --- a/oauth2server/config/config.go +++ b/oauth2server/config/config.go @@ -8,13 +8,14 @@ import ( ) var ( - Port string - Store string - StoreAddr string - StoreUser string - StorePass string - SecretA string - SecretB string + Port string + Store string + StoreAddr string + StoreUser string + StorePass string + SecretA string + SecretB string + UserRegistration bool ) func init() { @@ -39,6 +40,7 @@ func Refresh() { as.Append(args.STRING, "storeAddr", "addr of DB", "/tmp/oauth2server.db") as.Append(args.STRING, "storeUser", "user of DB", "") as.Append(args.STRING, "storePass", "pass of DB", "") + as.Append(args.BOOL, "users", "allow user registration", false) if err := as.Parse(); err != nil { panic(err) } @@ -50,4 +52,5 @@ func Refresh() { StorePass = as.Get("storepass").GetString() SecretA = as.Get("secreta").GetString() SecretB = as.Get("secretb").GetString() + UserRegistration = as.Get("users").GetBool() || Store == "map" } diff --git a/oauth2server/server/routes.go b/oauth2server/server/routes.go index 9187d31..1727b34 100644 --- a/oauth2server/server/routes.go +++ b/oauth2server/server/routes.go @@ -2,11 +2,13 @@ package server import ( "fmt" + "local/oauth2/oauth2server/config" "net/http" ) func (s *Server) Routes() error { endpoints := []struct { + skip bool path string handler http.HandlerFunc }{ @@ -23,16 +25,21 @@ func (s *Server) Routes() error { handler: s.usersLog, }, { + skip: !config.UserRegistration, path: fmt.Sprintf("users/register"), handler: s.usersRegister, }, { + skip: !config.UserRegistration, path: fmt.Sprintf("users/submit"), handler: s.usersSubmit, }, } for _, endpoint := range endpoints { + if endpoint.skip { + continue + } if err := s.Add(endpoint.path, endpoint.handler); err != nil { return err }