180 lines
3.8 KiB
Go
180 lines
3.8 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
type Server struct {
|
|
Transport http.RoundTripper
|
|
resolver *net.Resolver
|
|
limiter *rate.Limiter
|
|
Timeout time.Duration
|
|
dnsCacheLock sync.Mutex
|
|
dnsCache map[string]dns
|
|
}
|
|
|
|
type dns struct {
|
|
result string
|
|
err error
|
|
ts time.Time
|
|
}
|
|
|
|
func (dns dns) ok() bool {
|
|
dur := time.Minute * 10
|
|
if dns.err != nil {
|
|
dur = time.Minute * 1
|
|
}
|
|
return time.Since(dns.ts) < dur
|
|
}
|
|
|
|
func NewServer(c *Config) *Server {
|
|
resolver := &net.Resolver{
|
|
PreferGo: true,
|
|
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
d := net.Dialer{Timeout: c.Timeout}
|
|
if c.DNS != "" {
|
|
addr = c.DNS
|
|
}
|
|
return d.DialContext(ctx, network, addr)
|
|
},
|
|
}
|
|
transport := &http.Transport{}
|
|
transport.TLSClientConfig = &tls.Config{
|
|
InsecureSkipVerify: c.TLSInsecure,
|
|
}
|
|
s := &Server{
|
|
limiter: c.Limiter,
|
|
Transport: transport,
|
|
Timeout: c.Timeout,
|
|
resolver: resolver,
|
|
dnsCache: map[string]dns{},
|
|
}
|
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
addr, err := s.dig(ctx, addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("err digging for dial context: %w", err)
|
|
}
|
|
return (&net.Dialer{}).DialContext(ctx, network, addr)
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
//log.Println(r.Method, r.Host)
|
|
switch r.Method {
|
|
case http.MethodConnect:
|
|
s.Connect(w, r)
|
|
default:
|
|
s.Serve(w, r)
|
|
}
|
|
}
|
|
|
|
func (s *Server) Error(r *http.Request, w http.ResponseWriter, err error) {
|
|
log.Printf("err: %s: %v", r.URL.String(), err)
|
|
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
s.Error(r, w, errors.New("cannot hijack"))
|
|
return
|
|
}
|
|
|
|
host, err := s.dig(r.Context(), r.Host)
|
|
if err != nil {
|
|
s.Error(r, w, err)
|
|
return
|
|
}
|
|
|
|
network := "tcp"
|
|
if n := strings.Count(host, ":"); n > 3 {
|
|
network = "tcp6"
|
|
}
|
|
dest, err := net.DialTimeout(network, host, s.Timeout)
|
|
if err != nil {
|
|
s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err))
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
client, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
s.Error(r, w, err)
|
|
return
|
|
}
|
|
|
|
go s.xfer(r.Context(), dest, client)
|
|
go s.xfer(r.Context(), client, dest)
|
|
}
|
|
|
|
func (s *Server) Serve(w http.ResponseWriter, r *http.Request) {
|
|
resp, err := s.Transport.RoundTrip(r)
|
|
if err != nil {
|
|
s.Error(r, w, err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
for k, v := range resp.Header {
|
|
//log.Println(k, v)
|
|
for _, s := range v {
|
|
w.Header().Add(k, s)
|
|
}
|
|
}
|
|
w.WriteHeader(resp.StatusCode)
|
|
s.xfer(r.Context(), w, resp.Body)
|
|
}
|
|
|
|
func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) {
|
|
defer r.Close()
|
|
io.Copy(
|
|
throttledWriter{
|
|
ctx: ctx,
|
|
w: w,
|
|
limiter: s.limiter,
|
|
},
|
|
r,
|
|
)
|
|
}
|
|
|
|
func (s *Server) dig(ctx context.Context, host string) (string, error) {
|
|
search := host
|
|
search = strings.TrimPrefix(search, "https://")
|
|
search = strings.TrimPrefix(search, "http://")
|
|
s.dnsCacheLock.Lock()
|
|
defer s.dnsCacheLock.Unlock()
|
|
if v, ok := s.dnsCache[host]; ok && v.ok() {
|
|
return v.result, v.err
|
|
}
|
|
port := ""
|
|
if splithost, splitport, err := net.SplitHostPort(host); err == nil {
|
|
search = strings.TrimPrefix(strings.TrimPrefix(splithost, "http://"), "https://")
|
|
port = ":" + splitport
|
|
}
|
|
ip, err := s.resolver.LookupHost(ctx, search)
|
|
result := ""
|
|
if len(ip) > 0 {
|
|
result = strings.TrimPrefix(ip[0], "//") + port
|
|
}
|
|
s.dnsCache[host] = dns{result: result, err: err, ts: time.Now()}
|
|
if err != nil || result == "" {
|
|
return "", fmt.Errorf("failed to dns lookup %s: %v", search, err)
|
|
}
|
|
return result, nil
|
|
}
|