From a0bf41e04e764424eceb551661f6e15d2f91a839 Mon Sep 17 00:00:00 2001 From: bel Date: Sun, 20 Oct 2019 17:13:00 -0600 Subject: [PATCH] Working cross domain too --- oauth2client/client.go | 64 +++++++++++++++++++++++++++----- oauth2server/server/authorize.go | 30 ++++++++------- oauth2server/server/users.go | 3 +- oauth2test/package_test.go | 31 +++++++++++----- 4 files changed, 95 insertions(+), 33 deletions(-) diff --git a/oauth2client/client.go b/oauth2client/client.go index 1b270b3..e1df825 100644 --- a/oauth2client/client.go +++ b/oauth2client/client.go @@ -15,11 +15,46 @@ func Authenticate(server string, w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - access, err := r.Cookie(oauth2.COOKIE) - if err == http.ErrNoCookie { + access, exists := findAccess(w, r) + if !exists { return login(oauth2server, w, r) } - return verify(access.Value, oauth2server, w, r) + return verify(access, oauth2server, w, r) +} + +func findAccess(w http.ResponseWriter, r *http.Request) (string, bool) { + fresh, exists := findAccessFresh(w, r) + if exists { + return fresh, true + } + stable, exists := findAccessStable(w, r) + return stable, exists +} + +func findAccessFresh(w http.ResponseWriter, r *http.Request) (string, bool) { + q := r.URL.Query() + access := q.Get(oauth2.COOKIE) + q.Del(oauth2.COOKIE) + r.URL.RawQuery = q.Encode() + if access == "" { + return "", false + } + cookie := &http.Cookie{ + Name: oauth2.COOKIE, + Value: access, + SameSite: http.SameSiteLaxMode, + Path: "/", + } + http.SetCookie(w, cookie) + return access, true +} + +func findAccessStable(w http.ResponseWriter, r *http.Request) (string, bool) { + access, err := r.Cookie(oauth2.COOKIE) + if err == http.ErrNoCookie { + return "", false + } + return access.Value, true } func login(oauth2server *url.URL, w http.ResponseWriter, r *http.Request) error { @@ -29,13 +64,9 @@ func login(oauth2server *url.URL, w http.ResponseWriter, r *http.Request) error if url.Scheme == "" { url.Scheme = "http" } - cookie := &http.Cookie{ - Name: oauth2.REDIRECT, - Value: url.String(), - SameSite: http.SameSiteLaxMode, - Path: "/authorize", - } - http.SetCookie(w, cookie) + q := oauth2server.Query() + q.Set(oauth2.REDIRECT, url.String()) + oauth2server.RawQuery = q.Encode() http.Redirect(w, r, oauth2server.String(), http.StatusSeeOther) return errors.New("logging in") } @@ -63,3 +94,16 @@ func verify(access string, oauth2server *url.URL, w http.ResponseWriter, r *http } return nil } + +func setCookie(access string, w http.ResponseWriter) { + cookie := &http.Cookie{ + Name: oauth2.COOKIE, + Value: access, + SameSite: http.SameSiteLaxMode, + Path: "/", + } + if access == "" { + cookie.Expires = time.Now().Add(-1 * time.Hour) + } + http.SetCookie(w, cookie) +} diff --git a/oauth2server/server/authorize.go b/oauth2server/server/authorize.go index bb4eb3e..a9c1952 100644 --- a/oauth2server/server/authorize.go +++ b/oauth2server/server/authorize.go @@ -1,10 +1,11 @@ package server import ( + "fmt" "local/oauth2" "local/storage" - "log" "net/http" + "net/url" "github.com/google/uuid" ) @@ -25,19 +26,22 @@ func (s *Server) authorize(w http.ResponseWriter, r *http.Request) { http.Error(w, "no oauth for user", http.StatusForbidden) return } - cookie := &http.Cookie{ - Name: oauth2.COOKIE, - Value: access, - SameSite: http.SameSiteLaxMode, + q := r.URL.Query() + redirect := q.Get(oauth2.REDIRECT) + q.Del(oauth2.REDIRECT) + r.URL.RawQuery = q.Encode() + if redirect != "" { + url, _ := url.Parse(redirect) + if url.Scheme == "" { + url.Scheme = "http" + } + values := url.Query() + values.Set(oauth2.COOKIE, access) + url.RawQuery = values.Encode() + http.Redirect(w, r, url.String(), http.StatusSeeOther) + } else { + fmt.Fprintln(w, "OK") } - http.SetCookie(w, cookie) - redirectCookie, err := r.Cookie(oauth2.REDIRECT) - log.Printf("REDIR COOKIE", err, redirectCookie) - log.Println(r.Cookies()) - if err != nil { - return - } - http.Redirect(w, r, redirectCookie.Value, http.StatusSeeOther) } func (s *Server) genAuth(user string) { diff --git a/oauth2server/server/users.go b/oauth2server/server/users.go index 19397b2..b831773 100644 --- a/oauth2server/server/users.go +++ b/oauth2server/server/users.go @@ -13,10 +13,11 @@ import ( ) func (s *Server) usersLog(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() fmt.Fprintln(w, ` -
+
diff --git a/oauth2test/package_test.go b/oauth2test/package_test.go index bb85b3e..fbd0673 100644 --- a/oauth2test/package_test.go +++ b/oauth2test/package_test.go @@ -3,9 +3,11 @@ package oauth2 import ( "errors" "fmt" + "local/oauth2" "local/oauth2/oauth2client" "local/oauth2/oauth2server/config" "local/oauth2/oauth2server/server" + "log" "net/http" "net/http/cookiejar" "net/http/httptest" @@ -125,33 +127,44 @@ func clientShouldNotRedir(c *http.Client, dummy string) error { func testAuth(oauth2server, dummy string) error { c := makeClient() + log.Println("should redir...") if err := clientShouldRedir(c, dummy); err != nil { return err } - if err := clientLogin(c, oauth2server); err != nil { + log.Println("client login...") + access, err := clientLogin(c, oauth2server) + if err != nil { return err } - if err := clientShouldNotRedir(c, dummy); err != nil { + log.Println("client should not redir...") + if err := clientShouldNotRedir(c, dummy+"?"+oauth2.COOKIE+"="+access); err != nil { return err } + if !strings.Contains(fmt.Sprint(c.Jar), oauth2.COOKIE) { + return errors.New("cookie jar empty:" + fmt.Sprint(c.Jar)) + } return nil } -func clientLogin(c *http.Client, oauth2server string) error { - req, _ := http.NewRequest("POST", oauth2server+"/authorize", strings.NewReader("username=abc")) +func clientLogin(c *http.Client, oauth2server string) (string, error) { + req, _ := http.NewRequest("POST", oauth2server+"/authorize?"+oauth2.REDIRECT+"="+oauth2server+"/", strings.NewReader("username=abc")) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") resp, err := c.Do(req) if err != nil { - return err + return "", err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad status; %v", resp.StatusCode) + if resp.StatusCode == http.StatusUnauthorized { + return "", fmt.Errorf("bad status; %v", resp.StatusCode) } if resp.Request.URL.Path != "/" { - return fmt.Errorf("login response path wrong: %v", resp.Request.URL) + return "", fmt.Errorf("login response path wrong: %v", resp.Request.URL.Path) } - return nil + a := resp.Request.URL.Query().Get(oauth2.COOKIE) + if a == "" { + return "", fmt.Errorf("login and redir didnt set cookie: %v", a) + } + return a, nil } func makeClient() *http.Client {