master
Bel LaPointe 2018-10-12 09:53:27 -06:00
parent 2bdf518440
commit c429c4cc34
2 changed files with 89 additions and 7 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
*.key *.key
*.crt *.crt
*.pem *.pem
fproxy
*.swp *.swp
*.swo *.swo
*.pub *.pub

95
main.go
View File

@ -5,8 +5,11 @@ import (
"crypto/x509" "crypto/x509"
"flag" "flag"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"local1/logger" "local1/logger"
"log"
"net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
@ -16,9 +19,10 @@ import (
type Server struct { type Server struct {
transport *http.Transport transport *http.Transport
whitelist []string
} }
func NewServer(addr, clientcrt, clientkey, servercrt string) (*Server, error) { func NewServer(addr, clientcrt, clientkey, servercrt string, whitelist []string) (*Server, error) {
caCert, err := ioutil.ReadFile(servercrt) caCert, err := ioutil.ReadFile(servercrt)
if err != nil { if err != nil {
return nil, err return nil, err
@ -38,10 +42,39 @@ func NewServer(addr, clientcrt, clientkey, servercrt string) (*Server, error) {
RootCAs: rootCAs, RootCAs: rootCAs,
Certificates: []tls.Certificate{clientCert}, Certificates: []tls.Certificate{clientCert},
}, },
//DialTLS: dialtls,
}, },
whitelist: whitelist,
}, nil }, nil
} }
func dialtls(network, addr string) (net.Conn, error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
cfg := &tls.Config{ServerName: host}
tlsConn := tls.Client(conn, cfg)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, err
}
cs := tlsConn.ConnectionState()
cert := cs.PeerCertificates[0]
// Verify here
cert.VerifyHostname(host)
log.Println(cert.Subject)
return tlsConn, nil
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// fix scheme if necessary // fix scheme if necessary
fixScheme(r.URL) fixScheme(r.URL)
@ -50,12 +83,42 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
denyAccess(w) denyAccess(w)
return return
} }
if !toWhitelist(s.whitelist, r.URL.Host) {
denyAccess(w)
return
}
// proxy via stuncaddsies // proxy via stuncaddsies
//logger.Log("Proxying", r.URL.String()) logger.Log("Proxying", r.URL.String())
s.handleHTTP(w, r)
}
func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) {
proxy := httputil.NewSingleHostReverseProxy(pathlessURL(r.URL)) proxy := httputil.NewSingleHostReverseProxy(pathlessURL(r.URL))
proxy.Transport = s.transport proxy.Transport = s.transport
director := proxy.Director
proxy.Director = func(req *http.Request) {
director(req)
req.Host = r.URL.Host
}
proxy.ServeHTTP(w, r) proxy.ServeHTTP(w, r)
return return
resp, err := s.transport.RoundTrip(r)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer resp.Body.Close()
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
} }
func fixScheme(u *url.URL) { func fixScheme(u *url.URL) {
@ -67,6 +130,17 @@ func fixScheme(u *url.URL) {
} }
} }
func toWhitelist(okay []string, host string) bool {
host = strings.Split(host, ":")[0]
host = strings.Replace(host, "www", "", -1)
for i := range okay {
if strings.Contains(okay[i], host) {
return true
}
}
return false
}
func fromLocalhost(addr string) bool { func fromLocalhost(addr string) bool {
return strings.Contains(addr, "[::1]") || addr == "127.0.0.1" || addr == "::1" return strings.Contains(addr, "[::1]") || addr == "127.0.0.1" || addr == "::1"
} }
@ -90,6 +164,12 @@ func pathlessURL(u *url.URL) *url.URL {
} }
} }
func transfer(destination io.WriteCloser, source io.ReadCloser) {
defer destination.Close()
defer source.Close()
io.Copy(destination, source)
}
func flagEnvFallback(keyFallback map[string]string) map[string]string { func flagEnvFallback(keyFallback map[string]string) map[string]string {
results := map[string]*string{} results := map[string]*string{}
for k, v := range keyFallback { for k, v := range keyFallback {
@ -108,17 +188,18 @@ func flagEnvFallback(keyFallback map[string]string) map[string]string {
func main() { func main() {
conf := flagEnvFallback(map[string]string{ conf := flagEnvFallback(map[string]string{
"stunaddr": "https://localhost:20018", "stunaddr": "https://bel.house:20018",
"clientcrt": "../../stuncaddsies/mnt/stunclient.crt", "clientcrt": "/Volumes/bldisk/client.crt",
"clientkey": "../../stuncaddsies/mnt/stunclient.key", "clientkey": "/Volumes/bldisk/client.key",
"servercrt": "../../stuncaddsies/mnt/stunserver.crt", "servercrt": "/Volumes/bldisk/server.crt",
"port": "8888", "port": "8888",
"whitelist": "192.168.0.86,,bel.house,,gcp.blapointe.com",
}) })
if !strings.HasPrefix(conf["port"], ":") { if !strings.HasPrefix(conf["port"], ":") {
conf["port"] = ":" + conf["port"] conf["port"] = ":" + conf["port"]
} }
logger.Log(conf) logger.Log(conf)
server, err := NewServer(conf["stunaddr"], conf["clientcrt"], conf["clientkey"], conf["servercrt"]) server, err := NewServer(conf["stunaddr"], conf["clientcrt"], conf["clientkey"], conf["servercrt"], strings.Split(conf["whitelist"], ",,"))
if err != nil { if err != nil {
logger.Fatal(err) logger.Fatal(err)
} }