diff --git a/.gitignore b/.gitignore index b402e50..320219d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.key *.crt *.pem +fproxy *.swp *.swo *.pub diff --git a/main.go b/main.go index 9f96444..be16846 100644 --- a/main.go +++ b/main.go @@ -5,8 +5,11 @@ import ( "crypto/x509" "flag" "fmt" + "io" "io/ioutil" "local1/logger" + "log" + "net" "net/http" "net/http/httputil" "net/url" @@ -16,9 +19,10 @@ import ( type Server struct { transport *http.Transport + whitelist []string } -func NewServer(addr, clientcrt, clientkey, servercrt string) (*Server, error) { +func NewServer(addr, clientcrt, clientkey, servercrt string, whitelist []string) (*Server, error) { caCert, err := ioutil.ReadFile(servercrt) if err != nil { return nil, err @@ -38,10 +42,39 @@ func NewServer(addr, clientcrt, clientkey, servercrt string) (*Server, error) { RootCAs: rootCAs, Certificates: []tls.Certificate{clientCert}, }, + //DialTLS: dialtls, }, + whitelist: whitelist, }, nil } +func dialtls(network, addr string) (net.Conn, error) { + conn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + cfg := &tls.Config{ServerName: host} + + tlsConn := tls.Client(conn, cfg) + if err := tlsConn.Handshake(); err != nil { + conn.Close() + return nil, err + } + + cs := tlsConn.ConnectionState() + cert := cs.PeerCertificates[0] + + // Verify here + cert.VerifyHostname(host) + log.Println(cert.Subject) + return tlsConn, nil +} + func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // fix scheme if necessary fixScheme(r.URL) @@ -50,12 +83,42 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { denyAccess(w) return } + if !toWhitelist(s.whitelist, r.URL.Host) { + denyAccess(w) + return + } // proxy via stuncaddsies - //logger.Log("Proxying", r.URL.String()) + logger.Log("Proxying", r.URL.String()) + s.handleHTTP(w, r) +} + +func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) { proxy := httputil.NewSingleHostReverseProxy(pathlessURL(r.URL)) proxy.Transport = s.transport + director := proxy.Director + proxy.Director = func(req *http.Request) { + director(req) + req.Host = r.URL.Host + } proxy.ServeHTTP(w, r) return + resp, err := s.transport.RoundTrip(r) + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + defer resp.Body.Close() + copyHeader(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } } func fixScheme(u *url.URL) { @@ -67,6 +130,17 @@ func fixScheme(u *url.URL) { } } +func toWhitelist(okay []string, host string) bool { + host = strings.Split(host, ":")[0] + host = strings.Replace(host, "www", "", -1) + for i := range okay { + if strings.Contains(okay[i], host) { + return true + } + } + return false +} + func fromLocalhost(addr string) bool { return strings.Contains(addr, "[::1]") || addr == "127.0.0.1" || addr == "::1" } @@ -90,6 +164,12 @@ func pathlessURL(u *url.URL) *url.URL { } } +func transfer(destination io.WriteCloser, source io.ReadCloser) { + defer destination.Close() + defer source.Close() + io.Copy(destination, source) +} + func flagEnvFallback(keyFallback map[string]string) map[string]string { results := map[string]*string{} for k, v := range keyFallback { @@ -108,17 +188,18 @@ func flagEnvFallback(keyFallback map[string]string) map[string]string { func main() { conf := flagEnvFallback(map[string]string{ - "stunaddr": "https://localhost:20018", - "clientcrt": "../../stuncaddsies/mnt/stunclient.crt", - "clientkey": "../../stuncaddsies/mnt/stunclient.key", - "servercrt": "../../stuncaddsies/mnt/stunserver.crt", + "stunaddr": "https://bel.house:20018", + "clientcrt": "/Volumes/bldisk/client.crt", + "clientkey": "/Volumes/bldisk/client.key", + "servercrt": "/Volumes/bldisk/server.crt", "port": "8888", + "whitelist": "192.168.0.86,,bel.house,,gcp.blapointe.com", }) if !strings.HasPrefix(conf["port"], ":") { conf["port"] = ":" + conf["port"] } logger.Log(conf) - server, err := NewServer(conf["stunaddr"], conf["clientcrt"], conf["clientkey"], conf["servercrt"]) + server, err := NewServer(conf["stunaddr"], conf["clientcrt"], conf["clientkey"], conf["servercrt"], strings.Split(conf["whitelist"], ",,")) if err != nil { logger.Fatal(err) }