fproxy/main.go

218 lines
5.1 KiB
Go

package main
import (
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"io"
"io/ioutil"
"local1/logger"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"time"
)
type Server struct {
transport *http.Transport
whitelist []string
bypass []string
}
func NewServer(addr, clientcrt, clientkey, servercrt string, whitelist []string, bypass []string) (*Server, error) {
caCert, err := ioutil.ReadFile(servercrt)
if err != nil {
return nil, err
}
rootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
rootCAs.AppendCertsFromPEM(caCert)
clientCert, err := tls.LoadX509KeyPair(clientcrt, clientkey)
if err != nil {
return nil, err
}
return &Server{
transport: &http.Transport{
Proxy: func(r *http.Request) (*url.URL, error) {
return url.Parse(addr)
},
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
Certificates: []tls.Certificate{clientCert},
},
},
whitelist: whitelist,
bypass: bypass,
}, nil
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// fix scheme if necessary
fixScheme(r.URL)
// if not from localhost
if !fromLocalhost(r.RemoteAddr) {
denyAccess(w)
return
}
if !toWhitelist(s.whitelist, r.URL.Host) {
denyAccess(w)
return
}
if toWhitelist(s.bypass, r.URL.Host) {
//logger.Log("Bypassing", r.URL.String())
s.passthrough(w, r)
return
}
//logger.Log("Proxying", r.URL.String())
// proxy via stuncaddsies
s.handleHTTP(w, r)
}
func (s *Server) passthrough(w http.ResponseWriter, r *http.Request) {
if r.URL.Scheme == "http" {
proxy := httputil.NewSingleHostReverseProxy(pathlessURL(r.URL))
proxy.ServeHTTP(w, r)
return
}
dest_conn, err := net.DialTimeout("tcp", r.Host, 10*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
return
}
client_conn, _, err := hijacker.Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
}
transfer := func(destination io.WriteCloser, source io.ReadCloser) {
defer destination.Close()
defer source.Close()
io.Copy(destination, source)
}
go transfer(dest_conn, client_conn)
go transfer(client_conn, dest_conn)
}
func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) {
proxy := httputil.NewSingleHostReverseProxy(pathlessURL(r.URL))
proxy.Transport = s.transport
proxy.ServeHTTP(w, r)
return
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func fixScheme(u *url.URL) {
if u.Scheme == "" {
u.Scheme = "http"
}
if strings.HasSuffix(u.Host, ":443") {
u.Scheme = "https"
u.Host = u.Host[:len(u.Host)-len(":443")]
}
}
func toWhitelist(okay []string, host string) bool {
host = strings.Split(host, ":")[0]
host = strings.Replace(host, "www.", "", -1)
host = strings.Replace(host, "http://", "", -1)
host = strings.Replace(host, "https://", "", -1)
hosts := strings.Split(host, ".")
if len(hosts) > 1 {
host = hosts[len(hosts)-2] + "." + hosts[len(hosts)-1]
}
for i := range okay {
if strings.Contains(okay[i], host) {
return true
}
}
return false
}
func fromLocalhost(addr string) bool {
return strings.Contains(addr, "[::1]") || addr == "127.0.0.1" || addr == "::1"
}
func denyAccess(w http.ResponseWriter) {
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprintln(w, "You shouldn't be here")
}
func pathlessURL(u *url.URL) *url.URL {
return &url.URL{
Scheme: u.Scheme,
Opaque: u.Opaque,
User: u.User,
Host: u.Host,
Path: "",
RawPath: "",
ForceQuery: u.ForceQuery,
RawQuery: u.RawQuery,
Fragment: u.Fragment,
}
}
func transfer(destination io.WriteCloser, source io.ReadCloser) {
defer destination.Close()
defer source.Close()
io.Copy(destination, source)
}
func flagEnvFallback(keyFallback map[string]string) map[string]string {
results := map[string]*string{}
for k, v := range keyFallback {
results[k] = flag.String(k, v, "")
}
flag.Parse()
final := map[string]string{}
for k := range results {
if *results[k] == keyFallback[k] && os.Getenv(strings.ToUpper(k)) != "" {
*results[k] = os.Getenv(strings.ToUpper(k))
}
final[k] = *results[k]
}
return final
}
func main() {
conf := flagEnvFallback(map[string]string{
"stunaddr": "https://bel.house:20018",
"clientcrt": "/Volumes/bldisk/client.crt",
"clientkey": "/Volumes/bldisk/client.key",
"servercrt": "/Volumes/bldisk/server.crt",
"port": "8888",
"whitelist": "192.168.0.86,,bel.house,,minio.gcp.blapointe.com",
"bypass": "plex.tv",
})
if !strings.HasPrefix(conf["port"], ":") {
conf["port"] = ":" + conf["port"]
}
whitelist := strings.Split(conf["whitelist"], ",,")
bypass := strings.Split(conf["bypass"], ",,")
logger.Log(conf)
server, err := NewServer(conf["stunaddr"], conf["clientcrt"], conf["clientkey"], conf["servercrt"], append(whitelist, bypass...), bypass)
if err != nil {
logger.Fatal(err)
}
logger.Fatal(http.ListenAndServe(conf["port"], server))
}