commit 6dc2e074fb3f0e9ade1a468600a8a62b626e6b05 Author: bel Date: Sun Oct 20 12:59:50 2019 -0600 Wannabe oauth implementation diff --git a/const.go b/const.go new file mode 100644 index 0000000..dd1c1bc --- /dev/null +++ b/const.go @@ -0,0 +1,6 @@ +package oauth2 + +const ( + COOKIE = "BOAuthZ" + REDIRECT = "BOAuthZ-Redirect" +) diff --git a/oauth2client/client.go b/oauth2client/client.go new file mode 100644 index 0000000..433fa17 --- /dev/null +++ b/oauth2client/client.go @@ -0,0 +1,63 @@ +package oauth2client + +import ( + "errors" + "local/oauth2" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +func Authenticate(server string, w http.ResponseWriter, r *http.Request) error { + oauth2server, err := url.Parse(server) + if err != nil { + return err + } + access, err := r.Cookie(oauth2.COOKIE) + if err == http.ErrNoCookie { + return login(oauth2server, w, r) + } + return verify(access.Value, oauth2server, w, r) +} + +func login(oauth2server *url.URL, w http.ResponseWriter, r *http.Request) error { + oauth2server.Path = "/users/log" + url := *r.URL + url.Host = r.Host + if url.Scheme == "" { + url.Scheme = "http" + } + cookie := &http.Cookie{ + Name: oauth2.REDIRECT, + Value: url.String(), + } + http.SetCookie(w, cookie) + http.Redirect(w, r, oauth2server.String(), http.StatusSeeOther) + return errors.New("logging in") +} + +func verify(access string, oauth2server *url.URL, w http.ResponseWriter, r *http.Request) error { + oauth2server.Path = "/verify" + data := url.Values{} + data.Set("access", access) + req, err := http.NewRequest("POST", oauth2server.String(), strings.NewReader(data.Encode())) + if err != nil { + return err + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) + c := &http.Client{ + Timeout: 5 * time.Second, + } + resp, err := c.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return login(oauth2server, w, r) + } + return nil +} diff --git a/oauth2server/config/config.go b/oauth2server/config/config.go new file mode 100644 index 0000000..a99cdbd --- /dev/null +++ b/oauth2server/config/config.go @@ -0,0 +1,53 @@ +package config + +import ( + "fmt" + "local/args" + "os" + "strings" +) + +var ( + Port string + Store string + StoreAddr string + StoreUser string + StorePass string + SecretA string + SecretB string +) + +func init() { + Refresh() +} + +func Refresh() { + if strings.Contains(fmt.Sprint(os.Args), "-test") { + return + } + defer func() { + if err := recover(); err != nil { + panic(err) + } + }() + + as := args.NewArgSet() + as.Append(args.STRING, "port", "port to listen on", "23456") + as.Append(args.STRING, "secretA", "secret A", "secret") + as.Append(args.STRING, "secretB", "secret B", "secret") + as.Append(args.STRING, "store", "type of DB", "map") + as.Append(args.STRING, "storeAddr", "addr of DB", "/tmp/oauth2server.db") + as.Append(args.STRING, "storeUser", "user of DB", "") + as.Append(args.STRING, "storePass", "pass of DB", "") + if err := as.Parse(); err != nil { + panic(err) + } + + Port = ":" + strings.TrimPrefix(as.Get("port").GetString(), ":") + Store = as.Get("store").GetString() + StoreAddr = as.Get("storeaddr").GetString() + StoreUser = as.Get("storeuser").GetString() + StorePass = as.Get("storepass").GetString() + SecretA = as.Get("secreta").GetString() + SecretB = as.Get("secretb").GetString() +} diff --git a/oauth2server/main.go b/oauth2server/main.go new file mode 100644 index 0000000..6a95b2d --- /dev/null +++ b/oauth2server/main.go @@ -0,0 +1,19 @@ +package main + +import ( + "local/oauth2/oauth2server/config" + "local/oauth2/oauth2server/server" + "log" + "net/http" +) + +func main() { + s := server.New() + if err := s.Routes(); err != nil { + panic(err) + } + log.Println("listening on", config.Port) + if err := http.ListenAndServe(config.Port, s); err != nil { + panic(err) + } +} diff --git a/oauth2server/oauth2server b/oauth2server/oauth2server new file mode 100755 index 0000000..6412a51 Binary files /dev/null and b/oauth2server/oauth2server differ diff --git a/oauth2server/server/authorize.go b/oauth2server/server/authorize.go new file mode 100644 index 0000000..a416dde --- /dev/null +++ b/oauth2server/server/authorize.go @@ -0,0 +1,59 @@ +package server + +import ( + "local/oauth2" + "local/storage" + "net/http" + + "github.com/google/uuid" +) + +func (s *Server) authorize(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.NotFound(w, r) + return + } + id := r.FormValue("username") + user, ok := s.getUser(id) + if !ok { + http.Error(w, "unknown user", http.StatusForbidden) + return + } + access, ok := s.getAccess(user) + if !ok { + http.Error(w, "no oauth for user", http.StatusForbidden) + return + } + cookie := &http.Cookie{ + Name: oauth2.COOKIE, + Value: access, + SameSite: http.SameSiteLaxMode, + } + http.SetCookie(w, cookie) + redirectCookie, err := r.Cookie(oauth2.REDIRECT) + if err != nil { + return + } + http.Redirect(w, r, redirectCookie.Value, http.StatusSeeOther) +} + +func (s *Server) genAuth(user string) { + access := uuid.New().String() + token := uuid.New().String() + s.store.Set(user, []byte(access), ACCESS) + s.store.Set(access, []byte(token), TOKEN) +} + +func (s *Server) getAccess(user string) (string, bool) { + access, err := s.store.Get(user, ACCESS) + if err == storage.ErrNotFound { + s.genAuth(user) + access, err = s.store.Get(user, ACCESS) + } + return string(access), err == nil +} + +func (s *Server) getToken(access string) (string, bool) { + token, err := s.store.Get(access, TOKEN) + return string(token), err == nil +} diff --git a/oauth2server/server/routes.go b/oauth2server/server/routes.go new file mode 100644 index 0000000..9187d31 --- /dev/null +++ b/oauth2server/server/routes.go @@ -0,0 +1,41 @@ +package server + +import ( + "fmt" + "net/http" +) + +func (s *Server) Routes() error { + endpoints := []struct { + path string + handler http.HandlerFunc + }{ + { + path: fmt.Sprintf("authorize"), + handler: s.authorize, + }, + { + path: fmt.Sprintf("verify"), + handler: s.verify, + }, + { + path: fmt.Sprintf("users/log"), + handler: s.usersLog, + }, + { + path: fmt.Sprintf("users/register"), + handler: s.usersRegister, + }, + { + path: fmt.Sprintf("users/submit"), + handler: s.usersSubmit, + }, + } + + for _, endpoint := range endpoints { + if err := s.Add(endpoint.path, endpoint.handler); err != nil { + return err + } + } + return nil +} diff --git a/oauth2server/server/server.go b/oauth2server/server/server.go new file mode 100644 index 0000000..d064ab9 --- /dev/null +++ b/oauth2server/server/server.go @@ -0,0 +1,44 @@ +package server + +import ( + "local/oauth2/oauth2server/config" + "local/router" + "local/storage" +) + +var wildcard = router.Wildcard + +const ( + USERS = "users" + ACCESS = "access" + TOKEN = "token" + SALT = "salt" +) + +type Server struct { + *router.Router + store storage.DB +} + +func New() *Server { + store, err := storage.New(storage.TypeFromString(config.Store), config.StoreAddr, config.StoreUser, config.StorePass) + if err != nil { + panic(err) + } + purgeIssuedCredentials(store) + return &Server{ + Router: router.New(), + store: store, + } +} + +func purgeIssuedCredentials(store storage.DB) { + accesses, _ := store.List([]string{ACCESS}) + for _, access := range accesses { + store.Set(access, nil, ACCESS) + } + tokens, _ := store.List([]string{TOKEN}) + for _, token := range tokens { + store.Set(token, nil, TOKEN) + } +} diff --git a/oauth2server/server/users.go b/oauth2server/server/users.go new file mode 100644 index 0000000..19397b2 --- /dev/null +++ b/oauth2server/server/users.go @@ -0,0 +1,82 @@ +package server + +import ( + "crypto/hmac" + "crypto/md5" + "crypto/sha256" + "encoding/hex" + "fmt" + "local/oauth2/oauth2server/config" + "net/http" + + "github.com/google/uuid" +) + +func (s *Server) usersLog(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, ` + + +
+ + +
+ + + `) +} + +func (s *Server) usersRegister(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, ` + + +
+ + +
+ + + `) +} + +func (s *Server) usersSubmit(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.NotFound(w, r) + return + } + id := r.FormValue("username") + if _, ok := s.getUser(id); ok { + http.Error(w, "user already exists", http.StatusConflict) + return + } + s.genUser(id) +} + +func (s *Server) genUser(id string) { + user := uuid.New().String() + salt := uuid.New().String() + s.store.Set(id, []byte(salt), SALT) + obscured := s.obscureID(id) + s.store.Set(obscured, []byte(user), USERS) +} + +func (s *Server) getUser(id string) (string, bool) { + obscured := s.obscureID(id) + user, err := s.store.Get(obscured, USERS) + return string(user), err == nil +} + +func (s *Server) obscureID(id string) string { + salt, _ := s.store.Get(id, SALT) + a := s.obscure(string(salt)+id, config.SecretA) + b := s.obscure(string(salt)+id, config.SecretB) + return a + b +} + +func (s *Server) obscure(payload, secret string) string { + hash := md5.New() + hash.Write([]byte(secret)) + key := hash.Sum(nil) + sig := hmac.New(sha256.New, key) + sig.Write([]byte(payload)) + return hex.EncodeToString(sig.Sum(nil)) +} diff --git a/oauth2server/server/verify.go b/oauth2server/server/verify.go new file mode 100644 index 0000000..5946a7e --- /dev/null +++ b/oauth2server/server/verify.go @@ -0,0 +1,19 @@ +package server + +import ( + "net/http" +) + +func (s *Server) verify(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.NotFound(w, r) + return + } + access := r.FormValue("access") + token, ok := s.getToken(access) + if !ok { + http.Error(w, "unknown access", http.StatusUnauthorized) + return + } + w.Write([]byte(token)) +} diff --git a/oauth2test/package_test.go b/oauth2test/package_test.go new file mode 100644 index 0000000..bb85b3e --- /dev/null +++ b/oauth2test/package_test.go @@ -0,0 +1,162 @@ +package oauth2 + +import ( + "errors" + "fmt" + "local/oauth2/oauth2client" + "local/oauth2/oauth2server/config" + "local/oauth2/oauth2server/server" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "regexp" + "strings" + "testing" +) + +func TestAll(t *testing.T) { + oauth2server, err := launchServer() + if err != nil { + t.Fatal(err) + } + defer oauth2server.Close() + + s := dummyServer(oauth2server.URL) + defer s.Close() + + if err := createUser(oauth2server.URL); err != nil { + t.Fatal(err) + } + + if err := logUser(oauth2server.URL); err != nil { + t.Fatal(err) + } + + if err := shouldRedir(s.URL); err != nil { + t.Fatal(err) + } + + if err := testAuth(oauth2server.URL, s.URL); err != nil { + t.Fatal(err) + } +} + +func launchServer() (*httptest.Server, error) { + config.Store = "map" + + oauth2server := server.New() + err := oauth2server.Routes() + if err != nil { + return nil, err + } + s := httptest.NewServer(oauth2server) + + re := regexp.MustCompile(":[0-9]*") + port := re.FindString(s.URL) + config.Port = port + + return s, err +} + +func dummyServer(oauth2server string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := oauth2client.Authenticate(oauth2server, w, r) + if err != nil { + return + } + fmt.Fprintln(w, "dummy server serving authenticated") + })) +} + +func createUser(oauth2server string) error { + resp, err := http.Post(oauth2server+"/users/submit", "application/x-www-form-urlencoded", strings.NewReader("username=abc")) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.New("bad status " + resp.Status) + } + return nil +} + +func logUser(oauth2server string) error { + resp, err := http.Post(oauth2server+"/authorize", "application/x-www-form-urlencoded", strings.NewReader("username=abc")) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.New("bad status " + resp.Status) + } + return nil +} + +func shouldRedir(dummy string) error { + c := makeClient() + return clientShouldRedir(c, dummy) +} + +func clientShouldRedir(c *http.Client, dummy string) error { + req, _ := http.NewRequest("GET", dummy, nil) + resp, err := c.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.Request.URL.Path != "/users/log" { + return fmt.Errorf("did not need redir without auth: %v", resp.Request.URL) + } + return nil +} + +func clientShouldNotRedir(c *http.Client, dummy string) error { + req, _ := http.NewRequest("GET", dummy, nil) + resp, err := c.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.Request.URL.Path == "/users/log" { + return fmt.Errorf("did redir with auth: %v", resp.Request.URL.Path) + } + return nil +} + +func testAuth(oauth2server, dummy string) error { + c := makeClient() + if err := clientShouldRedir(c, dummy); err != nil { + return err + } + if err := clientLogin(c, oauth2server); err != nil { + return err + } + if err := clientShouldNotRedir(c, dummy); err != nil { + return err + } + return nil +} + +func clientLogin(c *http.Client, oauth2server string) error { + req, _ := http.NewRequest("POST", oauth2server+"/authorize", strings.NewReader("username=abc")) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + resp, err := c.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad status; %v", resp.StatusCode) + } + if resp.Request.URL.Path != "/" { + return fmt.Errorf("login response path wrong: %v", resp.Request.URL) + } + return nil +} + +func makeClient() *http.Client { + jar, _ := cookiejar.New(&cookiejar.Options{}) + return &http.Client{ + Jar: jar, + } +}