From a2e84379a9b5fb9b9ae2be7b7f54734cfe514dc0 Mon Sep 17 00:00:00 2001 From: bel Date: Sun, 10 Mar 2024 10:41:31 -0600 Subject: [PATCH] too much effort into the garbage --- oauth2client/client.go | 48 +++++++------- oauth2server/server/authorize.go | 20 +++--- oauth2server/server/server.go | 6 +- oauth2server/server/verify.go | 7 +-- oauth2test/package_test.go | 105 ++++++++++++++++++++++++++----- 5 files changed, 128 insertions(+), 58 deletions(-) diff --git a/oauth2client/client.go b/oauth2client/client.go index c0438a3..0a7d96b 100755 --- a/oauth2client/client.go +++ b/oauth2client/client.go @@ -3,12 +3,13 @@ package oauth2client import ( "crypto/tls" "errors" - "gitea.inhome.blapointe.com/local/oauth2" "net/http" "net/url" "strconv" "strings" "time" + + "gitea.inhome.blapointe.com/local/oauth2" ) type cached struct { @@ -25,7 +26,7 @@ func Authenticate(server, scope string, w http.ResponseWriter, r *http.Request) } access, exists := findAccess(w, r) if !exists { - return login(oauth2server, scope, w, r) + return login(scope, w, r) } return verify(access, oauth2server, scope, w, r) } @@ -44,12 +45,20 @@ func findAccessFresh(w http.ResponseWriter, r *http.Request) (string, bool) { if !found { access, found = findAccessFreshCookie(w, r) } + if !found { + access, found = findAccessFreshBasicAuth(w, r) + } if found { setCookie(oauth2.COOKIE, access, "", w) } return access, found } +func findAccessFreshBasicAuth(w http.ResponseWriter, r *http.Request) (string, bool) { + _, p, ok := r.BasicAuth() + return p, ok +} + func findAccessFreshQueryParam(w http.ResponseWriter, r *http.Request) (string, bool) { q := r.URL.Query() access := q.Get(oauth2.NEWCOOKIE) @@ -88,21 +97,17 @@ func findAccessStable(w http.ResponseWriter, r *http.Request) (string, bool) { return access.Value, true } -func login(oauth2server *url.URL, scope string, w http.ResponseWriter, r *http.Request) error { - oauth2server.Path = "/users/log/" + scope - url := *r.URL - url.Host = r.Host - if url.Scheme == "" { - url.Scheme = oauth2server.Scheme - } - if url.Scheme == "" { - url.Scheme = "https" - } - q := oauth2server.Query() - q.Set(oauth2.REDIRECT, url.String()) - oauth2server.RawQuery = q.Encode() - http.Redirect(w, r, oauth2server.String(), http.StatusSeeOther) - return errors.New("logging in") +func login(scope string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("WWW-Authenticate", "Basic") + w.WriteHeader(403) + return errors.New("login pls") +} + +var HTTPClient = &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, } func verify(access string, oauth2server *url.URL, scope string, w http.ResponseWriter, r *http.Request) error { @@ -118,19 +123,14 @@ func verify(access string, oauth2server *url.URL, scope string, w http.ResponseW } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) - c := &http.Client{ - Timeout: 5 * time.Second, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, - } + c := HTTPClient resp, err := c.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return login(oauth2server, scope, w, r) + return login(scope, w, r) } cache[scope] = cached{ access: access, diff --git a/oauth2server/server/authorize.go b/oauth2server/server/authorize.go index 56adad4..c853b31 100755 --- a/oauth2server/server/authorize.go +++ b/oauth2server/server/authorize.go @@ -2,13 +2,14 @@ package server import ( "fmt" - "gitea.inhome.blapointe.com/local/oauth2" - "gitea.inhome.blapointe.com/local/router" - "gitea.inhome.blapointe.com/local/storage" "net/http" "net/url" "strings" + "gitea.inhome.blapointe.com/local/oauth2" + "gitea.inhome.blapointe.com/local/router" + "gitea.inhome.blapointe.com/local/storage" + "github.com/google/uuid" ) @@ -55,9 +56,9 @@ func (s *Server) authorize(w http.ResponseWriter, r *http.Request) { func (s *Server) genAuth(scope, user string) { access := uuid.New().String() - token := uuid.New().String() s.store.Set(user, []byte(access), ACCESS) - s.store.Set(scope+"."+access, []byte(token), TOKEN) + s.store.Set(user, []byte(user), ACCESS) + s.store.Set(access, []byte(user), ACCESS) } func (s *Server) getAccess(scope, user string) (string, bool) { @@ -69,7 +70,10 @@ func (s *Server) getAccess(scope, user string) (string, bool) { return string(access), err == nil } -func (s *Server) getToken(scope, access string) (string, bool) { - token, err := s.store.Get(scope+"."+access, TOKEN) - return string(token), err == nil +func (s *Server) verifyAccess(access string) error { + _, err := s.store.Get(access, ACCESS) + if err != nil { + return fmt.Errorf("access not found: %s", access) + } + return nil } diff --git a/oauth2server/server/server.go b/oauth2server/server/server.go index b3738ed..1d68fdc 100755 --- a/oauth2server/server/server.go +++ b/oauth2server/server/server.go @@ -2,6 +2,7 @@ package server import ( "fmt" + "gitea.inhome.blapointe.com/local/oauth2/oauth2server/config" "gitea.inhome.blapointe.com/local/router" "gitea.inhome.blapointe.com/local/storage" @@ -14,7 +15,6 @@ var wildcard = router.Wildcard const ( USERS = "users" ACCESS = "access" - TOKEN = "token" SALT = "salt" ) @@ -42,10 +42,6 @@ func purgeIssuedCredentials(store storage.DB) { for _, access := range accesses { store.Set(access, nil, ACCESS) } - tokens, _ := store.List([]string{TOKEN}) - for _, token := range tokens { - store.Set(token, nil, TOKEN) - } } func wrapBody(title, body string) string { diff --git a/oauth2server/server/verify.go b/oauth2server/server/verify.go index 2d20668..eeecd6d 100755 --- a/oauth2server/server/verify.go +++ b/oauth2server/server/verify.go @@ -1,8 +1,9 @@ package server import ( - "gitea.inhome.blapointe.com/local/router" "net/http" + + "gitea.inhome.blapointe.com/local/router" ) func (s *Server) verify(w http.ResponseWriter, r *http.Request) { @@ -13,10 +14,8 @@ func (s *Server) verify(w http.ResponseWriter, r *http.Request) { return } access := r.FormValue("access") - token, ok := s.getToken(scope, access) - if !ok { + if err := s.verifyAccess(access); err != nil { http.Error(w, "unknown access", http.StatusUnauthorized) return } - w.Write([]byte(token)) } diff --git a/oauth2test/package_test.go b/oauth2test/package_test.go index 968a189..718b8a6 100755 --- a/oauth2test/package_test.go +++ b/oauth2test/package_test.go @@ -1,13 +1,11 @@ package oauth2 import ( + "context" "errors" "fmt" - "gitea.inhome.blapointe.com/local/oauth2" - "gitea.inhome.blapointe.com/local/oauth2/oauth2client" - "gitea.inhome.blapointe.com/local/oauth2/oauth2server/config" - "gitea.inhome.blapointe.com/local/oauth2/oauth2server/server" - "log" + "io" + "net" "net/http" "net/http/cookiejar" "net/http/httptest" @@ -15,9 +13,16 @@ import ( "regexp" "strings" "testing" + + "gitea.inhome.blapointe.com/local/oauth2" + "gitea.inhome.blapointe.com/local/oauth2/oauth2client" + "gitea.inhome.blapointe.com/local/oauth2/oauth2server/config" + "gitea.inhome.blapointe.com/local/oauth2/oauth2server/server" ) func TestAll(t *testing.T) { + oauth2client.HTTPClient.Transport = makeTransport() + oauth2server, err := launchServer() if err != nil { t.Fatal(err) @@ -25,24 +30,38 @@ func TestAll(t *testing.T) { defer oauth2server.Close() oauth2server.URL = strings.ReplaceAll(oauth2server.URL, "127.0.0.1", "echo.belbox.dev") - s := dummyServer(oauth2server.URL) + s := dummyServer(t, oauth2server.URL) defer s.Close() + t.Log("createUser...") if err := createUser(oauth2server.URL); err != nil { t.Fatal(err) } - if err := logUser(oauth2server.URL); err != nil { + t.Log("loginAsUser...") + if err := loginAsuser(oauth2server.URL); err != nil { t.Fatal(err) } + t.Log("shouldRedir...") if err := shouldRedir(s.URL); err != nil { t.Fatal(err) } + t.Log("testAuth...") if err := testAuth(oauth2server.URL, s.URL); err != nil { t.Fatal(err) } + + t.Log("testAuthViaBadBasicAuth...") + if err := testAuthViaBadBasicAuth(s.URL); err != nil { + t.Fatal(err) + } + + t.Log("testAuthViaBasicAuth...") + if err := testAuthViaBasicAuth(s.URL); err != nil { + t.Fatal(err) + } } func launchServer() (*httptest.Server, error) { @@ -63,18 +82,22 @@ func launchServer() (*httptest.Server, error) { return s, err } -func dummyServer(oauth2server string) *httptest.Server { +func dummyServer(t *testing.T, oauth2server string) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := oauth2client.Authenticate(oauth2server, "scope", w, r) if err != nil { + t.Logf("dummy: %s: %v", r.URL.Path, err) return } + t.Logf("dummy: %s: :D", r.URL.Path) fmt.Fprintln(w, "dummy server serving authenticated") })) } func createUser(oauth2server string) error { - resp, err := http.Post(oauth2server+"/users/submit/scope", "application/x-www-form-urlencoded", strings.NewReader("username=abc")) + req, _ := http.NewRequest(http.MethodPost, oauth2server+"/users/submit/scope", strings.NewReader("username=abc")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := makeClient().Do(req) if err != nil { return err } @@ -85,8 +108,10 @@ func createUser(oauth2server string) error { return nil } -func logUser(oauth2server string) error { - resp, err := http.Post(oauth2server+"/authorize/scope", "application/x-www-form-urlencoded", strings.NewReader("username=abc")) +func loginAsuser(oauth2server string) error { + req, _ := http.NewRequest(http.MethodPost, oauth2server+"/authorize/scope", strings.NewReader("username=abc")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := makeClient().Do(req) if err != nil { return err } @@ -109,8 +134,8 @@ func clientShouldRedir(c *http.Client, dummy string) error { return err } defer resp.Body.Close() - if resp.Request.URL.Path != "/users/log/scope" { - return fmt.Errorf("did not need redir without auth: %v", resp.Request.URL) + if resp.StatusCode != 403 || resp.Header.Get("WWW-Authenticate") == "" { + return fmt.Errorf("did not need redir without auth: (%d) %q", resp.StatusCode, resp.Header.Get("WWW-Authenticate")) } return nil } @@ -130,16 +155,13 @@ func clientShouldNotRedir(c *http.Client, dummy string) error { func testAuth(oauth2server, dummy string) error { c := makeClient() - log.Println("should redir...") if err := clientShouldRedir(c, dummy); err != nil { return err } - log.Println("client login...") access, err := clientLogin(c, oauth2server) if err != nil { return err } - log.Println("client should not redir...") if err := clientShouldNotRedir(c, dummy+"?"+oauth2.NEWCOOKIE+"="+access); err != nil { return err } @@ -178,9 +200,58 @@ func clientLogin(c *http.Client, oauth2server string) (string, error) { return a, nil } +func testAuthViaBadBasicAuth(dummy string) error { + c := makeClient() + req, _ := http.NewRequest(http.MethodGet, dummy, nil) + req.SetBasicAuth("u", "p") + resp, err := c.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != 403 { + b, _ := io.ReadAll(resp.Body) + return fmt.Errorf("got through with bad basic auth set: (%d) %s", resp.StatusCode, b) + } + return nil +} + +func testAuthViaBasicAuth(dummy string) error { + c := makeClient() + req, _ := http.NewRequest(http.MethodGet, dummy, nil) + req.SetBasicAuth("", "abc") + resp, err := c.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return fmt.Errorf("failed to get through with basic auth set: (%d) %s", resp.StatusCode, b) + } + if !strings.Contains(fmt.Sprint(c.Jar), oauth2.COOKIE) { + return errors.New("cookie jar empty:" + fmt.Sprint(c.Jar)) + } + return nil +} + func makeClient() *http.Client { jar, _ := cookiejar.New(&cookiejar.Options{}) return &http.Client{ - Jar: jar, + Jar: jar, + Transport: makeTransport(), + } +} + +func makeTransport() *http.Transport { + return &http.Transport{ + DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { + parts := strings.Split(addr, ":") + port := "80" + if len(parts) > 0 { + port = parts[1] + } + return (&net.Dialer{}).DialContext(ctx, "tcp4", fmt.Sprintf("127.0.0.1:%s", port)) + }, } }