Wannabe oauth implementation
commit
6dc2e074fb
|
|
@ -0,0 +1,6 @@
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
const (
|
||||||
|
COOKIE = "BOAuthZ"
|
||||||
|
REDIRECT = "BOAuthZ-Redirect"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
package oauth2client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"local/oauth2"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Authenticate(server string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
oauth2server, err := url.Parse(server)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
access, err := r.Cookie(oauth2.COOKIE)
|
||||||
|
if err == http.ErrNoCookie {
|
||||||
|
return login(oauth2server, w, r)
|
||||||
|
}
|
||||||
|
return verify(access.Value, oauth2server, w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func login(oauth2server *url.URL, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
oauth2server.Path = "/users/log"
|
||||||
|
url := *r.URL
|
||||||
|
url.Host = r.Host
|
||||||
|
if url.Scheme == "" {
|
||||||
|
url.Scheme = "http"
|
||||||
|
}
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: oauth2.REDIRECT,
|
||||||
|
Value: url.String(),
|
||||||
|
}
|
||||||
|
http.SetCookie(w, cookie)
|
||||||
|
http.Redirect(w, r, oauth2server.String(), http.StatusSeeOther)
|
||||||
|
return errors.New("logging in")
|
||||||
|
}
|
||||||
|
|
||||||
|
func verify(access string, oauth2server *url.URL, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
oauth2server.Path = "/verify"
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("access", access)
|
||||||
|
req, err := http.NewRequest("POST", oauth2server.String(), strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
|
||||||
|
c := &http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
resp, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return login(oauth2server, w, r)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"local/args"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
Port string
|
||||||
|
Store string
|
||||||
|
StoreAddr string
|
||||||
|
StoreUser string
|
||||||
|
StorePass string
|
||||||
|
SecretA string
|
||||||
|
SecretB string
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Refresh()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Refresh() {
|
||||||
|
if strings.Contains(fmt.Sprint(os.Args), "-test") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
as := args.NewArgSet()
|
||||||
|
as.Append(args.STRING, "port", "port to listen on", "23456")
|
||||||
|
as.Append(args.STRING, "secretA", "secret A", "secret")
|
||||||
|
as.Append(args.STRING, "secretB", "secret B", "secret")
|
||||||
|
as.Append(args.STRING, "store", "type of DB", "map")
|
||||||
|
as.Append(args.STRING, "storeAddr", "addr of DB", "/tmp/oauth2server.db")
|
||||||
|
as.Append(args.STRING, "storeUser", "user of DB", "")
|
||||||
|
as.Append(args.STRING, "storePass", "pass of DB", "")
|
||||||
|
if err := as.Parse(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
Port = ":" + strings.TrimPrefix(as.Get("port").GetString(), ":")
|
||||||
|
Store = as.Get("store").GetString()
|
||||||
|
StoreAddr = as.Get("storeaddr").GetString()
|
||||||
|
StoreUser = as.Get("storeuser").GetString()
|
||||||
|
StorePass = as.Get("storepass").GetString()
|
||||||
|
SecretA = as.Get("secreta").GetString()
|
||||||
|
SecretB = as.Get("secretb").GetString()
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"local/oauth2/oauth2server/config"
|
||||||
|
"local/oauth2/oauth2server/server"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
s := server.New()
|
||||||
|
if err := s.Routes(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
log.Println("listening on", config.Port)
|
||||||
|
if err := http.ListenAndServe(config.Port, s); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Binary file not shown.
|
|
@ -0,0 +1,59 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"local/oauth2"
|
||||||
|
"local/storage"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) authorize(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := r.FormValue("username")
|
||||||
|
user, ok := s.getUser(id)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "unknown user", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
access, ok := s.getAccess(user)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "no oauth for user", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: oauth2.COOKIE,
|
||||||
|
Value: access,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}
|
||||||
|
http.SetCookie(w, cookie)
|
||||||
|
redirectCookie, err := r.Cookie(oauth2.REDIRECT)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, redirectCookie.Value, http.StatusSeeOther)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) genAuth(user string) {
|
||||||
|
access := uuid.New().String()
|
||||||
|
token := uuid.New().String()
|
||||||
|
s.store.Set(user, []byte(access), ACCESS)
|
||||||
|
s.store.Set(access, []byte(token), TOKEN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getAccess(user string) (string, bool) {
|
||||||
|
access, err := s.store.Get(user, ACCESS)
|
||||||
|
if err == storage.ErrNotFound {
|
||||||
|
s.genAuth(user)
|
||||||
|
access, err = s.store.Get(user, ACCESS)
|
||||||
|
}
|
||||||
|
return string(access), err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getToken(access string) (string, bool) {
|
||||||
|
token, err := s.store.Get(access, TOKEN)
|
||||||
|
return string(token), err == nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) Routes() error {
|
||||||
|
endpoints := []struct {
|
||||||
|
path string
|
||||||
|
handler http.HandlerFunc
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
path: fmt.Sprintf("authorize"),
|
||||||
|
handler: s.authorize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: fmt.Sprintf("verify"),
|
||||||
|
handler: s.verify,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: fmt.Sprintf("users/log"),
|
||||||
|
handler: s.usersLog,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: fmt.Sprintf("users/register"),
|
||||||
|
handler: s.usersRegister,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: fmt.Sprintf("users/submit"),
|
||||||
|
handler: s.usersSubmit,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, endpoint := range endpoints {
|
||||||
|
if err := s.Add(endpoint.path, endpoint.handler); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"local/oauth2/oauth2server/config"
|
||||||
|
"local/router"
|
||||||
|
"local/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
var wildcard = router.Wildcard
|
||||||
|
|
||||||
|
const (
|
||||||
|
USERS = "users"
|
||||||
|
ACCESS = "access"
|
||||||
|
TOKEN = "token"
|
||||||
|
SALT = "salt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
*router.Router
|
||||||
|
store storage.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Server {
|
||||||
|
store, err := storage.New(storage.TypeFromString(config.Store), config.StoreAddr, config.StoreUser, config.StorePass)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
purgeIssuedCredentials(store)
|
||||||
|
return &Server{
|
||||||
|
Router: router.New(),
|
||||||
|
store: store,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func purgeIssuedCredentials(store storage.DB) {
|
||||||
|
accesses, _ := store.List([]string{ACCESS})
|
||||||
|
for _, access := range accesses {
|
||||||
|
store.Set(access, nil, ACCESS)
|
||||||
|
}
|
||||||
|
tokens, _ := store.List([]string{TOKEN})
|
||||||
|
for _, token := range tokens {
|
||||||
|
store.Set(token, nil, TOKEN)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/md5"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"local/oauth2/oauth2server/config"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) usersLog(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, `
|
||||||
|
<html>
|
||||||
|
<body>
|
||||||
|
<form method="post" action="/authorize">
|
||||||
|
<input type="text" name="username"></input>
|
||||||
|
<input type="submit"></input>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) usersRegister(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, `
|
||||||
|
<html>
|
||||||
|
<body>
|
||||||
|
<form method="post" action="/users/submit">
|
||||||
|
<input type="text" name="username"></input>
|
||||||
|
<input type="submit"></input>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) usersSubmit(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := r.FormValue("username")
|
||||||
|
if _, ok := s.getUser(id); ok {
|
||||||
|
http.Error(w, "user already exists", http.StatusConflict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.genUser(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) genUser(id string) {
|
||||||
|
user := uuid.New().String()
|
||||||
|
salt := uuid.New().String()
|
||||||
|
s.store.Set(id, []byte(salt), SALT)
|
||||||
|
obscured := s.obscureID(id)
|
||||||
|
s.store.Set(obscured, []byte(user), USERS)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getUser(id string) (string, bool) {
|
||||||
|
obscured := s.obscureID(id)
|
||||||
|
user, err := s.store.Get(obscured, USERS)
|
||||||
|
return string(user), err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) obscureID(id string) string {
|
||||||
|
salt, _ := s.store.Get(id, SALT)
|
||||||
|
a := s.obscure(string(salt)+id, config.SecretA)
|
||||||
|
b := s.obscure(string(salt)+id, config.SecretB)
|
||||||
|
return a + b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) obscure(payload, secret string) string {
|
||||||
|
hash := md5.New()
|
||||||
|
hash.Write([]byte(secret))
|
||||||
|
key := hash.Sum(nil)
|
||||||
|
sig := hmac.New(sha256.New, key)
|
||||||
|
sig.Write([]byte(payload))
|
||||||
|
return hex.EncodeToString(sig.Sum(nil))
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) verify(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
access := r.FormValue("access")
|
||||||
|
token, ok := s.getToken(access)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "unknown access", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write([]byte(token))
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,162 @@
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"local/oauth2/oauth2client"
|
||||||
|
"local/oauth2/oauth2server/config"
|
||||||
|
"local/oauth2/oauth2server/server"
|
||||||
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
|
"net/http/httptest"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAll(t *testing.T) {
|
||||||
|
oauth2server, err := launchServer()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer oauth2server.Close()
|
||||||
|
|
||||||
|
s := dummyServer(oauth2server.URL)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
if err := createUser(oauth2server.URL); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := logUser(oauth2server.URL); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := shouldRedir(s.URL); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := testAuth(oauth2server.URL, s.URL); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchServer() (*httptest.Server, error) {
|
||||||
|
config.Store = "map"
|
||||||
|
|
||||||
|
oauth2server := server.New()
|
||||||
|
err := oauth2server.Routes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s := httptest.NewServer(oauth2server)
|
||||||
|
|
||||||
|
re := regexp.MustCompile(":[0-9]*")
|
||||||
|
port := re.FindString(s.URL)
|
||||||
|
config.Port = port
|
||||||
|
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func dummyServer(oauth2server string) *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := oauth2client.Authenticate(oauth2server, w, r)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintln(w, "dummy server serving authenticated")
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUser(oauth2server string) error {
|
||||||
|
resp, err := http.Post(oauth2server+"/users/submit", "application/x-www-form-urlencoded", strings.NewReader("username=abc"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return errors.New("bad status " + resp.Status)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func logUser(oauth2server string) error {
|
||||||
|
resp, err := http.Post(oauth2server+"/authorize", "application/x-www-form-urlencoded", strings.NewReader("username=abc"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return errors.New("bad status " + resp.Status)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldRedir(dummy string) error {
|
||||||
|
c := makeClient()
|
||||||
|
return clientShouldRedir(c, dummy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientShouldRedir(c *http.Client, dummy string) error {
|
||||||
|
req, _ := http.NewRequest("GET", dummy, nil)
|
||||||
|
resp, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.Request.URL.Path != "/users/log" {
|
||||||
|
return fmt.Errorf("did not need redir without auth: %v", resp.Request.URL)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientShouldNotRedir(c *http.Client, dummy string) error {
|
||||||
|
req, _ := http.NewRequest("GET", dummy, nil)
|
||||||
|
resp, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.Request.URL.Path == "/users/log" {
|
||||||
|
return fmt.Errorf("did redir with auth: %v", resp.Request.URL.Path)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAuth(oauth2server, dummy string) error {
|
||||||
|
c := makeClient()
|
||||||
|
if err := clientShouldRedir(c, dummy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := clientLogin(c, oauth2server); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := clientShouldNotRedir(c, dummy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientLogin(c *http.Client, oauth2server string) error {
|
||||||
|
req, _ := http.NewRequest("POST", oauth2server+"/authorize", strings.NewReader("username=abc"))
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
resp, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("bad status; %v", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if resp.Request.URL.Path != "/" {
|
||||||
|
return fmt.Errorf("login response path wrong: %v", resp.Request.URL)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeClient() *http.Client {
|
||||||
|
jar, _ := cookiejar.New(&cookiejar.Options{})
|
||||||
|
return &http.Client{
|
||||||
|
Jar: jar,
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue