From efad96aa36a6fc4d9f492c5594ce7c49e1dbeed5 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Wed, 7 Oct 2020 21:41:03 -0600 Subject: [PATCH] basic bitch --- config.go | 31 +++++++++++++++++++ main.go | 16 ++++++++++ server.go | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 config.go create mode 100644 main.go create mode 100644 server.go diff --git a/config.go b/config.go new file mode 100644 index 0000000..5e21e69 --- /dev/null +++ b/config.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" + "local/args" + "time" +) + +type Config struct { + Listen string + Timeout time.Duration + TLSInsecure bool +} + +func NewConfig() *Config { + as := args.NewArgSet() + + as.Append(args.INT, "p", "port to listen on", 61113) + as.Append(args.BOOL, "tls-insecure", "permit tls insecure", false) + as.Append(args.DURATION, "t", "timeout", time.Minute) + + if err := as.Parse(); err != nil { + panic(err) + } + + return &Config{ + Listen: fmt.Sprintf(":%v", as.GetInt("p")), + Timeout: as.GetDuration("t"), + TLSInsecure: as.GetBool("tls-insecure"), + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..25fa252 --- /dev/null +++ b/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "log" + "net/http" +) + +func main() { + config := NewConfig() + server := NewServer(config) + + log.Printf("config: %+v", *config) + if err := http.ListenAndServe(config.Listen, server); err != nil { + panic(err) + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..cda5104 --- /dev/null +++ b/server.go @@ -0,0 +1,89 @@ +package main + +import ( + "crypto/tls" + "errors" + "io" + "log" + "net" + "net/http" + "time" +) + +type Server struct { + Transport http.RoundTripper + Timeout time.Duration +} + +func NewServer(c *Config) *Server { + transport := &http.Transport{} + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: c.TLSInsecure, + } + return &Server{ + Transport: transport, + Timeout: c.Timeout, + } +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodConnect: + s.Connect(w, r) + default: + s.Serve(w, r) + } +} + +func (s *Server) Error(w http.ResponseWriter, err error) { + log.Println(err) + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return +} + +func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { + dest, err := net.DialTimeout("tcp", r.Host, 30*time.Second) + if err != nil { + s.Error(w, err) + return + } + + w.WriteHeader(http.StatusOK) + hijacker, ok := w.(http.Hijacker) + if !ok { + s.Error(w, errors.New("hijack not available")) + return + } + + client, _, err := hijacker.Hijack() + if err != nil { + s.Error(w, err) + return + } + + xfer := func(dst io.WriteCloser, src io.ReadCloser) { + defer dst.Close() + defer src.Close() + io.Copy(dst, src) + } + + go xfer(dest, client) + go xfer(client, dest) +} + +func (s *Server) Serve(w http.ResponseWriter, r *http.Request) { + resp, err := s.Transport.RoundTrip(r) + if err != nil { + s.Error(w, err) + return + } + defer resp.Body.Close() + + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, s := range v { + w.Header().Add(k, s) + } + } + io.Copy(w, resp.Body) +}