diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..e5d6f21 --- /dev/null +++ b/config/config.go @@ -0,0 +1,54 @@ +package config + +import ( + "local/rproxy3/storage/packable" + "strings" +) + +func GetPort() string { + v := packable.NewString() + conf.Get(nsConf, flagPort, v) + return v.String() +} + +func GetRoutes() map[string]string { + v := packable.NewString() + conf.Get(nsConf, flagRoutes, v) + m := make(map[string]string) + for _, v := range strings.Split(v.String(), ",") { + if len(v) == 0 { + return m + } + from := v[:strings.Index(v, ":")] + to := v[strings.Index(v, ":")+1:] + m[from] = to + } + return m +} + +func GetSSL() (string, string, bool) { + v := packable.NewString() + conf.Get(nsConf, flagCert, v) + certPath := v.String() + conf.Get(nsConf, flagKey, v) + keyPath := v.String() + return certPath, keyPath, notEmpty(certPath, keyPath) +} + +func GetAuth() (string, string, bool) { + v := packable.NewString() + conf.Get(nsConf, flagUser, v) + user := v.String() + conf.Get(nsConf, flagPass, v) + pass := v.String() + return user, pass, notEmpty(user, pass) +} + +func notEmpty(s ...string) bool { + for i := range s { + if s[i] == "" || s[i] == "/dev/null" { + return false + } + } + return true +} diff --git a/config/new.go b/config/new.go new file mode 100644 index 0000000..2513f7c --- /dev/null +++ b/config/new.go @@ -0,0 +1,114 @@ +package config + +import ( + "flag" + "io/ioutil" + "local/rproxy3/storage" + "local/rproxy3/storage/packable" + "log" + "strings" + + yaml "gopkg.in/yaml.v2" +) + +const nsConf = "configuration" +const flagPort = "p" +const flagRoutes = "r" +const flagConf = "c" +const flagCert = "crt" +const flagKey = "key" +const flagUser = "user" +const flagPass = "pass" + +var conf = storage.NewMap() + +type toBind struct { + flag string + value *string +} + +type fileConf struct { + Port string `yaml:"port"` + Routes []string `yaml:"routes"` + CertPath string `yaml:"cert"` + KeyPath string `yaml:"key"` + Username string `yaml:"user"` + Password string `yaml:"pass"` +} + +func Init() error { + log.SetFlags(log.Ldate | log.Ltime | log.Llongfile) + if err := fromFile(); err != nil { + return err + } + if err := fromFlags(); err != nil { + return err + } + return nil +} + +func fromFile() error { + flag.String(flagConf, "/dev/null", "yaml config file path") + confFlag := flag.Lookup(flagConf) + if confFlag == nil || confFlag.Value.String() == "" { + return nil + } + confBytes, err := ioutil.ReadFile(confFlag.Value.String()) + if err != nil { + return err + } + var c fileConf + if err := yaml.Unmarshal(confBytes, &c); err != nil { + return err + } + if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil { + return err + } + if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil { + return err + } + if err := conf.Set(nsConf, flagCert, packable.NewString(c.CertPath)); err != nil { + return err + } + if err := conf.Set(nsConf, flagKey, packable.NewString(c.KeyPath)); err != nil { + return err + } + if err := conf.Set(nsConf, flagUser, packable.NewString(c.Username)); err != nil { + return err + } + if err := conf.Set(nsConf, flagPass, packable.NewString(c.Password)); err != nil { + return err + } + return nil +} + +func fromFlags() error { + binds := make([]toBind, 0) + binds = append(binds, addFlag(flagPort, "51555", "port to bind to")) + binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port")) + binds = append(binds, addFlag(flagCert, "", "path to .crt")) + binds = append(binds, addFlag(flagKey, "", "path to .key")) + binds = append(binds, addFlag(flagUser, "", "basic auth username")) + binds = append(binds, addFlag(flagPass, "", "basic auth password")) + flag.Parse() + + for _, bind := range binds { + confFlag := flag.Lookup(bind.flag) + if confFlag == nil || confFlag.Value.String() == "" { + continue + } + if err := conf.Set(nsConf, bind.flag, packable.NewString(*bind.value)); err != nil { + return err + } + } + + return nil +} + +func addFlag(key, def, help string) toBind { + v := flag.String(key, def, help) + return toBind{ + flag: key, + value: v, + } +} diff --git a/config/new_test.go b/config/new_test.go new file mode 100644 index 0000000..8e1c7b3 --- /dev/null +++ b/config/new_test.go @@ -0,0 +1,46 @@ +package config + +import ( + "flag" + "os" + "testing" +) + +func TestInit(t *testing.T) { + was := os.Args[:] + os.Args = []string{"program"} + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + defer func() { + os.Args = was[:] + }() + + if err := Init(); err != nil { + t.Errorf("failed to init: %v", err) + } +} + +func TestFromFile(t *testing.T) { + was := os.Args[:] + os.Args = []string{"program"} + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + defer func() { + os.Args = was[:] + }() + + if err := fromFile(); err != nil { + t.Errorf("failed from file: %v", err) + } +} + +func TestFromFlags(t *testing.T) { + was := os.Args[:] + os.Args = []string{"program"} + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + defer func() { + os.Args = was[:] + }() + + if err := fromFlags(); err != nil { + t.Errorf("failed from flags: %v", err) + } +} diff --git a/main.go b/main.go index d39e298..902cc6d 100644 --- a/main.go +++ b/main.go @@ -1,17 +1,20 @@ package main import ( - "flag" - "local/s2sa/s2sa/server" + "local/rproxy3/config" + "local/rproxy3/server" ) func main() { - flag.Parse() + if err := config.Init(); err != nil { + panic(err) + } - server := server.New("") + server := server.New() if err := server.Routes(); err != nil { panic(err) } + if err := server.Run(); err != nil { panic(err) } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..9a3deb4 --- /dev/null +++ b/main_test.go @@ -0,0 +1,131 @@ +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" +) + +func TestHTTPSMain(t *testing.T) { + was := os.Args[:] + os.Args = []string{"program"} + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + defer func() { + os.Args = was[:] + }() + + addr, stop := echoServer() + defer stop() + ported := make(chan string) + go func() { + p := getPort() + ported <- p + os.Args = []string{ + "foobar", + "-p", + p, + "-user", + "username", + "-pass", + "password", + "-r", + "hello:" + addr, + "-crt", + "./testdata/rproxy3server.crt", + "-key", + "./testdata/rproxy3server.key", + } + main() + }() + port := <-ported + time.Sleep(time.Millisecond * 100) + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + r, _ := http.NewRequest("GET", "https://hello.localhost"+port, nil) + + if resp, err := client.Do(r); err != nil { + t.Fatalf("client failed: %v", err) + } else if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("proxy failed: code %v != %v", resp.StatusCode, http.StatusUnauthorized) + } + + r.SetBasicAuth("username", "password") + if resp, err := client.Do(r); err != nil { + t.Fatalf("client failed: %v", err) + } else if resp.StatusCode != http.StatusOK { + t.Errorf("proxy failed: code %v != %v", resp.StatusCode, http.StatusOK) + } +} + +func TestHTTPMain(t *testing.T) { + was := os.Args[:] + os.Args = []string{"program"} + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + defer func() { + os.Args = was[:] + }() + + addr, stop := echoServer() + defer stop() + ported := make(chan string) + go func() { + p := getPort() + ported <- p + os.Args = []string{ + "foobar", + "-p", + p, + "-user", + "username", + "-pass", + "password", + "-r", + "hello:" + addr, + } + main() + }() + port := <-ported + time.Sleep(time.Millisecond * 100) + + client := &http.Client{} + r, _ := http.NewRequest("GET", "http://hello.localhost"+port, nil) + + if resp, err := client.Do(r); err != nil { + t.Fatalf("client failed: %v", err) + } else if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("proxy failed: code %v != %v", resp.StatusCode, http.StatusUnauthorized) + } + + r.SetBasicAuth("username", "password") + if resp, err := client.Do(r); err != nil { + t.Fatalf("client failed: %v", err) + } else if resp.StatusCode != http.StatusOK { + t.Errorf("proxy failed: code %v != %v", resp.StatusCode, http.StatusOK) + } +} + +func echoServer() (string, func()) { + h := func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "hello") + } + portsrv := httptest.NewServer(http.HandlerFunc(h)) + return portsrv.URL, func() { + portsrv.Close() + } +} + +func getPort() string { + s := httptest.NewServer(nil) + s.Close() + return s.URL[strings.LastIndex(s.URL, ":"):] +} diff --git a/server/new.go b/server/new.go index 2da9d8b..cbc94a5 100644 --- a/server/new.go +++ b/server/new.go @@ -1,30 +1,14 @@ package server import ( - "local/s2sa/s2sa/server/router" - "local/s2sa/s2sa/services" - "local/s2sa/s2sa/storage" + "local/rproxy3/config" + "local/rproxy3/storage" ) -func New(path string) *Server { - var db storage.DB - db = storage.NewMap() - if len(path) > 0 { - var err error - db, err = storage.NewBolt(path) - if err != nil { - return nil - } +func New() *Server { + port := config.GetPort() + return &Server{ + db: storage.NewMap(), + addr: port, } - authdb := storage.NewMap() - s := &Server{ - db: services.New(db), - authdb: services.New(authdb), - router: router.New(), - addr: ":18341", - } - if err := s.authdb.Register(serverNS); err != nil { - panic(err) - } - return s } diff --git a/server/new_test.go b/server/new_test.go index 97e766f..fda3aaa 100644 --- a/server/new_test.go +++ b/server/new_test.go @@ -1,11 +1,9 @@ package server import ( - "os" "testing" ) func TestServerNew(t *testing.T) { - New("") - New(os.DevNull) + New() } diff --git a/server/proxy.go b/server/proxy.go new file mode 100644 index 0000000..9b14904 --- /dev/null +++ b/server/proxy.go @@ -0,0 +1,29 @@ +package server + +import ( + "local/rproxy3/storage/packable" + "log" + "net/http" + "net/http/httputil" + "net/url" + "strings" +) + +func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { + newURL, err := s.lookup(r.Host) + if err != nil { + http.NotFound(w, r) + log.Printf("unknown host lookup %q", r.Host) + return + } + r.Host = newURL.Host + proxy := httputil.NewSingleHostReverseProxy(newURL) + proxy.ServeHTTP(w, r) +} + +func (s *Server) lookup(host string) (*url.URL, error) { + host = strings.Split(host, ".")[0] + v := packable.NewURL() + err := s.db.Get(nsRouting, host, v) + return v.URL(), err +} diff --git a/server/routes.go b/server/routes.go index 78591dc..fabc9d6 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,197 +1,15 @@ package server import ( - "encoding/json" - "local/s2sa/s2sa/server/router" - "local/s2sa/s2sa/token" - "net/http" - "path" - "strings" + "local/rproxy3/config" ) -const clientsNS = "clients" -const accessorsNS = "accessors" -const wildcard = "{}" -const serverNS = "server" - func (s *Server) Routes() error { - appendWildcards := func(s string, cnt int) string { - s = strings.Trim(s, "/") - return path.Join(s, strings.Repeat("/"+wildcard, cnt)) - } - paths := []struct { - base string - wildcards int - method http.HandlerFunc - }{ - { - base: "admin/register", - wildcards: 1, - method: s.adminRegister, - }, - { - base: "register", - wildcards: 1, - method: s.authenticate(s.registerClient), - }, - { - base: "generate", - wildcards: 2, - method: s.authenticate(s.generateToken), - }, - { - base: "retrieve", - wildcards: 3, - method: s.authenticate(s.retrieveToken), - }, - { - base: "revoke", - wildcards: 2, - method: s.authenticate(s.revokeToken), - }, - { - base: "lookup", - wildcards: 2, - method: s.authenticate(s.lookupToken), - }, - { - base: "policies", - wildcards: 0, - method: s.authenticate(s.getPolicies), - }, - } - for _, path := range paths { - if err := s.Add(appendWildcards(path.base, path.wildcards), path.method); err != nil { + routes := config.GetRoutes() + for k, v := range routes { + if err := s.Route(k, v); err != nil { return err } } return nil } - -func (s *Server) authenticate(foo http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - recipient, accessorToken, ok := r.BasicAuth() - if !ok { - w.WriteHeader(http.StatusUnauthorized) - return - } - accessor := strings.Split(accessorToken, ":")[0] - tokenValue := strings.Split(accessorToken, ":")[1] - - token, err := s.authdb.Get(recipient, accessor) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - return - } - - if token.Token != tokenValue { - w.WriteHeader(http.StatusUnauthorized) - return - } - if token.To != recipient { - w.WriteHeader(http.StatusUnauthorized) - return - } - - foo(w, r) - } -} - -func (s *Server) adminRegister(w http.ResponseWriter, r *http.Request) { - var name string - if err := router.Params(r, &name); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - token, err := s.authdb.New(serverNS, name) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - respondWithToken(token, w, r) -} - -func (s *Server) registerClient(w http.ResponseWriter, r *http.Request) { - var name string - if err := router.Params(r, &name); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - if err := s.db.Register(name); err != nil { - w.WriteHeader(http.StatusInternalServerError) - } -} - -func (s *Server) generateToken(w http.ResponseWriter, r *http.Request) { - var name, to string - if err := router.Params(r, &name, &to); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - token, err := s.db.New(name, to) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - respondWithToken(token, w, r) -} - -func (s *Server) retrieveToken(w http.ResponseWriter, r *http.Request) { - var creator, recipient, accessor string - if err := router.Params(r, &creator, &recipient, &accessor); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - token, err := s.db.Get(recipient, accessor) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - respondWithToken(token, w, r) -} - -func respondWithToken(token token.Basic, w http.ResponseWriter, r *http.Request) { - if err := json.NewEncoder(w).Encode(token); err != nil { - w.WriteHeader(http.StatusInternalServerError) - } -} - -func (s *Server) revokeToken(w http.ResponseWriter, r *http.Request) { - var name, accessor string - if err := router.Params(r, &name, &accessor); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - if err := s.db.Revoke(name, accessor); err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } -} - -func (s *Server) lookupToken(w http.ResponseWriter, r *http.Request) { - var creator, recipient string - if err := router.Params(r, &creator, &recipient); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - token, err := s.db.Lookup(creator, recipient) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - respondWithToken(token, w, r) -} - -func (s *Server) getPolicies(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) -} diff --git a/server/routes_test.go b/server/routes_test.go index 30b6a77..f75a53a 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -1,75 +1,15 @@ package server import ( - "bytes" - "encoding/json" - "errors" - "io/ioutil" "net/http" "net/http/httptest" "testing" ) -type badRouter struct { - accept []string -} - -func (br *badRouter) Add(p string, h http.HandlerFunc) error { - for i := range br.accept { - if br.accept[i] == p { - return nil - } - } - return errors.New("rejected path") -} - -func (br *badRouter) ServeHTTP(http.ResponseWriter, *http.Request) {} - -func TestServerRoutesBadRouter(t *testing.T) { - server, _, _ := mockServer() - br := badRouter{ - accept: make([]string, 0), - } - server.router = &br - - toAdd := []string{ - "/nil", - "/admin/register/{}", - "/register/{}", - "/generate/{}/{}", - "/retrieve/{}/{}", - "/revoke/{}/{}", - "/lookup/{}/{}", - "/policies", - } - - for _, path := range toAdd { - br.accept = append(br.accept, path) - if err := server.Routes(); err == nil { - t.Errorf("can add non-allowed routes") - } - } -} - func TestServerRoutes(t *testing.T) { - server, _, _ := mockServer() - if err := server.db.Register("a"); err != nil { - t.Fatalf("cannot register: %v", err) - } - token, err := server.db.New("a", "b") - if err != nil { - t.Fatalf("cannot new: %v", err) - } + server := mockServer() - paths := []string{ - "retrieve/a/b/" + token.Accessor, - "revoke/a/" + token.Accessor, - "lookup/a/b", - "policies", - "admin/register/a", - "register/a", - "generate/a/b", - } + paths := []string{} for _, p := range paths { w := httptest.NewRecorder() @@ -81,357 +21,6 @@ func TestServerRoutes(t *testing.T) { } } -func TestServerAdminRegister(t *testing.T) { - cases := []struct { - name string - status int - }{ - { - name: "", - status: 404, - }, - { - name: "name", - status: 200, - }, - } - - path := "admin/register" - server, _, _ := mockServer() - - for i, c := range cases { - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", path, nil) - r = addParam(r, "name", c.name) - server.ServeHTTP(w, r) - if w.Code != c.status { - t.Errorf("[%d] wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) - } - body, _ := ioutil.ReadAll(w.Body) - if len(body) < 1 { - t.Errorf("[%d] empty body in admin/register response: %q", i, body) - } - t.Logf("[%d] admin/register body: %q", i, body) - } -} - -func TestServerRegister(t *testing.T) { - cases := []struct { - name string - status int - }{ - { - name: "", - status: 404, - }, - { - name: "name", - status: 200, - }, - } - - path := "register" - server, _, _ := mockServer() - - for i, c := range cases { - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", path, nil) - r = addParam(r, "name", c.name) - auth, err := server.authdb.New(serverNS, c.name) - if err != nil { - t.Fatalf("cannot authdb new: %v", err) - } - r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token) - server.ServeHTTP(w, r) - if w.Code != c.status { - t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) - } - } -} - -func TestServerGenerate(t *testing.T) { - cases := []struct { - name string - to string - status int - }{ - { - name: "", - to: "", - status: 404, - }, - { - name: "name", - to: "", - status: 404, - }, - { - name: "", - to: "to", - status: 404, - }, - { - name: "name", - to: "to", - status: 200, - }, - } - - path := "generate" - server, _, _ := mockServer() - - for i, c := range cases { - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", path, nil) - r = addParam(r, "name", c.name) - r = addParam(r, "to", c.to) - auth, err := server.authdb.New(serverNS, c.name) - if err != nil { - t.Fatalf("cannot authdb new: %v", err) - } - r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token) - server.ServeHTTP(w, r) - if w.Code != c.status { - t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) - } - var result struct { - Token string `json:"token"` - Acc string `json:"accessor"` - TTL int `json:"TTL"` - To string `json:"to"` - } - if err := json.NewDecoder(w.Body).Decode(&result); c.status == 200 && err != nil { - t.Errorf("invalid body: %v", err) - } else if c.status == 200 { - if result.To != c.to { - t.Errorf("wrong `to` in response: got %v, want %v", result.To, c.to) - } - if len(result.Token) == 0 { - t.Errorf("empty `token` in response") - } - if len(result.Acc) == 0 { - t.Errorf("empty `accessor` in response") - } - if result.TTL < 100 { - t.Errorf("short TTL in response") - } - } - } -} - -func TestServerRetrieve(t *testing.T) { - server, defaultName, defaultAccessor := mockServer() - - cases := []struct { - name string - acc string - status int - }{ - { - name: "", - acc: "", - status: 404, - }, - { - name: defaultName, - acc: "", - status: 404, - }, - { - name: "", - acc: defaultAccessor, - status: 404, - }, - { - name: defaultName, - acc: defaultAccessor, - status: 200, - }, - { - name: "fake", - acc: "fake", - status: 400, - }, - } - - path := "retrieve" - for i, c := range cases { - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", path, nil) - r = addParam(r, "name", c.name) - r = addParam(r, "to", "to") - r = addParam(r, "accessor", c.acc) - auth, err := server.authdb.New(serverNS, c.name) - if err != nil { - t.Fatalf("cannot authdb new: %v", err) - } - r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token) - server.ServeHTTP(w, r) - if w.Code != c.status { - //t.Errorf("%d: wrong code for %s with %v: got %d, expected %d", i, path, c, w.Code, c.status) - t.Fatalf("%d: wrong code for %s with %v: got %d, expected %d", i, path, c, w.Code, c.status) - } - } -} - -func TestServerRevoke(t *testing.T) { - cases := []struct { - name string - status int - }{ - { - name: "", - status: 404, - }, - { - name: "name", - status: 200, - }, - } - - path := "register" - server, _, _ := mockServer() - - for i, c := range cases { - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", path, nil) - r = addParam(r, "name", c.name) - auth, err := server.authdb.New(serverNS, c.name) - if err != nil { - t.Fatalf("cannot authdb new: %v", err) - } - r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token) - server.ServeHTTP(w, r) - if w.Code != c.status { - t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) - } - } -} - -func TestServerLookup(t *testing.T) { - from := "from" - to := "to" - cases := []struct { - from string - to string - status int - }{ - { - from: "", - to: "", - status: 404, - }, - { - from: "", - to: to, - status: 404, - }, - { - from: from, - to: "", - status: 404, - }, - { - from: from, - to: to, - status: 200, - }, - } - - path := "lookup" - server, _, _ := mockServer() - server.db.Register(from) - generated, _ := server.db.New(from, to) - - for i, c := range cases { - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", path, nil) - r = addParam(r, "from", c.from) - r = addParam(r, "to", c.to) - auth, err := server.authdb.New(serverNS, c.from) - if err != nil { - t.Fatalf("cannot authdb new: %v", err) - } - r.SetBasicAuth(c.from, auth.Accessor+":"+auth.Token) - server.ServeHTTP(w, r) - if w.Code != c.status { - t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) - } - if w.Code != http.StatusOK { - continue - } - b, err := ioutil.ReadAll(w.Body) - if err != nil { - t.Fatalf("%d: cannot read body: %v", i, err) - } - if !bytes.Contains(b, []byte(generated.Accessor)) { - t.Errorf("%d: response didn't contain accessor: got %s, want %s", i, b, generated.Accessor) - } - } -} - func echoHTTP(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.URL.Path)) } - -func addParam(r *http.Request, key, value string) *http.Request { - r.URL.Path += "/" + value - r.Header.Set(key, value) - return r -} - -func TestServerAuthenticate(t *testing.T) { - server, _, _ := mockServer() - - name := "name" - server.authdb.Register(serverNS) - token, err := server.authdb.New(serverNS, name) - if err != nil { - t.Fatalf("cannot authdb new: %v", err) - } - - nilHandle := func(http.ResponseWriter, *http.Request) {} - authFunc := server.authenticate(nilHandle) - - cases := []struct { - name string - token string - accessor string - code int - }{ - { - name: "bad", - token: token.Token, - accessor: token.Accessor, - code: http.StatusUnauthorized, - }, - { - name: name, - token: token.Token, - accessor: "bad", - code: http.StatusUnauthorized, - }, - { - name: name, - token: "bad", - accessor: token.Accessor, - code: http.StatusUnauthorized, - }, - { - name: name, - token: token.Token, - accessor: token.Accessor, - code: http.StatusOK, - }, - } - - for i, c := range cases { - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/any", nil) - r.SetBasicAuth(c.name, c.accessor+":"+c.token) - authFunc(w, r) - if w.Code != c.code { - t.Errorf("[case %d] failed auth: got %v, wanted %v", i, w.Code, c.code) - } - } -} diff --git a/server/server.go b/server/server.go index 9ee0bba..e2946d6 100644 --- a/server/server.go +++ b/server/server.go @@ -1,51 +1,84 @@ package server import ( - "fmt" - "local/s2sa/s2sa/logg" - "local/s2sa/s2sa/server/router" - "local/s2sa/s2sa/token" + "errors" + "local/rproxy3/config" + "local/rproxy3/storage" + "local/rproxy3/storage/packable" + "log" "net/http" - "strings" + "net/url" ) -type Router interface { - Add(string, http.HandlerFunc) error - ServeHTTP(http.ResponseWriter, *http.Request) -} +const nsRouting = "routing" -type TokenDatabase interface { - Register(string) error - New(string, string) (token.Basic, error) - Get(string, string) (token.Basic, error) - Revoke(string, string) error - Lookup(string, string) (token.Basic, error) -} +type listenerScheme int -type Messenger interface { - Produce(string, interface{}) error +const ( + schemeHTTP listenerScheme = iota + schemeHTTPS listenerScheme = iota +) + +func (ls listenerScheme) String() string { + switch ls { + case schemeHTTP: + return "http" + case schemeHTTPS: + return "https" + } + return "" } type Server struct { - router Router - db TokenDatabase - authdb TokenDatabase - messenger Messenger - addr string + db storage.DB + addr string + username string + password string } -func (s *Server) Add(path string, handler http.HandlerFunc) error { - logg.Logf("Adding path %v...\n", path) - path = strings.Replace(path, wildcard, router.Wildcard, -1) - return s.router.Add(path, handler) +func (s *Server) Route(src, dst string) error { + log.Printf("Adding route %q -> %q...\n", src, dst) + u, err := url.Parse(dst) + if err != nil { + return err + } + return s.db.Set(nsRouting, src, packable.NewURL(u)) } func (s *Server) Run() error { - logg.Logf("Listening on %v...\n", s.addr) - return http.ListenAndServe(s.addr, s) + scheme := schemeHTTP + if _, _, ok := config.GetSSL(); ok { + scheme = schemeHTTPS + } + log.Printf("Listening for %v on %v...\n", scheme, s.addr) + switch scheme { + case schemeHTTP: + return http.ListenAndServe(s.addr, s) + case schemeHTTPS: + c, k, _ := config.GetSSL() + return http.ListenAndServeTLS(s.addr, c, k, s) + } + return errors.New("did not load server") +} + +func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + rusr, rpwd, ok := config.GetAuth() + if ok { + usr, pwd, ok := r.BasicAuth() + if !ok || rusr != usr || rpwd != pwd { + w.WriteHeader(http.StatusUnauthorized) + return + } + } + foo(w, r) + } +} + +func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { + return s.doAuth(foo) } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - fmt.Printf("REQ: %s\n", r.URL.Path) - s.router.ServeHTTP(w, r) + s.Pre(s.Proxy)(w, r) } diff --git a/server/server_test.go b/server/server_test.go index 6e27329..8457c42 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,25 +2,17 @@ package server import ( "fmt" - "io/ioutil" - "local/s2sa/s2sa/logg" - "local/s2sa/s2sa/server/router" - "local/s2sa/s2sa/services" - "local/s2sa/s2sa/storage" + "local/rproxy3/storage" "net/http" "net/http/httptest" - "os" "strings" "testing" ) func TestServerStart(t *testing.T) { - server, _, _ := mockServer() + server := mockServer() - checked := false - if err := server.Add("/hello/world", func(_ http.ResponseWriter, _ *http.Request) { - checked = true - }); err != nil { + if err := server.Route("world", "http://hello.localhost"+server.addr); err != nil { t.Fatalf("cannot add route: %v", err) } @@ -30,58 +22,35 @@ func TestServerStart(t *testing.T) { } }() - if resp, err := http.Get("http://localhost" + server.addr + "/hello/world"); err != nil { + r, _ := http.NewRequest("GET", "http://world.localhost"+server.addr, nil) + if _, err := (&http.Client{}).Do(r); err != nil { t.Errorf("failed to get: %v", err) - } else if resp.StatusCode != 200 { - t.Errorf("wrong status: %v", resp.StatusCode) - } else if !checked { - t.Errorf("didnt hit handler") } } -func mockServer() (*Server, string, string) { - f, _ := os.Open(os.DevNull) - logg.ConfigFile(f) +func mockServer() *Server { portServer := httptest.NewServer(nil) port := strings.Split(portServer.URL, ":")[2] portServer.Close() s := &Server{ - router: router.New(), - db: services.New(storage.NewMap()), - authdb: services.New(storage.NewMap()), - addr: ":" + port, + db: storage.NewMap(), + addr: ":" + port, } - s.authdb.Register(serverNS) if err := s.Routes(); err != nil { panic(fmt.Sprintf("cannot initiate server routes; %v", err)) } - - defaultName := "name" - if err := s.db.Register(defaultName); err != nil { - panic(fmt.Sprintf("cannot register: %v", err)) - } - token, err := s.db.New("name", "to") - if err != nil { - panic(fmt.Sprintf("cannot generate: %v", err)) - } - defaultAccessor := token.Accessor - return s, defaultName, defaultAccessor + return s } -func TestServerAdd(t *testing.T) { - server, _, _ := mockServer() - path := "/hello/world" - if err := server.Add(path, echoHTTP); err != nil { - t.Fatalf("cannot add path: %v", err) +func TestServerRoute(t *testing.T) { + server := mockServer() + if err := server.Route("world", "http://hello.localhost"+server.addr); err != nil { + t.Fatalf("cannot add route: %v", err) } w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", path, nil) + r, _ := http.NewRequest("GET", "http://world.localhost"+server.addr, nil) server.ServeHTTP(w, r) - b, err := ioutil.ReadAll(w.Body) - if err != nil { - t.Fatalf("cannot read body: %v", err) - } - if string(b) != path { - t.Errorf("cannot hit endpoint: %s", b) + if w.Code != 502 { + t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code) } } diff --git a/storage/packable/packable.go b/storage/packable/packable.go index 27b636e..8b9421a 100644 --- a/storage/packable/packable.go +++ b/storage/packable/packable.go @@ -1,5 +1,11 @@ package packable +import ( + "bytes" + "encoding/gob" + "net/url" +) + type Packable interface { Encode() ([]byte, error) Decode([]byte) error @@ -20,7 +26,50 @@ func (s *String) Decode(b []byte) error { return nil } -func NewString(s string) *String { - w := String(s) +func NewString(s ...string) *String { + if len(s) == 0 { + s = append(s, "") + } + w := String(s[0]) return &w } + +type URL struct { + Scheme string + Host string + Path string +} + +func (u *URL) URL() *url.URL { + return &url.URL{ + Scheme: u.Scheme, + Host: u.Host, + Path: u.Path, + } +} + +func (u *URL) Encode() ([]byte, error) { + buf := bytes.NewBuffer(nil) + g := gob.NewEncoder(buf) + if err := g.Encode(*u); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (u *URL) Decode(b []byte) error { + buf := bytes.NewBuffer(b) + g := gob.NewDecoder(buf) + return g.Decode(&u) +} + +func NewURL(u ...*url.URL) *URL { + if len(u) == 0 { + u = append(u, &url.URL{}) + } + return &URL{ + Scheme: u[0].Scheme, + Host: u[0].Host, + Path: u[0].Path, + } +} diff --git a/storage/packable/packable_test.go b/storage/packable/packable_test.go index f342c8d..172ad09 100644 --- a/storage/packable/packable_test.go +++ b/storage/packable/packable_test.go @@ -1,6 +1,9 @@ package packable -import "testing" +import ( + "net/url" + "testing" +) func TestPackableString(t *testing.T) { raw := "hello" @@ -21,3 +24,24 @@ func TestPackableString(t *testing.T) { t.Errorf("wrong decoded string: %v vs %v", x, raw) } } + +func TestPackableURL(t *testing.T) { + raw := &url.URL{ + Scheme: "a", + Host: "b", + Path: "c", + } + s := NewURL(raw) + + packed, err := s.Encode() + if err != nil { + t.Errorf("cannot encode URL: %v", err) + } + + x := NewURL() + if err := x.Decode(packed); err != nil { + t.Errorf("cannot decode URL: %v", err) + } else if *x != *s { + t.Errorf("wrong decoded URL: %v (%T) vs %v (%T)", x, x, s, s) + } +} diff --git a/testdata/rproxy3server.crt b/testdata/rproxy3server.crt new file mode 100644 index 0000000..524382e --- /dev/null +++ b/testdata/rproxy3server.crt @@ -0,0 +1,30 @@ +-----BEGIN CERTIFICATE----- +MIIFJDCCAwygAwIBAgIJAJnIjAlj+0HEMA0GCSqGSIb3DQEBCwUAMD4xCzAJBgNV +BAYTAlVTMQswCQYDVQQIDAJVVDEOMAwGA1UECgwFYnJlZWwxEjAQBgNVBAMMCWxv +Y2FsaG9zdDAeFw0xOTAyMTgyMzA5MjdaFw0yOTAyMTUyMzA5MjdaMD4xCzAJBgNV +BAYTAlVTMQswCQYDVQQIDAJVVDEOMAwGA1UECgwFYnJlZWwxEjAQBgNVBAMMCWxv +Y2FsaG9zdDCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAO1bmespMb8R +IRnf2jvuvKTEeLuovCmEIDQchxGEyJg2Vvpqcs8yA+7NdJzbX1I1xrH+Ne+FMlwV +obUbjxSvU6qchHqwzG8oXf+Tjr498YLl/xoP5axpUdwE4YqSXGgeTBNuvem3zjxb +Ukp74TfkyzSqgteUu66QLZlGwOcnSF+/pLoncQvqfESppUJIk0so2DH+J3TcQnBP +/jQpcVcDyl0jZzxdDT+qlHSl6P9khX8VT7FlMjVl2ztt9tPQRMcf337pl/QCLHld +MgpbAe9YnvWHsUGHHwgqh6X3R8Mq42UTAjY80k1/qmjVf0b8wymYMV7HNULShiUF +LZQBD0+t+0SD8Np4C3oZU1pf1e+hL8bSsBaEFfe0FAZIFZY2xijTEyxxah2zvKEF +Ntm6BrT6BM5rewSOk9myQochBcmzTQudnpxZjZWjG4tM+cOB7rWUBLyQQtaJLNFM +LRRkdLRAODfhNkN/RW2vCFPxlUvdVQYjB5ftopmEnL6L+8uP3N4FUGRDGvYT6WRN +zlvWe/2iV7LnFHs75OzRRrkXjSkbF2ScRKHPFroNXFsP/R70hWGwdzxLiB5UlbPi +3N3zov4+++O0dUUUo5ok2ZvTTpcIxZMN9dPZxV9+nNOgtyI3gdyzAPgxU0n1JSG+ +4/V0+QjpxiiY8zBn6mbem9LwuW2OhNUhAgMBAAGjJTAjMCEGA1UdEQQaMBiCCWxv +Y2FsaG9zdIILKi5sb2NhbGhvc3QwDQYJKoZIhvcNAQELBQADggIBALt1KCip8/a2 +eL1vyJhwrkggT18KNOv7ylIOuOSS7AoEgkppUPyGNgQIrnyzjuEEIvKAlMjSW1jV +dfRLcV4ha1GAZhQrL19vAKExcjwBoN1sB8VX8NTAcyXsIANEoGxsSkrIn9hTi2my +JyIn+fSVX82IxZ4bf9J1jmNGE2b6X0v6urgVrjzkZbsydJPlT8JAKWFHUhwXiGGi +f9FbFmkyghX4p1UrqL6hUCVuudCtNSW3r3psiwycaUSaivLv/BAWtQCUyLvPMjPa +vTxx7UokWQPIRT/ugAlX8WsgW3wSPtnSkcHC8/eOlBAhAzS7keEOaKW7SZeiby4c +ghjqEqX8k1Ptss7cuaHwmbQ5X0HA+J4aZ0/LKVacPsPPdELFujfsUxOoh/3m7mnM +hTebVPk7D3lYIYur8cAjdGxdaYPvGpJnKuRT7KKplrMHxWD8UvsTscN/cVvK7YPq +wJN6QlBjaZLuOcPZwJ72tGXV3D1cK3zETpFpHfFY6+2DgsbgziP3qGT4g7czjy7z +NF1ByoXYMBANC+bJ1u7wRWQxPHkNk2GtxgddNAogq0b+cPuX/5mdzM7nXEbRPqD9 +sQWI/9VyTJ/xdAv4UJNpzRBcFwE52M7tqFwKNLAtMYKrrVk3XfPWupo2bjfxkyyC +9qFlXy+DWku/zqx/Ytf/shwz70cHmdjY +-----END CERTIFICATE----- diff --git a/testdata/rproxy3server.key b/testdata/rproxy3server.key new file mode 100644 index 0000000..49ee223 --- /dev/null +++ b/testdata/rproxy3server.key @@ -0,0 +1,51 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIJKAIBAAKCAgEA7VuZ6ykxvxEhGd/aO+68pMR4u6i8KYQgNByHEYTImDZW+mpy +zzID7s10nNtfUjXGsf4174UyXBWhtRuPFK9TqpyEerDMbyhd/5OOvj3xguX/Gg/l +rGlR3AThipJcaB5ME2696bfOPFtSSnvhN+TLNKqC15S7rpAtmUbA5ydIX7+kuidx +C+p8RKmlQkiTSyjYMf4ndNxCcE/+NClxVwPKXSNnPF0NP6qUdKXo/2SFfxVPsWUy +NWXbO23209BExx/ffumX9AIseV0yClsB71ie9YexQYcfCCqHpfdHwyrjZRMCNjzS +TX+qaNV/RvzDKZgxXsc1QtKGJQUtlAEPT637RIPw2ngLehlTWl/V76EvxtKwFoQV +97QUBkgVljbGKNMTLHFqHbO8oQU22boGtPoEzmt7BI6T2bJChyEFybNNC52enFmN +laMbi0z5w4HutZQEvJBC1oks0UwtFGR0tEA4N+E2Q39Fba8IU/GVS91VBiMHl+2i +mYScvov7y4/c3gVQZEMa9hPpZE3OW9Z7/aJXsucUezvk7NFGuReNKRsXZJxEoc8W +ug1cWw/9HvSFYbB3PEuIHlSVs+Lc3fOi/j7747R1RRSjmiTZm9NOlwjFkw3109nF +X36c06C3IjeB3LMA+DFTSfUlIb7j9XT5COnGKJjzMGfqZt6b0vC5bY6E1SECAwEA +AQKCAgBMCdNWTZ0dK5yiEF92YbXXRwWygIy+9A/pAdaXWyVz9byJfn6HN+ugnfsk +oPZ5fLbJoLmgoNgQPfHO9iQxKTWxa3DZaTgkyBbM4HWTJn7vQ0UlEUCvqhHKXVnv +rZGi3Unb09dNP0/3b/391I/C+y3KEnHWJFS3yIKDHvJ/WstJuThJVodVnOnwiTRi +9qMRKeWQpm33dvRlzIqQJVKk4Jb8nXHeGaU75yal89yfrJFDtA0StGuQRbAk6sZu +9sKB1AkiPC0sw2GCA8QbIyqMhaRu3eiAKIxdblqEZaJ8uImegKdtvrBUmXh7GyIq +GKFg+tueFE1Quea1h2IhjvFbwk3C+FwiaOyemP/kwyuYuIMOZBCX78v0X3b+BacS +PRUISo2PfoOJsYnlcMXPEgmjGMEmlGSwQVIMzqPzE1wloyMkzP9peJ27b7IeyUlD +qlvQchG3uMPvjDEx+IuWo0OHL7wZ4tHA+8GZAcitFopnR1ZD7C9DfxzuLWjyg+M2 ++XL0fVIpal1BA5lA7MHGEjg9uiN6hYF91Dcrk7xhmoBDVGy68jZrL/e7bkknqLhq +O1cR3ZjvFmjpC+wz9vYXCUIzas+Or6ESPp3AdaFX6HQNLHrJl37zHVW7w34os5y0 +Coa1pkgxW7LBJ1NE/5vg9yUJwsEO3BHpfBvZWG+5JGmGmfbwPQKCAQEA+IkAOB8s +d5r14qsHWO/LeCdjxOx4P9Iss1nnb9TUUQELn6WBKHvINYG1wyP/eJH6XBMd8n8E +zWi/b2GsRTJWiMwG8r+tzDvCoXPOzp2w6Es+/H+r92P7SDLspoYGj/rFZg7xctcf +vnK1Ww4vdARRIhGI3DQDzfnmiSYMiqTPJvv5GLxiEmKm1yUZlo8HKHKEQcvPgHzV +aJKUetXrhaxo1xIVaGHgstyR07cjPRdu+Zu7TeVjW5rs7VvTm9TSPgDl0C8BLJ3l +6YcJ6STsJtnoaYqq+kMtcEzlrw23lemEbpb9DIAi6PTpeV+q3us46JFRfygDeBfw +hfi2hdYdHqI3SwKCAQEA9HyoQZGdXm5k0T+vqsNRY38NoaSsXpKshMI4PlAE84tR +T8MGGeJfdGaIHElHL/gorPQICzhl7P+0dg++9uqtqoGELAedwmUKGEBCU/63H1DZ +eQnlcZKfS3jqVn5u1A+IE5ZhVTTBLLz+U+jLWlwLMXnyQ0LX99UkdBDiwEdJ0fIO +znKVNwPykMh2t8hiIHqLBdrzYkkDVv/Y+mMPPY0666ajAMVRQ1MZF2ZSGO+u4vad +FIKuO4csUhA3EGXl52lGWswNdFseI4YB3VX4eYV19qHHT3U1WGmOvaoOCCs5Xovb +zuWN3HhzDqwk/mFmN60rUaFCl4zYQW6sI6VlZDjFwwKCAQAFKkO45tzzt2K4zTkf +3gvqeVcXdpqhKOsI5ytqJZAsBsuJC9V9U+U7R3DRl2Pty4racwNCUOV2p2CjHfKI +lZ06xpK2ZMll3zASTufTX2+XxEiQ0s5uG8PTIkvMwihMwbdsgY7/Bf2A5b1jIQ5X +aOeOv7KKcQJLO5BQ9Vt5Xid39fCdVPzsyOQSwJChkmNhsc+R1nx/fRiqTbGMQ/nb +E1QydAvB2Zbj4LWfIsU2tc/2nAfufH/uHXAjFq91I50i8L3V78QkezuV+Np53+yQ +eiLcXqdjABZnqRF06+IO2YWJDtY/dkJDeSOJBDX24AiUvMBm1hwMWOMExcB1U5fT +VSe1AoIBACaJR5QeOanayLVaznuiEcAXFLT32duGTv8WBft8bWtd+FEUMm/+CIG2 +b5nVywy6lt68EFCEcTxsCavaS4Vr9De99nFiOfGcL30nE81dhsu+2KiFcM74B3fD +9VvwzdNLJ22+9FST3icJGyVqujLh8zm3OQJ1FMmRpQ9OYg9grTgUaVUNJovnaWJ8 +1omdYiowZp1jp51EWOxro23EE3DmQ/VE7MiAUZDFFl1j2Wjozq4jpRrIhmIHqmq8 +95D0HsrdAyPoqe4/Nn7u7nhOxr2Q1cksMthJZ0EqUj9/AHs1JPPMI7d242du7OPR +KnuWKqB4AS36tx5gKu4VXbi4p0Sm1jUCggEBAOmQ0WqyDIbpD+mp/bKAWxUws6PV +ZnZj+OP1edRmaDwoAKcur53MVBJFINF+sob+MoEq+aqytyHydk7fhqG5bfgFK4Zf +73ZQZSwun0DQaK905s7yzkcUeQcey1GuwSr3OG6r6jmwGxPm9rt3uCjJVmhC8lUs ++iqm8X0gfAxNu+W4yigNrgnY1fAHckpN2xl9/T3OXcRwFAxblfF4Pn3y6KBAXCir +M3B+Y9xKsC4ZoUDYCw49TeAXdEBbBtjOd4Qa/rMTRx0MAgjkvD4QEC7KyNjJUqhH +upfNtQtVhPQ3dNO+8zDbWzUMsoBNKl8LLnUoYDFpDk6EeLkvdPyBmMRLuxQ= +-----END RSA PRIVATE KEY----- diff --git a/testdata/rproxy3server.pkcs12 b/testdata/rproxy3server.pkcs12 new file mode 100644 index 0000000..32378a7 Binary files /dev/null and b/testdata/rproxy3server.pkcs12 differ