258 lines
6.5 KiB
Go
Executable File
258 lines
6.5 KiB
Go
Executable File
package oauth2
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/cookiejar"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"testing"
|
|
|
|
"gitea.inhome.blapointe.com/local/oauth2"
|
|
"gitea.inhome.blapointe.com/local/oauth2/oauth2client"
|
|
"gitea.inhome.blapointe.com/local/oauth2/oauth2server/config"
|
|
"gitea.inhome.blapointe.com/local/oauth2/oauth2server/server"
|
|
)
|
|
|
|
func TestAll(t *testing.T) {
|
|
oauth2client.HTTPClient.Transport = makeTransport()
|
|
|
|
oauth2server, err := launchServer()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer oauth2server.Close()
|
|
oauth2server.URL = strings.ReplaceAll(oauth2server.URL, "127.0.0.1", "echo.belbox.dev")
|
|
|
|
s := dummyServer(t, oauth2server.URL)
|
|
defer s.Close()
|
|
|
|
t.Log("createUser...")
|
|
if err := createUser(oauth2server.URL); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Log("loginAsUser...")
|
|
if err := loginAsuser(oauth2server.URL); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Log("shouldRedir...")
|
|
if err := shouldRedir(s.URL); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Log("testAuth...")
|
|
if err := testAuth(oauth2server.URL, s.URL); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Log("testAuthViaBadBasicAuth...")
|
|
if err := testAuthViaBadBasicAuth(s.URL); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Log("testAuthViaBasicAuth...")
|
|
if err := testAuthViaBasicAuth(s.URL); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func launchServer() (*httptest.Server, error) {
|
|
config.Store = "map"
|
|
config.UserRegistration = true
|
|
|
|
oauth2server := server.New()
|
|
err := oauth2server.Routes()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s := httptest.NewServer(oauth2server)
|
|
|
|
re := regexp.MustCompile(":[0-9]*")
|
|
port := re.FindString(s.URL)
|
|
config.Port = port
|
|
|
|
return s, err
|
|
}
|
|
|
|
func dummyServer(t *testing.T, oauth2server string) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
err := oauth2client.Authenticate(oauth2server, "scope", w, r)
|
|
if err != nil {
|
|
t.Logf("dummy: %s: %v", r.URL.Path, err)
|
|
return
|
|
}
|
|
t.Logf("dummy: %s: :D", r.URL.Path)
|
|
fmt.Fprintln(w, "dummy server serving authenticated")
|
|
}))
|
|
}
|
|
|
|
func createUser(oauth2server string) error {
|
|
req, _ := http.NewRequest(http.MethodPost, oauth2server+"/users/submit/scope", strings.NewReader("username=abc"))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
resp, err := makeClient().Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return errors.New("bad status " + resp.Status)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func loginAsuser(oauth2server string) error {
|
|
req, _ := http.NewRequest(http.MethodPost, oauth2server+"/authorize/scope", strings.NewReader("username=abc"))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
resp, err := makeClient().Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return errors.New("bad status " + resp.Status)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func shouldRedir(dummy string) error {
|
|
c := makeClient()
|
|
return clientShouldRedir(c, dummy)
|
|
}
|
|
|
|
func clientShouldRedir(c *http.Client, dummy string) error {
|
|
req, _ := http.NewRequest("GET", dummy, nil)
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 403 || resp.Header.Get("WWW-Authenticate") == "" {
|
|
return fmt.Errorf("did not need redir without auth: (%d) %q", resp.StatusCode, resp.Header.Get("WWW-Authenticate"))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func clientShouldNotRedir(c *http.Client, dummy string) error {
|
|
req, _ := http.NewRequest("GET", dummy, nil)
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.Request.URL.Path == "/users/log/scope" {
|
|
return fmt.Errorf("did redir with auth: %v", resp.Request.URL.Path)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func testAuth(oauth2server, dummy string) error {
|
|
c := makeClient()
|
|
if err := clientShouldRedir(c, dummy); err != nil {
|
|
return err
|
|
}
|
|
access, err := clientLogin(c, oauth2server)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := clientShouldNotRedir(c, dummy+"?"+oauth2.NEWCOOKIE+"="+access); err != nil {
|
|
return err
|
|
}
|
|
if !strings.Contains(fmt.Sprint(c.Jar), oauth2.COOKIE) {
|
|
return errors.New("cookie jar empty:" + fmt.Sprint(c.Jar))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func clientLogin(c *http.Client, oauth2server string) (string, error) {
|
|
req, _ := http.NewRequest("POST", oauth2server+"/authorize/scope?"+oauth2.REDIRECT+"="+oauth2server+"/", strings.NewReader("username=abc"))
|
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
return "", fmt.Errorf("bad status; %v", resp.StatusCode)
|
|
}
|
|
if resp.Request.URL.Path != "/" {
|
|
return "", fmt.Errorf("login response path wrong: %v", resp.Request.URL.Path)
|
|
}
|
|
a := resp.Request.URL.Query().Get(oauth2.NEWCOOKIE)
|
|
if a == "" {
|
|
cookies := c.Jar.Cookies(&url.URL{Scheme: "http", Path: "/", Host: "echo.belbox.dev"})
|
|
for i := range cookies {
|
|
if cookies[i].Name == oauth2.NEWCOOKIE {
|
|
a = cookies[i].Value
|
|
}
|
|
}
|
|
}
|
|
if a == "" {
|
|
return "", fmt.Errorf("login and redir didnt set cookie: %v", a)
|
|
}
|
|
return a, nil
|
|
}
|
|
|
|
func testAuthViaBadBasicAuth(dummy string) error {
|
|
c := makeClient()
|
|
req, _ := http.NewRequest(http.MethodGet, dummy, nil)
|
|
req.SetBasicAuth("u", "p")
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 403 {
|
|
b, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("got through with bad basic auth set: (%d) %s", resp.StatusCode, b)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func testAuthViaBasicAuth(dummy string) error {
|
|
c := makeClient()
|
|
req, _ := http.NewRequest(http.MethodGet, dummy, nil)
|
|
req.SetBasicAuth("", "abc")
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
b, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("failed to get through with basic auth set: (%d) %s", resp.StatusCode, b)
|
|
}
|
|
if !strings.Contains(fmt.Sprint(c.Jar), oauth2.COOKIE) {
|
|
return errors.New("cookie jar empty:" + fmt.Sprint(c.Jar))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func makeClient() *http.Client {
|
|
jar, _ := cookiejar.New(&cookiejar.Options{})
|
|
return &http.Client{
|
|
Jar: jar,
|
|
Transport: makeTransport(),
|
|
}
|
|
}
|
|
|
|
func makeTransport() *http.Transport {
|
|
return &http.Transport{
|
|
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
|
parts := strings.Split(addr, ":")
|
|
port := "80"
|
|
if len(parts) > 0 {
|
|
port = parts[1]
|
|
}
|
|
return (&net.Dialer{}).DialContext(ctx, "tcp4", fmt.Sprintf("127.0.0.1:%s", port))
|
|
},
|
|
}
|
|
}
|