master
Bel LaPointe 2022-12-23 12:46:45 -05:00
parent 50781b6ee5
commit 569d3f1d8e
1 changed files with 8 additions and 5 deletions

View File

@ -9,6 +9,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"strings"
"time" "time"
"golang.org/x/time/rate" "golang.org/x/time/rate"
@ -44,7 +45,7 @@ func NewServer(c *Config) *Server {
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
addr, err := s.dig(ctx, addr) addr, err := s.dig(ctx, addr)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("err digging for dial context: %w", err)
} }
return (&net.Dialer{}).DialContext(ctx, network, addr) return (&net.Dialer{}).DialContext(ctx, network, addr)
} }
@ -76,7 +77,7 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) {
dest, err := net.DialTimeout("tcp", host, s.Timeout) dest, err := net.DialTimeout("tcp", host, s.Timeout)
if err != nil { if err != nil {
s.Error(r, w, err) s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err))
return return
} }
@ -129,13 +130,15 @@ func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) {
} }
func (s *Server) dig(ctx context.Context, host string) (string, error) { func (s *Server) dig(ctx context.Context, host string) (string, error) {
search := host
search = strings.TrimPrefix(search, "https://")
search = strings.TrimPrefix(search, "http://")
if v, ok := s.dnsCache[host]; ok { if v, ok := s.dnsCache[host]; ok {
return v, nil return v, nil
} }
search := host
port := "" port := ""
if splithost, splitport, err := net.SplitHostPort(host); err == nil { if splithost, splitport, err := net.SplitHostPort(host); err == nil {
search = splithost search = strings.TrimPrefix(strings.TrimPrefix(splithost, "http://"), "https://")
port = ":" + splitport port = ":" + splitport
} }
ip, err := s.resolver.LookupHost(ctx, search) ip, err := s.resolver.LookupHost(ctx, search)
@ -145,7 +148,7 @@ func (s *Server) dig(ctx context.Context, host string) (string, error) {
if len(ip) == 0 { if len(ip) == 0 {
return "", errors.New("name does not resolve") return "", errors.New("name does not resolve")
} }
result := ip[0] + port result := strings.TrimPrefix(ip[0], "//") + port
s.dnsCache[host] = result s.dnsCache[host] = result
//log.Printf("dug %s => %s => %s", host, search, result) //log.Printf("dug %s => %s => %s", host, search, result)
return result, nil return result, nil