mfproxy/server.go

172 lines
3.7 KiB
Go

package main
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"strings"
"time"
"golang.org/x/time/rate"
)
type Server struct {
Transport http.RoundTripper
resolver *net.Resolver
limiter *rate.Limiter
Timeout time.Duration
dnsCache map[string]string
}
func NewServer(c *Config) *Server {
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
d := net.Dialer{Timeout: c.Timeout}
return d.DialContext(ctx, network, c.DNS)
},
}
if c.DNS == "" {
resolver = nil
}
transport := &http.Transport{}
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: c.TLSInsecure,
}
s := &Server{
limiter: c.Limiter,
Transport: transport,
Timeout: c.Timeout,
resolver: resolver,
dnsCache: map[string]string{},
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
addr, err := s.dig(ctx, addr)
if err != nil {
return nil, fmt.Errorf("err digging for dial context: %w", err)
}
return (&net.Dialer{}).DialContext(ctx, network, addr)
}
return s
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
//log.Println(r.Method, r.Host)
switch r.Method {
case http.MethodConnect:
s.Connect(w, r)
default:
s.Serve(w, r)
}
}
func (s *Server) Error(r *http.Request, w http.ResponseWriter, err error) {
log.Printf("err: %s: %v", r.URL.String(), err)
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
host, err := s.dig(r.Context(), r.Host)
if err != nil {
s.Error(r, w, err)
return
}
if host == r.Host {
s.connectHTTPReverseProxy(w, r)
return
}
dest, err := net.DialTimeout("tcp", host, s.Timeout)
if err != nil {
s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err))
return
}
w.WriteHeader(http.StatusOK)
hijacker, ok := w.(http.Hijacker)
if !ok {
s.Error(r, w, errors.New("hijack not available"))
return
}
client, _, err := hijacker.Hijack()
if err != nil {
s.Error(r, w, err)
return
}
go s.xfer(r.Context(), dest, client)
go s.xfer(r.Context(), client, dest)
}
func (s *Server) connectHTTPReverseProxy(w http.ResponseWriter, r *http.Request) {
httputil.NewSingleHostReverseProxy(r.URL).ServeHTTP(w, r)
}
func (s *Server) Serve(w http.ResponseWriter, r *http.Request) {
resp, err := s.Transport.RoundTrip(r)
if err != nil {
s.Error(r, w, err)
return
}
defer resp.Body.Close()
for k, v := range resp.Header {
//log.Println(k, v)
for _, s := range v {
w.Header().Add(k, s)
}
}
w.WriteHeader(resp.StatusCode)
s.xfer(r.Context(), w, resp.Body)
}
func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) {
defer r.Close()
io.Copy(
throttledWriter{
ctx: ctx,
w: w,
limiter: s.limiter,
},
r,
)
}
func (s *Server) dig(ctx context.Context, host string) (string, error) {
if s.resolver == nil {
return host, nil
}
search := host
search = strings.TrimPrefix(search, "https://")
search = strings.TrimPrefix(search, "http://")
if v, ok := s.dnsCache[host]; ok {
return v, nil
}
port := ""
if splithost, splitport, err := net.SplitHostPort(host); err == nil {
search = strings.TrimPrefix(strings.TrimPrefix(splithost, "http://"), "https://")
port = ":" + splitport
}
ip, err := s.resolver.LookupHost(ctx, search)
if err != nil {
return "", fmt.Errorf("failed to dns lookup %s: %v", search, err)
}
if len(ip) == 0 {
return "", errors.New("name does not resolve")
}
result := strings.TrimPrefix(ip[0], "//") + port
s.dnsCache[host] = result
//log.Printf("dug %s => %s => %s", host, search, result)
return result, nil
}