diff --git a/server.go b/server.go index 8f84250..09f1df2 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( "log" "net" "net/http" + "strings" "time" "golang.org/x/time/rate" @@ -44,7 +45,7 @@ func NewServer(c *Config) *Server { transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { addr, err := s.dig(ctx, addr) if err != nil { - return nil, err + return nil, fmt.Errorf("err digging for dial context: %w", err) } return (&net.Dialer{}).DialContext(ctx, network, addr) } @@ -76,7 +77,7 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { dest, err := net.DialTimeout("tcp", host, s.Timeout) if err != nil { - s.Error(r, w, err) + s.Error(r, w, fmt.Errorf("error dialing w timeout %s=>%s: %w", r.Host, host, err)) return } @@ -129,13 +130,15 @@ func (s *Server) xfer(ctx context.Context, w io.Writer, r io.ReadCloser) { } func (s *Server) dig(ctx context.Context, host string) (string, error) { + search := host + search = strings.TrimPrefix(search, "https://") + search = strings.TrimPrefix(search, "http://") if v, ok := s.dnsCache[host]; ok { return v, nil } - search := host port := "" if splithost, splitport, err := net.SplitHostPort(host); err == nil { - search = splithost + search = strings.TrimPrefix(strings.TrimPrefix(splithost, "http://"), "https://") port = ":" + splitport } ip, err := s.resolver.LookupHost(ctx, search) @@ -145,7 +148,7 @@ func (s *Server) dig(ctx context.Context, host string) (string, error) { if len(ip) == 0 { return "", errors.New("name does not resolve") } - result := ip[0] + port + result := strings.TrimPrefix(ip[0], "//") + port s.dnsCache[host] = result //log.Printf("dug %s => %s => %s", host, search, result) return result, nil