Wannabe oauth implementation

This commit is contained in:
bel
2019-10-20 12:59:50 -06:00
commit 6dc2e074fb
11 changed files with 548 additions and 0 deletions

View File

@@ -0,0 +1,53 @@
package config
import (
"fmt"
"local/args"
"os"
"strings"
)
var (
Port string
Store string
StoreAddr string
StoreUser string
StorePass string
SecretA string
SecretB string
)
func init() {
Refresh()
}
func Refresh() {
if strings.Contains(fmt.Sprint(os.Args), "-test") {
return
}
defer func() {
if err := recover(); err != nil {
panic(err)
}
}()
as := args.NewArgSet()
as.Append(args.STRING, "port", "port to listen on", "23456")
as.Append(args.STRING, "secretA", "secret A", "secret")
as.Append(args.STRING, "secretB", "secret B", "secret")
as.Append(args.STRING, "store", "type of DB", "map")
as.Append(args.STRING, "storeAddr", "addr of DB", "/tmp/oauth2server.db")
as.Append(args.STRING, "storeUser", "user of DB", "")
as.Append(args.STRING, "storePass", "pass of DB", "")
if err := as.Parse(); err != nil {
panic(err)
}
Port = ":" + strings.TrimPrefix(as.Get("port").GetString(), ":")
Store = as.Get("store").GetString()
StoreAddr = as.Get("storeaddr").GetString()
StoreUser = as.Get("storeuser").GetString()
StorePass = as.Get("storepass").GetString()
SecretA = as.Get("secreta").GetString()
SecretB = as.Get("secretb").GetString()
}

19
oauth2server/main.go Normal file
View File

@@ -0,0 +1,19 @@
package main
import (
"local/oauth2/oauth2server/config"
"local/oauth2/oauth2server/server"
"log"
"net/http"
)
func main() {
s := server.New()
if err := s.Routes(); err != nil {
panic(err)
}
log.Println("listening on", config.Port)
if err := http.ListenAndServe(config.Port, s); err != nil {
panic(err)
}
}

BIN
oauth2server/oauth2server Executable file

Binary file not shown.

View File

@@ -0,0 +1,59 @@
package server
import (
"local/oauth2"
"local/storage"
"net/http"
"github.com/google/uuid"
)
func (s *Server) authorize(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.NotFound(w, r)
return
}
id := r.FormValue("username")
user, ok := s.getUser(id)
if !ok {
http.Error(w, "unknown user", http.StatusForbidden)
return
}
access, ok := s.getAccess(user)
if !ok {
http.Error(w, "no oauth for user", http.StatusForbidden)
return
}
cookie := &http.Cookie{
Name: oauth2.COOKIE,
Value: access,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, cookie)
redirectCookie, err := r.Cookie(oauth2.REDIRECT)
if err != nil {
return
}
http.Redirect(w, r, redirectCookie.Value, http.StatusSeeOther)
}
func (s *Server) genAuth(user string) {
access := uuid.New().String()
token := uuid.New().String()
s.store.Set(user, []byte(access), ACCESS)
s.store.Set(access, []byte(token), TOKEN)
}
func (s *Server) getAccess(user string) (string, bool) {
access, err := s.store.Get(user, ACCESS)
if err == storage.ErrNotFound {
s.genAuth(user)
access, err = s.store.Get(user, ACCESS)
}
return string(access), err == nil
}
func (s *Server) getToken(access string) (string, bool) {
token, err := s.store.Get(access, TOKEN)
return string(token), err == nil
}

View File

@@ -0,0 +1,41 @@
package server
import (
"fmt"
"net/http"
)
func (s *Server) Routes() error {
endpoints := []struct {
path string
handler http.HandlerFunc
}{
{
path: fmt.Sprintf("authorize"),
handler: s.authorize,
},
{
path: fmt.Sprintf("verify"),
handler: s.verify,
},
{
path: fmt.Sprintf("users/log"),
handler: s.usersLog,
},
{
path: fmt.Sprintf("users/register"),
handler: s.usersRegister,
},
{
path: fmt.Sprintf("users/submit"),
handler: s.usersSubmit,
},
}
for _, endpoint := range endpoints {
if err := s.Add(endpoint.path, endpoint.handler); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,44 @@
package server
import (
"local/oauth2/oauth2server/config"
"local/router"
"local/storage"
)
var wildcard = router.Wildcard
const (
USERS = "users"
ACCESS = "access"
TOKEN = "token"
SALT = "salt"
)
type Server struct {
*router.Router
store storage.DB
}
func New() *Server {
store, err := storage.New(storage.TypeFromString(config.Store), config.StoreAddr, config.StoreUser, config.StorePass)
if err != nil {
panic(err)
}
purgeIssuedCredentials(store)
return &Server{
Router: router.New(),
store: store,
}
}
func purgeIssuedCredentials(store storage.DB) {
accesses, _ := store.List([]string{ACCESS})
for _, access := range accesses {
store.Set(access, nil, ACCESS)
}
tokens, _ := store.List([]string{TOKEN})
for _, token := range tokens {
store.Set(token, nil, TOKEN)
}
}

View File

@@ -0,0 +1,82 @@
package server
import (
"crypto/hmac"
"crypto/md5"
"crypto/sha256"
"encoding/hex"
"fmt"
"local/oauth2/oauth2server/config"
"net/http"
"github.com/google/uuid"
)
func (s *Server) usersLog(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `
<html>
<body>
<form method="post" action="/authorize">
<input type="text" name="username"></input>
<input type="submit"></input>
</form>
</body>
</html>
`)
}
func (s *Server) usersRegister(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `
<html>
<body>
<form method="post" action="/users/submit">
<input type="text" name="username"></input>
<input type="submit"></input>
</form>
</body>
</html>
`)
}
func (s *Server) usersSubmit(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.NotFound(w, r)
return
}
id := r.FormValue("username")
if _, ok := s.getUser(id); ok {
http.Error(w, "user already exists", http.StatusConflict)
return
}
s.genUser(id)
}
func (s *Server) genUser(id string) {
user := uuid.New().String()
salt := uuid.New().String()
s.store.Set(id, []byte(salt), SALT)
obscured := s.obscureID(id)
s.store.Set(obscured, []byte(user), USERS)
}
func (s *Server) getUser(id string) (string, bool) {
obscured := s.obscureID(id)
user, err := s.store.Get(obscured, USERS)
return string(user), err == nil
}
func (s *Server) obscureID(id string) string {
salt, _ := s.store.Get(id, SALT)
a := s.obscure(string(salt)+id, config.SecretA)
b := s.obscure(string(salt)+id, config.SecretB)
return a + b
}
func (s *Server) obscure(payload, secret string) string {
hash := md5.New()
hash.Write([]byte(secret))
key := hash.Sum(nil)
sig := hmac.New(sha256.New, key)
sig.Write([]byte(payload))
return hex.EncodeToString(sig.Sum(nil))
}

View File

@@ -0,0 +1,19 @@
package server
import (
"net/http"
)
func (s *Server) verify(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.NotFound(w, r)
return
}
access := r.FormValue("access")
token, ok := s.getToken(access)
if !ok {
http.Error(w, "unknown access", http.StatusUnauthorized)
return
}
w.Write([]byte(token))
}