Disable user registration with a flag

master
bel 2019-10-22 05:16:56 +00:00
parent 80017bb32b
commit ba44094eb9
2 changed files with 17 additions and 7 deletions

View File

@ -8,13 +8,14 @@ import (
) )
var ( var (
Port string Port string
Store string Store string
StoreAddr string StoreAddr string
StoreUser string StoreUser string
StorePass string StorePass string
SecretA string SecretA string
SecretB string SecretB string
UserRegistration bool
) )
func init() { func init() {
@ -39,6 +40,7 @@ func Refresh() {
as.Append(args.STRING, "storeAddr", "addr of DB", "/tmp/oauth2server.db") as.Append(args.STRING, "storeAddr", "addr of DB", "/tmp/oauth2server.db")
as.Append(args.STRING, "storeUser", "user of DB", "") as.Append(args.STRING, "storeUser", "user of DB", "")
as.Append(args.STRING, "storePass", "pass of DB", "") as.Append(args.STRING, "storePass", "pass of DB", "")
as.Append(args.BOOL, "users", "allow user registration", false)
if err := as.Parse(); err != nil { if err := as.Parse(); err != nil {
panic(err) panic(err)
} }
@ -50,4 +52,5 @@ func Refresh() {
StorePass = as.Get("storepass").GetString() StorePass = as.Get("storepass").GetString()
SecretA = as.Get("secreta").GetString() SecretA = as.Get("secreta").GetString()
SecretB = as.Get("secretb").GetString() SecretB = as.Get("secretb").GetString()
UserRegistration = as.Get("users").GetBool() || Store == "map"
} }

View File

@ -2,11 +2,13 @@ package server
import ( import (
"fmt" "fmt"
"local/oauth2/oauth2server/config"
"net/http" "net/http"
) )
func (s *Server) Routes() error { func (s *Server) Routes() error {
endpoints := []struct { endpoints := []struct {
skip bool
path string path string
handler http.HandlerFunc handler http.HandlerFunc
}{ }{
@ -23,16 +25,21 @@ func (s *Server) Routes() error {
handler: s.usersLog, handler: s.usersLog,
}, },
{ {
skip: !config.UserRegistration,
path: fmt.Sprintf("users/register"), path: fmt.Sprintf("users/register"),
handler: s.usersRegister, handler: s.usersRegister,
}, },
{ {
skip: !config.UserRegistration,
path: fmt.Sprintf("users/submit"), path: fmt.Sprintf("users/submit"),
handler: s.usersSubmit, handler: s.usersSubmit,
}, },
} }
for _, endpoint := range endpoints { for _, endpoint := range endpoints {
if endpoint.skip {
continue
}
if err := s.Add(endpoint.path, endpoint.handler); err != nil { if err := s.Add(endpoint.path, endpoint.handler); err != nil {
return err return err
} }