From 8a90a3adda184a1d70bdbca15f90e580a7b80349 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Wed, 2 Oct 2019 09:07:59 -0600 Subject: [PATCH] Implement tcp proxy with single forward --- config/config.go | 7 +++++++ config/new.go | 6 ++++++ server/server.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/config/config.go b/config/config.go index d50034f..fd5ee3e 100644 --- a/config/config.go +++ b/config/config.go @@ -28,6 +28,13 @@ func GetRoutes() map[string]string { return m } +func GetTCP() (string, bool) { + v := packable.NewString() + conf.Get(nsConf, flagTCP, v) + tcpAddr := v.String() + return tcpAddr, notEmpty(tcpAddr) +} + func GetSSL() (string, string, bool) { v := packable.NewString() conf.Get(nsConf, flagCert, v) diff --git a/config/new.go b/config/new.go index 4bdebdf..6f0003e 100644 --- a/config/new.go +++ b/config/new.go @@ -17,6 +17,7 @@ const flagPort = "p" const flagRoutes = "r" const flagConf = "c" const flagCert = "crt" +const flagTCP = "tcp" const flagKey = "key" const flagUser = "user" const flagPass = "pass" @@ -36,6 +37,7 @@ type fileConf struct { Port string `yaml:"p"` Routes []string `yaml:"r"` CertPath string `yaml:"crt"` + TCPPath string `yaml:"tcp"` KeyPath string `yaml:"key"` Username string `yaml:"user"` Password string `yaml:"pass"` @@ -85,6 +87,9 @@ func fromFile() error { if err := conf.Set(nsConf, flagCert, packable.NewString(c.CertPath)); err != nil { return err } + if err := conf.Set(nsConf, flagTCP, packable.NewString(c.TCPPath)); err != nil { + return err + } if err := conf.Set(nsConf, flagKey, packable.NewString(c.KeyPath)); err != nil { return err } @@ -115,6 +120,7 @@ func fromFlags() error { binds = append(binds, addFlag(flagConf, "", "configuration file path")) binds = append(binds, addFlag(flagRoutes, "", "comma-separated routes to map, each as from:scheme://to.tld:port")) binds = append(binds, addFlag(flagCert, "", "path to .crt")) + binds = append(binds, addFlag(flagTCP, "", "tcp addr")) binds = append(binds, addFlag(flagKey, "", "path to .key")) binds = append(binds, addFlag(flagUser, "", "basic auth username")) binds = append(binds, addFlag(flagPass, "", "basic auth password")) diff --git a/server/server.go b/server/server.go index 249cf2f..7c9d747 100644 --- a/server/server.go +++ b/server/server.go @@ -5,10 +5,12 @@ import ( "crypto/tls" "encoding/base64" "errors" + "io" "local/rproxy3/config" "local/rproxy3/storage" "local/rproxy3/storage/packable" "log" + "net" "net/http" "net/url" "strings" @@ -24,6 +26,7 @@ type listenerScheme int const ( schemeHTTP listenerScheme = iota schemeHTTPS listenerScheme = iota + schemeTCP listenerScheme = iota ) func (ls listenerScheme) String() string { @@ -32,6 +35,8 @@ func (ls listenerScheme) String() string { return "http" case schemeHTTPS: return "https" + case schemeTCP: + return "tcp" } return "" } @@ -58,6 +63,9 @@ func (s *Server) Run() error { if _, _, ok := config.GetSSL(); ok { scheme = schemeHTTPS } + if _, ok := config.GetTCP(); ok { + scheme = schemeTCP + } log.Printf("Listening for %v on %v...\n", scheme, s.addr) switch scheme { case schemeHTTP: @@ -83,6 +91,10 @@ func (s *Server) Run() error { TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0), } return httpsServer.ListenAndServeTLS(c, k) + case schemeTCP: + log.Printf("Serve tcp") + addr, _ := config.GetTCP() + return s.ServeTCP(addr) } return errors.New("did not load server") } @@ -103,6 +115,36 @@ func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc { } } +func (s *Server) ServeTCP(addr string) error { + listen, err := net.Listen("tcp", s.addr) + if err != nil { + return err + } + for { + c, err := listen.Accept() + if err != nil { + return err + } + go func(c net.Conn) { + d, err := net.Dial("tcp", addr) + if err != nil { + log.Println(err) + return + } + go pipe(c, d) + go pipe(d, c) + }(c) + } +} + +func pipe(a, b net.Conn) { + log.Println("open pipe") + defer log.Println("close pipe") + defer a.Close() + defer b.Close() + io.Copy(a, b) +} + func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, can := context.WithTimeout(r.Context(), time.Second*time.Duration(config.GetTimeout()))