package oauth2 import ( "context" "errors" "fmt" "io" "net" "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" "regexp" "strings" "testing" "gitea.inhome.blapointe.com/local/oauth2" "gitea.inhome.blapointe.com/local/oauth2/oauth2client" "gitea.inhome.blapointe.com/local/oauth2/oauth2server/config" "gitea.inhome.blapointe.com/local/oauth2/oauth2server/server" ) func TestAll(t *testing.T) { oauth2client.HTTPClient.Transport = makeTransport() oauth2server, err := launchServer() if err != nil { t.Fatal(err) } defer oauth2server.Close() oauth2server.URL = strings.ReplaceAll(oauth2server.URL, "127.0.0.1", "echo.belbox.dev") s := dummyServer(t, oauth2server.URL) defer s.Close() t.Log("createUser...") if err := createUser(oauth2server.URL); err != nil { t.Fatal(err) } t.Log("loginAsUser...") if err := loginAsuser(oauth2server.URL); err != nil { t.Fatal(err) } t.Log("shouldRedir...") if err := shouldRedir(s.URL); err != nil { t.Fatal(err) } t.Log("testAuth...") if err := testAuth(oauth2server.URL, s.URL); err != nil { t.Fatal(err) } t.Log("testAuthViaBadBasicAuth...") if err := testAuthViaBadBasicAuth(s.URL); err != nil { t.Fatal(err) } t.Log("testAuthViaBasicAuth...") if err := testAuthViaBasicAuth(s.URL); err != nil { t.Fatal(err) } } func launchServer() (*httptest.Server, error) { config.Store = "map" config.UserRegistration = true oauth2server := server.New() err := oauth2server.Routes() if err != nil { return nil, err } s := httptest.NewServer(oauth2server) re := regexp.MustCompile(":[0-9]*") port := re.FindString(s.URL) config.Port = port return s, err } func dummyServer(t *testing.T, oauth2server string) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := oauth2client.Authenticate(oauth2server, "scope", w, r) if err != nil { t.Logf("dummy: %s: %v", r.URL.Path, err) return } t.Logf("dummy: %s: :D", r.URL.Path) fmt.Fprintln(w, "dummy server serving authenticated") })) } func createUser(oauth2server string) error { req, _ := http.NewRequest(http.MethodPost, oauth2server+"/users/submit/scope", strings.NewReader("username=abc")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := makeClient().Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return errors.New("bad status " + resp.Status) } return nil } func loginAsuser(oauth2server string) error { req, _ := http.NewRequest(http.MethodPost, oauth2server+"/authorize/scope", strings.NewReader("username=abc")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := makeClient().Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return errors.New("bad status " + resp.Status) } return nil } func shouldRedir(dummy string) error { c := makeClient() return clientShouldRedir(c, dummy) } func clientShouldRedir(c *http.Client, dummy string) error { req, _ := http.NewRequest("GET", dummy, nil) resp, err := c.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != 403 || resp.Header.Get("WWW-Authenticate") == "" { return fmt.Errorf("did not need redir without auth: (%d) %q", resp.StatusCode, resp.Header.Get("WWW-Authenticate")) } return nil } func clientShouldNotRedir(c *http.Client, dummy string) error { req, _ := http.NewRequest("GET", dummy, nil) resp, err := c.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.Request.URL.Path == "/users/log/scope" { return fmt.Errorf("did redir with auth: %v", resp.Request.URL.Path) } return nil } func testAuth(oauth2server, dummy string) error { c := makeClient() if err := clientShouldRedir(c, dummy); err != nil { return err } access, err := clientLogin(c, oauth2server) if err != nil { return err } if err := clientShouldNotRedir(c, dummy+"?"+oauth2.NEWCOOKIE+"="+access); err != nil { return err } if !strings.Contains(fmt.Sprint(c.Jar), oauth2.COOKIE) { return errors.New("cookie jar empty:" + fmt.Sprint(c.Jar)) } return nil } func clientLogin(c *http.Client, oauth2server string) (string, error) { req, _ := http.NewRequest("POST", oauth2server+"/authorize/scope?"+oauth2.REDIRECT+"="+oauth2server+"/", strings.NewReader("username=abc")) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") resp, err := c.Do(req) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode == http.StatusUnauthorized { return "", fmt.Errorf("bad status; %v", resp.StatusCode) } if resp.Request.URL.Path != "/" { return "", fmt.Errorf("login response path wrong: %v", resp.Request.URL.Path) } a := resp.Request.URL.Query().Get(oauth2.NEWCOOKIE) if a == "" { cookies := c.Jar.Cookies(&url.URL{Scheme: "http", Path: "/", Host: "echo.belbox.dev"}) for i := range cookies { if cookies[i].Name == oauth2.NEWCOOKIE { a = cookies[i].Value } } } if a == "" { return "", fmt.Errorf("login and redir didnt set cookie: %v", a) } return a, nil } func testAuthViaBadBasicAuth(dummy string) error { c := makeClient() req, _ := http.NewRequest(http.MethodGet, dummy, nil) req.SetBasicAuth("u", "p") resp, err := c.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != 403 { b, _ := io.ReadAll(resp.Body) return fmt.Errorf("got through with bad basic auth set: (%d) %s", resp.StatusCode, b) } return nil } func testAuthViaBasicAuth(dummy string) error { c := makeClient() req, _ := http.NewRequest(http.MethodGet, dummy, nil) req.SetBasicAuth("", "abc") resp, err := c.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { b, _ := io.ReadAll(resp.Body) return fmt.Errorf("failed to get through with basic auth set: (%d) %s", resp.StatusCode, b) } if !strings.Contains(fmt.Sprint(c.Jar), oauth2.COOKIE) { return errors.New("cookie jar empty:" + fmt.Sprint(c.Jar)) } return nil } func makeClient() *http.Client { jar, _ := cookiejar.New(&cookiejar.Options{}) return &http.Client{ Jar: jar, Transport: makeTransport(), } } func makeTransport() *http.Transport { return &http.Transport{ DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { parts := strings.Split(addr, ":") port := "80" if len(parts) > 0 { port = parts[1] } return (&net.Dialer{}).DialContext(ctx, "tcp4", fmt.Sprintf("127.0.0.1:%s", port)) }, } }