package main import ( "crypto/tls" "crypto/x509" "flag" "fmt" "io" "io/ioutil" "local1/logger" "log" "net" "net/http" "net/http/httputil" "net/url" "os" "strings" ) type Server struct { transport *http.Transport whitelist []string } func NewServer(addr, clientcrt, clientkey, servercrt string, whitelist []string) (*Server, error) { caCert, err := ioutil.ReadFile(servercrt) if err != nil { return nil, err } rootCAs := x509.NewCertPool() rootCAs.AppendCertsFromPEM(caCert) clientCert, err := tls.LoadX509KeyPair(clientcrt, clientkey) if err != nil { return nil, err } return &Server{ transport: &http.Transport{ Proxy: func(*http.Request) (*url.URL, error) { return url.Parse(addr) }, TLSClientConfig: &tls.Config{ RootCAs: rootCAs, Certificates: []tls.Certificate{clientCert}, }, //DialTLS: dialtls, }, whitelist: whitelist, }, nil } func dialtls(network, addr string) (net.Conn, error) { conn, err := net.Dial(network, addr) if err != nil { return nil, err } host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err } cfg := &tls.Config{ServerName: host} tlsConn := tls.Client(conn, cfg) if err := tlsConn.Handshake(); err != nil { conn.Close() return nil, err } cs := tlsConn.ConnectionState() cert := cs.PeerCertificates[0] // Verify here cert.VerifyHostname(host) log.Println(cert.Subject) return tlsConn, nil } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // fix scheme if necessary fixScheme(r.URL) // if not from localhost if !fromLocalhost(r.RemoteAddr) { denyAccess(w) return } if !toWhitelist(s.whitelist, r.URL.Host) { denyAccess(w) return } // proxy via stuncaddsies logger.Log("Proxying", r.URL.String()) s.handleHTTP(w, r) } func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) { proxy := httputil.NewSingleHostReverseProxy(pathlessURL(r.URL)) proxy.Transport = s.transport director := proxy.Director proxy.Director = func(req *http.Request) { director(req) req.Host = r.URL.Host } proxy.ServeHTTP(w, r) return resp, err := s.transport.RoundTrip(r) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } defer resp.Body.Close() copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) io.Copy(w, resp.Body) } func copyHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv { dst.Add(k, v) } } } func fixScheme(u *url.URL) { if u.Scheme == "" { u.Scheme = "http" if strings.Contains(u.Host, "443") { u.Scheme = "https" } } } func toWhitelist(okay []string, host string) bool { host = strings.Split(host, ":")[0] host = strings.Replace(host, "www", "", -1) for i := range okay { if strings.Contains(okay[i], host) { return true } } return false } func fromLocalhost(addr string) bool { return strings.Contains(addr, "[::1]") || addr == "127.0.0.1" || addr == "::1" } func denyAccess(w http.ResponseWriter) { w.WriteHeader(http.StatusUnauthorized) fmt.Fprintln(w, "You shouldn't be here") } func pathlessURL(u *url.URL) *url.URL { return &url.URL{ Scheme: u.Scheme, Opaque: u.Opaque, User: u.User, Host: u.Host, Path: "", RawPath: "", ForceQuery: u.ForceQuery, RawQuery: u.RawQuery, Fragment: u.Fragment, } } func transfer(destination io.WriteCloser, source io.ReadCloser) { defer destination.Close() defer source.Close() io.Copy(destination, source) } func flagEnvFallback(keyFallback map[string]string) map[string]string { results := map[string]*string{} for k, v := range keyFallback { results[k] = flag.String(k, v, "") } flag.Parse() final := map[string]string{} for k := range results { if *results[k] == keyFallback[k] && os.Getenv(strings.ToUpper(k)) != "" { *results[k] = os.Getenv(strings.ToUpper(k)) } final[k] = *results[k] } return final } func main() { conf := flagEnvFallback(map[string]string{ "stunaddr": "https://bel.house:20018", "clientcrt": "/Volumes/bldisk/client.crt", "clientkey": "/Volumes/bldisk/client.key", "servercrt": "/Volumes/bldisk/server.crt", "port": "8888", "whitelist": "192.168.0.86,,bel.house,,gcp.blapointe.com", }) if !strings.HasPrefix(conf["port"], ":") { conf["port"] = ":" + conf["port"] } logger.Log(conf) server, err := NewServer(conf["stunaddr"], conf["clientcrt"], conf["clientkey"], conf["servercrt"], strings.Split(conf["whitelist"], ",,")) if err != nil { logger.Fatal(err) } logger.Fatal(http.ListenAndServe(conf["port"], server)) }