package main import ( "crypto/tls" "crypto/x509" "flag" "fmt" "io" "io/ioutil" "log" "net" "net/http" "net/http/httputil" "net/url" "os" "strings" "time" ) type Server struct { transport *http.Transport server *http.Server whitelist []string bypass []string secure []string } func NewServer(port, addr, mecrt, mekey, tocrt, fromcrt string, whitelist, bypass, secure []string) (*Server, error) { caCert, err := ioutil.ReadFile(tocrt) if err != nil { return nil, err } rootCAs, err := x509.SystemCertPool() if err != nil { return nil, err } rootCAs.AppendCertsFromPEM(caCert) clientCert, err := tls.LoadX509KeyPair(mecrt, mekey) if err != nil { return nil, err } acceptCaCert, err := ioutil.ReadFile(fromcrt) if err != nil { return nil, err } clientCAs := x509.NewCertPool() clientCAs.AppendCertsFromPEM(acceptCaCert) s := &Server{ transport: &http.Transport{ Proxy: func(r *http.Request) (*url.URL, error) { return url.Parse(addr) }, TLSClientConfig: &tls.Config{ RootCAs: rootCAs, Certificates: []tls.Certificate{clientCert}, }, }, whitelist: whitelist, bypass: bypass, secure: secure, } s.server = &http.Server{ Addr: "11244", Handler: s, TLSConfig: &tls.Config{ ClientCAs: clientCAs, ClientAuth: tls.RequireAndVerifyClientCert, MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_CBC_SHA, }, }, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0), } return s, nil } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // fix scheme if necessary fixScheme(r) // if not from localhost if !fromLocalhost(r.RemoteAddr) { log.Print("Denying non-localhost", r.RemoteAddr) denyAccess(w) return } if !toWhitelist(s.whitelist, r.URL.Host) { log.Print("Denying non-whitelisted", r.URL.Host) denyAccess(w) return } if toWhitelist(s.bypass, r.URL.Host) { log.Print("Bypassing", r.URL.String()) s.passthrough(w, r) return } if toWhitelist(s.secure, r.URL.Host) { log.Print("Securing", r.URL.String(), r.Host) r.URL.Scheme = "https" } //log.Print("Proxying", r.URL.String(), r.Host, *r) // proxy via stuncaddsies s.handleHTTP(w, r) } func (s *Server) passthrough(w http.ResponseWriter, r *http.Request) { if r.URL.Scheme == "http" { proxy := httputil.NewSingleHostReverseProxy(pathlessURL(r.URL)) proxy.ServeHTTP(w, r) return } dest_conn, err := net.DialTimeout("tcp", r.Host, 10*time.Second) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } w.WriteHeader(http.StatusOK) hijacker, ok := w.(http.Hijacker) if !ok { http.Error(w, "Hijacking not supported", http.StatusInternalServerError) return } client_conn, _, err := hijacker.Hijack() if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) } transfer := func(destination io.WriteCloser, source io.ReadCloser) { defer destination.Close() defer source.Close() io.Copy(destination, source) } go transfer(dest_conn, client_conn) go transfer(client_conn, dest_conn) } func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) { proxy := httputil.NewSingleHostReverseProxy(pathlessURL(r.URL)) proxy.Transport = s.transport proxy.ServeHTTP(w, r) return } func copyHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv { dst.Add(k, v) } } } func fixScheme(r *http.Request) { if r.URL.Scheme == "" { r.URL.Scheme = "http" } if strings.HasSuffix(r.URL.Host, ":443") { r.URL.Scheme = "https" r.URL.Host = r.URL.Host[:len(r.URL.Host)-len(":443")] } //r.URL.Scheme = "https" } func toWhitelist(okay []string, host string) bool { host = strings.Split(host, ":")[0] host = strings.Replace(host, "www.", "", -1) host = strings.Replace(host, "http://", "", -1) host = strings.Replace(host, "https://", "", -1) hosts := strings.Split(host, ".") if len(hosts) > 1 { host = hosts[len(hosts)-2] + "." + hosts[len(hosts)-1] } for i := range okay { if strings.Contains(okay[i], host) { return true } } return false } func fromLocalhost(addr string) bool { return strings.Contains(addr, "[::1]") || strings.HasPrefix(addr, "127.0.0.1") || addr == "::1" || strings.Contains(addr, "bel.pc") || strings.Contains(addr, "192.168.0.") } func denyAccess(w http.ResponseWriter) { w.WriteHeader(http.StatusUnauthorized) fmt.Fprintln(w, "You shouldn't be here") } func pathlessURL(u *url.URL) *url.URL { return &url.URL{ Scheme: u.Scheme, Opaque: u.Opaque, User: u.User, Host: u.Host, Path: "", RawPath: "", ForceQuery: u.ForceQuery, RawQuery: u.RawQuery, Fragment: u.Fragment, } } func transfer(destination io.WriteCloser, source io.ReadCloser) { defer destination.Close() defer source.Close() io.Copy(destination, source) } func flagEnvFallback(keyFallback map[string]string) map[string]string { results := map[string]*string{} for k, v := range keyFallback { results[k] = flag.String(k, v, "") } flag.Parse() final := map[string]string{} for k := range results { if *results[k] == keyFallback[k] && os.Getenv(strings.ToUpper(k)) != "" { *results[k] = os.Getenv(strings.ToUpper(k)) } final[k] = *results[k] } return final } func oldmain() { conf := flagEnvFallback(map[string]string{ "stunaddr": "https://bel.house:20018", "mecrt": "/Volumes/bldisk/client.crt", "mekey": "/Volumes/bldisk/client.key", "tocrt": "/Volumes/bldisk/server.crt", "fromcrt": "/Volumes/bldisk/accept.crt", "port": "8888", "whitelist": "192.168.0.86,,bel.house,,minio.gcp.blapointe.com", "bypass": "plex.tv", "secure": "gcp.blapointe.com", }) if !strings.HasPrefix(conf["port"], ":") { conf["port"] = ":" + conf["port"] } whitelist := strings.Split(conf["whitelist"], ",,") bypass := strings.Split(conf["bypass"], ",,") secure := strings.Split(conf["secure"], ",,") log.Print(conf) server, err := NewServer(conf["port"], conf["stunaddr"], conf["mecrt"], conf["mekey"], conf["tocrt"], conf["fromcrt"], append(whitelist, bypass...), bypass, secure) if err != nil { log.Fatal(err) } if conf["fromcrt"] != "" { log.Fatal(server.server.ListenAndServeTLS(conf["mecrt"], conf["mekey"])) } else { log.Fatal(http.ListenAndServe(conf["port"], server)) } }