Optional oauth via + flag

master
Bel LaPointe 2019-11-03 07:55:38 -07:00
parent 7d3d6d88f6
commit 01b7b06971
4 changed files with 9 additions and 7 deletions

View File

@ -7,5 +7,5 @@ crt: ""
key: "" key: ""
tcp: "" tcp: ""
timeout: 1m timeout: 1m
proxy: a,http://localhost:41912,,b,http://localhost:41912 proxy: a,http://localhost:41912,,+b,http://localhost:41912
oauth: http://localhost:23456 oauth: http://localhost:23456

View File

@ -45,7 +45,7 @@ func parseArgs() (*args.ArgSet, error) {
as.Append(args.STRING, "key", "path to key for ssl", "") as.Append(args.STRING, "key", "path to key for ssl", "")
as.Append(args.STRING, "tcp", "address for tcp only tunnel", "") as.Append(args.STRING, "tcp", "address for tcp only tunnel", "")
as.Append(args.DURATION, "timeout", "timeout for tunnel", time.Minute) as.Append(args.DURATION, "timeout", "timeout for tunnel", time.Minute)
as.Append(args.STRING, "proxy", "double-comma separated from,scheme://to.tld:port,oauth,,", "") as.Append(args.STRING, "proxy", "double-comma separated (+ if oauth)from,scheme://to.tld:port,oauth,,", "")
as.Append(args.STRING, "oauth", "url for boauthz", "") as.Append(args.STRING, "oauth", "url for boauthz", "")
err := as.Parse() err := as.Parse()

View File

@ -52,7 +52,7 @@ func (s *Server) lookup(host string) (*url.URL, error) {
func (s *Server) lookupBOAuthZ(host string) (bool, error) { func (s *Server) lookupBOAuthZ(host string) (bool, error) {
v := packable.NewString() v := packable.NewString()
err := s.db.Get(nsBOAuthZ, host, v) err := s.db.Get(nsBOAuthZ, host, v)
return v.String() != "", err return v.String() == "true", err
} }
func mapKey(host string) string { func mapKey(host string) string {

View File

@ -53,12 +53,14 @@ type Server struct {
} }
func (s *Server) Route(src string, dst config.Proxy) error { func (s *Server) Route(src string, dst config.Proxy) error {
hasOAuth := strings.HasPrefix(src, "+")
src = strings.TrimPrefix(src, "+")
log.Printf("Adding route %q -> %v...\n", src, dst) log.Printf("Adding route %q -> %v...\n", src, dst)
u, err := url.Parse(dst.To) u, err := url.Parse(dst.To)
if err != nil { if err != nil {
return err return err
} }
s.db.Set(nsBOAuthZ, src, packable.NewString(fmt.Sprint(dst.BOAuthZ))) s.db.Set(nsBOAuthZ, src, packable.NewString(fmt.Sprint(hasOAuth)))
return s.db.Set(nsRouting, src, packable.NewURL(u)) return s.db.Set(nsRouting, src, packable.NewURL(u))
} }
@ -114,14 +116,14 @@ func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc {
return return
} }
} }
ok, err := s.lookupBOAuthZ(mapKey(r.Host)) key := mapKey(r.Host)
ok, err := s.lookupBOAuthZ(key)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
if url, exists := config.GetBOAuthZ(); ok && exists { if url, exists := config.GetBOAuthZ(); ok && exists {
name := mapKey(r.Host) err := oauth2client.Authenticate(url, key, w, r)
err := oauth2client.Authenticate(url, name, w, r)
if err != nil { if err != nil {
return return
} }