diff --git a/config/config.go b/config/config.go index d50034f..43a340b 100644 --- a/config/config.go +++ b/config/config.go @@ -106,3 +106,13 @@ func GetRewrites(hostMatch string) map[string]string { } return m } + +func GetProxyMode() string { + v := packable.NewString() + conf.Get(nsConf, flagMode, v) + s := v.String() + if s == "" { + return "domain" + } + return s +} diff --git a/config/new.go b/config/new.go index 4bdebdf..2b559fd 100644 --- a/config/new.go +++ b/config/new.go @@ -14,6 +14,7 @@ import ( const nsConf = "configuration" const flagPort = "p" +const flagMode = "mode" const flagRoutes = "r" const flagConf = "c" const flagCert = "crt" @@ -34,6 +35,7 @@ type toBind struct { type fileConf struct { Port string `yaml:"p"` + Mode string `yaml:"mode"` Routes []string `yaml:"r"` CertPath string `yaml:"crt"` KeyPath string `yaml:"key"` @@ -79,6 +81,9 @@ func fromFile() error { if err := conf.Set(nsConf, flagPort, packable.NewString(c.Port)); err != nil { return err } + if err := conf.Set(nsConf, flagMode, packable.NewString(c.Mode)); err != nil { + return err + } if err := conf.Set(nsConf, flagRoutes, packable.NewString(strings.Join(c.Routes, ","))); err != nil { return err } @@ -112,6 +117,7 @@ func fromFile() error { func fromFlags() error { binds := make([]toBind, 0) binds = append(binds, addFlag(flagPort, "51555", "port to bind to")) + binds = append(binds, addFlag(flagMode, "domain", "[domain] or [path] to match")) binds = append(binds, addFlag(flagConf, "", "configuration file path")) binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port")) binds = append(binds, addFlag(flagCert, "", "path to .crt")) diff --git a/server/proxy.go b/server/proxy.go index 6ee58cc..67556d9 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -24,7 +24,7 @@ type rewrite struct { } func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { - newURL, err := s.lookup(mapKey(r.Host)) + newURL, err := s.lookup(mapKey(r, config.GetProxyMode())) var transport http.RoundTripper transport = &redirPurge{ proxyHost: r.Host, @@ -32,7 +32,7 @@ func (s *Server) Proxy(w http.ResponseWriter, r *http.Request) { baseTransport: http.DefaultTransport, } transport = &rewrite{ - rewrites: config.GetRewrites(mapKey(r.Host)), + rewrites: config.GetRewrites(mapKey(r, config.GetProxyMode())), baseTransport: transport, } if err != nil { @@ -52,10 +52,20 @@ func (s *Server) lookup(host string) (*url.URL, error) { return v.URL(), err } -func mapKey(host string) string { - host = strings.Split(host, ".")[0] - host = strings.Split(host, ":")[0] - return host +func mapKey(r *http.Request, proxyMode string) string { + switch proxyMode { + case "domain": + host := strings.Split(r.Host, ".")[0] + host = strings.Split(host, ":")[0] + return host + case "path": + paths := strings.Split(r.URL.Path, "/") + if len(paths) < 2 { + return "" + } + return paths[1] + } + return "" } func (rp *redirPurge) RoundTrip(r *http.Request) (*http.Response, error) { diff --git a/server/proxy_test.go b/server/proxy_test.go index ca1c1f1..2c75be9 100644 --- a/server/proxy_test.go +++ b/server/proxy_test.go @@ -3,6 +3,7 @@ package server import ( "io/ioutil" "net/http" + "net/url" "strings" "testing" ) @@ -40,3 +41,35 @@ func TestRewrite(t *testing.T) { t.Errorf("failed to replace: got %q, want \"b\"", b) } } + +func TestMapKey(t *testing.T) { + r := &http.Request{ + Host: "a.b.c:123", + URL: &url.URL{ + Path: "/c/d/e", + }, + } + + if v := mapKey(r, "domain"); v != "a" { + t.Errorf("failed to get domain: got %v", v) + } + + if v := mapKey(r, "path"); v != "c" { + t.Errorf("failed to get domain: got %v", v) + } + + r.Host = "a:123" + if v := mapKey(r, "domain"); v != "a" { + t.Errorf("failed to get domain: got %v", v) + } + + r.URL.Path = "" + if v := mapKey(r, "path"); v != "" { + t.Errorf("failed to get domain: got %v", v) + } + + r.URL.Path = "/" + if v := mapKey(r, "path"); v != "" { + t.Errorf("failed to get domain: got %v", v) + } +}