change conf to argsset and flag for oauth
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"local/rproxy3/config"
|
||||
"local/rproxy3/storage/packable"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -33,10 +32,6 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
targetHost: newURL.Host,
|
||||
baseTransport: http.DefaultTransport,
|
||||
}
|
||||
transport = &rewrite{
|
||||
rewrites: config.GetRewrites(mapKey(r.Host)),
|
||||
baseTransport: transport,
|
||||
}
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
log.Printf("unknown host lookup %q", r.Host)
|
||||
@@ -54,6 +49,12 @@ func (s *Server) lookup(host string) (*url.URL, error) {
|
||||
return v.URL(), err
|
||||
}
|
||||
|
||||
func (s *Server) lookupBOAuthZ(host string) (bool, error) {
|
||||
v := packable.NewString()
|
||||
err := s.db.Get(nsBOAuthZ, host, v)
|
||||
return v.String() != "", err
|
||||
}
|
||||
|
||||
func mapKey(host string) string {
|
||||
host = strings.Split(host, ".")[0]
|
||||
host = strings.Split(host, ":")[0]
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"local/oauth2/oauth2client"
|
||||
"local/rproxy3/config"
|
||||
"local/rproxy3/storage"
|
||||
"local/rproxy3/storage/packable"
|
||||
@@ -20,6 +22,7 @@ import (
|
||||
)
|
||||
|
||||
const nsRouting = "routing"
|
||||
const nsBOAuthZ = "oauth"
|
||||
|
||||
type listenerScheme int
|
||||
|
||||
@@ -49,12 +52,13 @@ type Server struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func (s *Server) Route(src, dst string) error {
|
||||
log.Printf("Adding route %q -> %q...\n", src, dst)
|
||||
u, err := url.Parse(dst)
|
||||
func (s *Server) Route(src string, dst config.Proxy) error {
|
||||
log.Printf("Adding route %q -> %v...\n", src, dst)
|
||||
u, err := url.Parse(dst.To)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.db.Set(nsBOAuthZ, src, packable.NewString(fmt.Sprint(dst.BOAuthZ)))
|
||||
return s.db.Set(nsRouting, src, packable.NewURL(u))
|
||||
}
|
||||
|
||||
@@ -103,7 +107,6 @@ func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rusr, rpwd, ok := config.GetAuth()
|
||||
if ok {
|
||||
//usr, pwd := getProxyAuth(r)
|
||||
usr, pwd, ok := r.BasicAuth()
|
||||
if !ok || rusr != usr || rpwd != pwd {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
@@ -111,6 +114,17 @@ func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
}
|
||||
ok, err := s.lookupBOAuthZ(mapKey(r.Host))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if boauthz, useoauth := config.GetBOAuthZ(); ok && useoauth {
|
||||
err := oauth2client.Authenticate(boauthz, w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
foo(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"local/rproxy3/config"
|
||||
"local/rproxy3/storage"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -15,7 +16,10 @@ import (
|
||||
func TestServerStart(t *testing.T) {
|
||||
server := mockServer()
|
||||
|
||||
if err := server.Route("world", "http://hello.localhost"+server.addr); err != nil {
|
||||
p := config.Proxy{
|
||||
To: "http://hello.localhost" + server.addr,
|
||||
}
|
||||
if err := server.Route("world", p); err != nil {
|
||||
t.Fatalf("cannot add route: %v", err)
|
||||
}
|
||||
|
||||
@@ -48,7 +52,10 @@ func mockServer() *Server {
|
||||
|
||||
func TestServerRoute(t *testing.T) {
|
||||
server := mockServer()
|
||||
if err := server.Route("world", "http://hello.localhost"+server.addr); err != nil {
|
||||
p := config.Proxy{
|
||||
To: "http://hello.localhost" + server.addr,
|
||||
}
|
||||
if err := server.Route("world", p); err != nil {
|
||||
t.Fatalf("cannot add route: %v", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
Reference in New Issue
Block a user