parent
a2e84379a9
commit
6ae4b401b1
|
|
@ -3,13 +3,12 @@ package oauth2client
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"gitea.inhome.blapointe.com/local/oauth2"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gitea.inhome.blapointe.com/local/oauth2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type cached struct {
|
type cached struct {
|
||||||
|
|
@ -26,7 +25,7 @@ func Authenticate(server, scope string, w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
access, exists := findAccess(w, r)
|
access, exists := findAccess(w, r)
|
||||||
if !exists {
|
if !exists {
|
||||||
return login(scope, w, r)
|
return login(oauth2server, scope, w, r)
|
||||||
}
|
}
|
||||||
return verify(access, oauth2server, scope, w, r)
|
return verify(access, oauth2server, scope, w, r)
|
||||||
}
|
}
|
||||||
|
|
@ -45,20 +44,12 @@ func findAccessFresh(w http.ResponseWriter, r *http.Request) (string, bool) {
|
||||||
if !found {
|
if !found {
|
||||||
access, found = findAccessFreshCookie(w, r)
|
access, found = findAccessFreshCookie(w, r)
|
||||||
}
|
}
|
||||||
if !found {
|
|
||||||
access, found = findAccessFreshBasicAuth(w, r)
|
|
||||||
}
|
|
||||||
if found {
|
if found {
|
||||||
setCookie(oauth2.COOKIE, access, "", w)
|
setCookie(oauth2.COOKIE, access, "", w)
|
||||||
}
|
}
|
||||||
return access, found
|
return access, found
|
||||||
}
|
}
|
||||||
|
|
||||||
func findAccessFreshBasicAuth(w http.ResponseWriter, r *http.Request) (string, bool) {
|
|
||||||
_, p, ok := r.BasicAuth()
|
|
||||||
return p, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func findAccessFreshQueryParam(w http.ResponseWriter, r *http.Request) (string, bool) {
|
func findAccessFreshQueryParam(w http.ResponseWriter, r *http.Request) (string, bool) {
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
access := q.Get(oauth2.NEWCOOKIE)
|
access := q.Get(oauth2.NEWCOOKIE)
|
||||||
|
|
@ -97,17 +88,21 @@ func findAccessStable(w http.ResponseWriter, r *http.Request) (string, bool) {
|
||||||
return access.Value, true
|
return access.Value, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func login(scope string, w http.ResponseWriter, r *http.Request) error {
|
func login(oauth2server *url.URL, scope string, w http.ResponseWriter, r *http.Request) error {
|
||||||
w.Header().Set("WWW-Authenticate", "Basic")
|
oauth2server.Path = "/users/log/" + scope
|
||||||
w.WriteHeader(403)
|
url := *r.URL
|
||||||
return errors.New("login pls")
|
url.Host = r.Host
|
||||||
}
|
if url.Scheme == "" {
|
||||||
|
url.Scheme = oauth2server.Scheme
|
||||||
var HTTPClient = &http.Client{
|
}
|
||||||
Timeout: 5 * time.Second,
|
if url.Scheme == "" {
|
||||||
Transport: &http.Transport{
|
url.Scheme = "https"
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
}
|
||||||
},
|
q := oauth2server.Query()
|
||||||
|
q.Set(oauth2.REDIRECT, url.String())
|
||||||
|
oauth2server.RawQuery = q.Encode()
|
||||||
|
http.Redirect(w, r, oauth2server.String(), http.StatusSeeOther)
|
||||||
|
return errors.New("logging in")
|
||||||
}
|
}
|
||||||
|
|
||||||
func verify(access string, oauth2server *url.URL, scope string, w http.ResponseWriter, r *http.Request) error {
|
func verify(access string, oauth2server *url.URL, scope string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
|
@ -123,14 +118,19 @@ func verify(access string, oauth2server *url.URL, scope string, w http.ResponseW
|
||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
|
req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
|
||||||
c := HTTPClient
|
c := &http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
resp, err := c.Do(req)
|
resp, err := c.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return login(scope, w, r)
|
return login(oauth2server, scope, w, r)
|
||||||
}
|
}
|
||||||
cache[scope] = cached{
|
cache[scope] = cached{
|
||||||
access: access,
|
access: access,
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,12 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"gitea.inhome.blapointe.com/local/oauth2"
|
"gitea.inhome.blapointe.com/local/oauth2"
|
||||||
"gitea.inhome.blapointe.com/local/router"
|
"gitea.inhome.blapointe.com/local/router"
|
||||||
"gitea.inhome.blapointe.com/local/storage"
|
"gitea.inhome.blapointe.com/local/storage"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
@ -56,9 +55,9 @@ func (s *Server) authorize(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
func (s *Server) genAuth(scope, user string) {
|
func (s *Server) genAuth(scope, user string) {
|
||||||
access := uuid.New().String()
|
access := uuid.New().String()
|
||||||
|
token := uuid.New().String()
|
||||||
s.store.Set(user, []byte(access), ACCESS)
|
s.store.Set(user, []byte(access), ACCESS)
|
||||||
s.store.Set(user, []byte(user), ACCESS)
|
s.store.Set(scope+"."+access, []byte(token), TOKEN)
|
||||||
s.store.Set(access, []byte(user), ACCESS)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getAccess(scope, user string) (string, bool) {
|
func (s *Server) getAccess(scope, user string) (string, bool) {
|
||||||
|
|
@ -70,10 +69,7 @@ func (s *Server) getAccess(scope, user string) (string, bool) {
|
||||||
return string(access), err == nil
|
return string(access), err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) verifyAccess(access string) error {
|
func (s *Server) getToken(scope, access string) (string, bool) {
|
||||||
_, err := s.store.Get(access, ACCESS)
|
token, err := s.store.Get(scope+"."+access, TOKEN)
|
||||||
if err != nil {
|
return string(token), err == nil
|
||||||
return fmt.Errorf("access not found: %s", access)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"gitea.inhome.blapointe.com/local/oauth2/oauth2server/config"
|
"gitea.inhome.blapointe.com/local/oauth2/oauth2server/config"
|
||||||
"gitea.inhome.blapointe.com/local/router"
|
"gitea.inhome.blapointe.com/local/router"
|
||||||
"gitea.inhome.blapointe.com/local/storage"
|
"gitea.inhome.blapointe.com/local/storage"
|
||||||
|
|
@ -15,6 +14,7 @@ var wildcard = router.Wildcard
|
||||||
const (
|
const (
|
||||||
USERS = "users"
|
USERS = "users"
|
||||||
ACCESS = "access"
|
ACCESS = "access"
|
||||||
|
TOKEN = "token"
|
||||||
SALT = "salt"
|
SALT = "salt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -42,6 +42,10 @@ func purgeIssuedCredentials(store storage.DB) {
|
||||||
for _, access := range accesses {
|
for _, access := range accesses {
|
||||||
store.Set(access, nil, ACCESS)
|
store.Set(access, nil, ACCESS)
|
||||||
}
|
}
|
||||||
|
tokens, _ := store.List([]string{TOKEN})
|
||||||
|
for _, token := range tokens {
|
||||||
|
store.Set(token, nil, TOKEN)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func wrapBody(title, body string) string {
|
func wrapBody(title, body string) string {
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"gitea.inhome.blapointe.com/local/router"
|
"gitea.inhome.blapointe.com/local/router"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) verify(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) verify(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
@ -14,8 +13,10 @@ func (s *Server) verify(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
access := r.FormValue("access")
|
access := r.FormValue("access")
|
||||||
if err := s.verifyAccess(access); err != nil {
|
token, ok := s.getToken(scope, access)
|
||||||
|
if !ok {
|
||||||
http.Error(w, "unknown access", http.StatusUnauthorized)
|
http.Error(w, "unknown access", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
w.Write([]byte(token))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
package oauth2
|
package oauth2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"gitea.inhome.blapointe.com/local/oauth2"
|
||||||
"net"
|
"gitea.inhome.blapointe.com/local/oauth2/oauth2client"
|
||||||
|
"gitea.inhome.blapointe.com/local/oauth2/oauth2server/config"
|
||||||
|
"gitea.inhome.blapointe.com/local/oauth2/oauth2server/server"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
@ -13,16 +15,9 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gitea.inhome.blapointe.com/local/oauth2"
|
|
||||||
"gitea.inhome.blapointe.com/local/oauth2/oauth2client"
|
|
||||||
"gitea.inhome.blapointe.com/local/oauth2/oauth2server/config"
|
|
||||||
"gitea.inhome.blapointe.com/local/oauth2/oauth2server/server"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAll(t *testing.T) {
|
func TestAll(t *testing.T) {
|
||||||
oauth2client.HTTPClient.Transport = makeTransport()
|
|
||||||
|
|
||||||
oauth2server, err := launchServer()
|
oauth2server, err := launchServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
@ -30,38 +25,24 @@ func TestAll(t *testing.T) {
|
||||||
defer oauth2server.Close()
|
defer oauth2server.Close()
|
||||||
oauth2server.URL = strings.ReplaceAll(oauth2server.URL, "127.0.0.1", "echo.belbox.dev")
|
oauth2server.URL = strings.ReplaceAll(oauth2server.URL, "127.0.0.1", "echo.belbox.dev")
|
||||||
|
|
||||||
s := dummyServer(t, oauth2server.URL)
|
s := dummyServer(oauth2server.URL)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
t.Log("createUser...")
|
|
||||||
if err := createUser(oauth2server.URL); err != nil {
|
if err := createUser(oauth2server.URL); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("loginAsUser...")
|
if err := logUser(oauth2server.URL); err != nil {
|
||||||
if err := loginAsuser(oauth2server.URL); err != nil {
|
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("shouldRedir...")
|
|
||||||
if err := shouldRedir(s.URL); err != nil {
|
if err := shouldRedir(s.URL); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("testAuth...")
|
|
||||||
if err := testAuth(oauth2server.URL, s.URL); err != nil {
|
if err := testAuth(oauth2server.URL, s.URL); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("testAuthViaBadBasicAuth...")
|
|
||||||
if err := testAuthViaBadBasicAuth(s.URL); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("testAuthViaBasicAuth...")
|
|
||||||
if err := testAuthViaBasicAuth(s.URL); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func launchServer() (*httptest.Server, error) {
|
func launchServer() (*httptest.Server, error) {
|
||||||
|
|
@ -82,22 +63,18 @@ func launchServer() (*httptest.Server, error) {
|
||||||
return s, err
|
return s, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func dummyServer(t *testing.T, oauth2server string) *httptest.Server {
|
func dummyServer(oauth2server string) *httptest.Server {
|
||||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
err := oauth2client.Authenticate(oauth2server, "scope", w, r)
|
err := oauth2client.Authenticate(oauth2server, "scope", w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("dummy: %s: %v", r.URL.Path, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Logf("dummy: %s: :D", r.URL.Path)
|
|
||||||
fmt.Fprintln(w, "dummy server serving authenticated")
|
fmt.Fprintln(w, "dummy server serving authenticated")
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUser(oauth2server string) error {
|
func createUser(oauth2server string) error {
|
||||||
req, _ := http.NewRequest(http.MethodPost, oauth2server+"/users/submit/scope", strings.NewReader("username=abc"))
|
resp, err := http.Post(oauth2server+"/users/submit/scope", "application/x-www-form-urlencoded", strings.NewReader("username=abc"))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
resp, err := makeClient().Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -108,10 +85,8 @@ func createUser(oauth2server string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loginAsuser(oauth2server string) error {
|
func logUser(oauth2server string) error {
|
||||||
req, _ := http.NewRequest(http.MethodPost, oauth2server+"/authorize/scope", strings.NewReader("username=abc"))
|
resp, err := http.Post(oauth2server+"/authorize/scope", "application/x-www-form-urlencoded", strings.NewReader("username=abc"))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
resp, err := makeClient().Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -134,8 +109,8 @@ func clientShouldRedir(c *http.Client, dummy string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode != 403 || resp.Header.Get("WWW-Authenticate") == "" {
|
if resp.Request.URL.Path != "/users/log/scope" {
|
||||||
return fmt.Errorf("did not need redir without auth: (%d) %q", resp.StatusCode, resp.Header.Get("WWW-Authenticate"))
|
return fmt.Errorf("did not need redir without auth: %v", resp.Request.URL)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -155,13 +130,16 @@ func clientShouldNotRedir(c *http.Client, dummy string) error {
|
||||||
|
|
||||||
func testAuth(oauth2server, dummy string) error {
|
func testAuth(oauth2server, dummy string) error {
|
||||||
c := makeClient()
|
c := makeClient()
|
||||||
|
log.Println("should redir...")
|
||||||
if err := clientShouldRedir(c, dummy); err != nil {
|
if err := clientShouldRedir(c, dummy); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
log.Println("client login...")
|
||||||
access, err := clientLogin(c, oauth2server)
|
access, err := clientLogin(c, oauth2server)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
log.Println("client should not redir...")
|
||||||
if err := clientShouldNotRedir(c, dummy+"?"+oauth2.NEWCOOKIE+"="+access); err != nil {
|
if err := clientShouldNotRedir(c, dummy+"?"+oauth2.NEWCOOKIE+"="+access); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -200,58 +178,9 @@ func clientLogin(c *http.Client, oauth2server string) (string, error) {
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func testAuthViaBadBasicAuth(dummy string) error {
|
|
||||||
c := makeClient()
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, dummy, nil)
|
|
||||||
req.SetBasicAuth("u", "p")
|
|
||||||
resp, err := c.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
if resp.StatusCode != 403 {
|
|
||||||
b, _ := io.ReadAll(resp.Body)
|
|
||||||
return fmt.Errorf("got through with bad basic auth set: (%d) %s", resp.StatusCode, b)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func testAuthViaBasicAuth(dummy string) error {
|
|
||||||
c := makeClient()
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, dummy, nil)
|
|
||||||
req.SetBasicAuth("", "abc")
|
|
||||||
resp, err := c.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
b, _ := io.ReadAll(resp.Body)
|
|
||||||
return fmt.Errorf("failed to get through with basic auth set: (%d) %s", resp.StatusCode, b)
|
|
||||||
}
|
|
||||||
if !strings.Contains(fmt.Sprint(c.Jar), oauth2.COOKIE) {
|
|
||||||
return errors.New("cookie jar empty:" + fmt.Sprint(c.Jar))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeClient() *http.Client {
|
func makeClient() *http.Client {
|
||||||
jar, _ := cookiejar.New(&cookiejar.Options{})
|
jar, _ := cookiejar.New(&cookiejar.Options{})
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Jar: jar,
|
Jar: jar,
|
||||||
Transport: makeTransport(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTransport() *http.Transport {
|
|
||||||
return &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
|
||||||
parts := strings.Split(addr, ":")
|
|
||||||
port := "80"
|
|
||||||
if len(parts) > 0 {
|
|
||||||
port = parts[1]
|
|
||||||
}
|
|
||||||
return (&net.Dialer{}).DialContext(ctx, "tcp4", fmt.Sprintf("127.0.0.1:%s", port))
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue