From 4dc0f8e861bc174e49cf7fc3e562862e04fc09c0 Mon Sep 17 00:00:00 2001 From: Shengjing Zhu Date: Sat, 23 Jul 2022 17:14:01 +0800 Subject: [PATCH] refactor: only export LookupNetIP for resolver package --- conf.go | 82 +++++++++++++-------- internal/proxy/proxy.go | 4 +- internal/resolver/resolver.go | 112 ++++++++++++++++++++--------- internal/resolver/resolver_test.go | 10 +-- 4 files changed, 141 insertions(+), 67 deletions(-) diff --git a/conf.go b/conf.go index 8ff95f2..5c7f661 100644 --- a/conf.go +++ b/conf.go @@ -6,6 +6,8 @@ import ( "encoding/hex" "fmt" "net" + "net/netip" + "strconv" "time" "github.com/zhsj/wghttp/internal/resolver" @@ -13,13 +15,14 @@ import ( ) type peer struct { - dialer *net.Dialer + resolver *resolver.Resolver pubKey string psk string - addr string - ipPort string + host string + ip netip.Addr + port uint16 } func newPeerEndpoint() (*peer, error) { @@ -33,30 +36,45 @@ func newPeerEndpoint() (*peer, error) { } p := &peer{ - dialer: &net.Dialer{ - Resolver: resolver.New( - opts.ResolveDNS, - func(ctx context.Context, network, address string) (net.Conn, error) { - netConn, err := (&net.Dialer{}).DialContext(ctx, network, address) - logger.Verbosef("Using %s to resolve peer endpoint: %v", opts.ResolveDNS, err) - return netConn, err - }, - ), - }, pubKey: hex.EncodeToString(pubKey), psk: hex.EncodeToString(psk), - addr: opts.PeerEndpoint, } - p.ipPort, err = p.resolveAddr() + host, port, err := net.SplitHostPort(opts.PeerEndpoint) if err != nil { - return nil, fmt.Errorf("resolve peer endpoint: %w", err) + return nil, fmt.Errorf("parse peer endpoint: %w", err) } + port16, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil, fmt.Errorf("parse peer endpoint port: %w", err) + } + p.host = host + p.port = uint16(port16) + + p.ip, err = netip.ParseAddr(p.host) + if err == nil { + return p, nil + } + + p.resolver = resolver.New( + opts.ResolveDNS, + func(ctx context.Context, network, address string) (net.Conn, error) { + netConn, err := (&net.Dialer{}).DialContext(ctx, network, address) + logger.Verbosef("Using %s to resolve peer endpoint: %v", opts.ResolveDNS, err) + return netConn, err + }, + ) + + p.ip, err = p.resolveHost() + if err != nil { + return nil, fmt.Errorf("resolve peer endpoint ip: %w", err) + } + return p, err } func (p *peer) initConf() string { conf := "public_key=" + p.pubKey + "\n" - conf += "endpoint=" + p.ipPort + "\n" + conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n" conf += "allowed_ip=0.0.0.0/0\n" conf += "allowed_ip=::/0\n" @@ -71,30 +89,38 @@ func (p *peer) initConf() string { } func (p *peer) updateConf() (string, bool) { - newIPPort, err := p.resolveAddr() + newIP, err := p.resolveHost() if err != nil { logger.Verbosef("Resolve peer endpoint: %v", err) return "", false } - if p.ipPort == newIPPort { + if p.ip == newIP { return "", false } - p.ipPort = newIPPort - logger.Verbosef("PeerEndpoint is changed to: %s", p.ipPort) + p.ip = newIP + logger.Verbosef("PeerEndpoint is changed to: %s", p.ip) conf := "public_key=" + p.pubKey + "\n" conf += "update_only=true\n" - conf += "endpoint=" + p.ipPort + "\n" + conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n" return conf, true } -func (p *peer) resolveAddr() (string, error) { - c, err := p.dialer.Dial("udp", p.addr) +func (p *peer) resolveHost() (netip.Addr, error) { + ips, err := p.resolver.LookupNetIP(context.Background(), "ip", p.host) if err != nil { - return "", err + return netip.Addr{}, fmt.Errorf("resolve ip for %s: %w", p.host, err) } - defer c.Close() - return c.RemoteAddr().String(), nil + for _, ip := range ips { + conn, err := net.DialUDP("udp", nil, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, p.port))) + if err == nil { + conn.Close() + return ip, nil + } else { + logger.Verbosef("Dial %s: %s", ip, err) + } + } + return netip.Addr{}, fmt.Errorf("no available ip for %s", p.host) } func ipcSet(dev *device.Device) error { @@ -118,7 +144,7 @@ func ipcSet(dev *device.Device) error { return err } - if peer.addr != peer.ipPort { + if peer.resolver != nil { go func() { c := time.Tick(opts.ResolveInterval) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 4ed4771..acd8148 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -51,7 +51,7 @@ func dialWithDNS(dial dialer, dns string) dialer { } } - ips, err := resolv.LookupHost(ctx, host) + ips, err := resolv.LookupNetIP(ctx, network, host) if err != nil { return nil, err } @@ -61,7 +61,7 @@ func dialWithDNS(dial dialer, dns string) dialer { conn net.Conn ) for _, ip := range ips { - addr := net.JoinHostPort(ip, port) + addr := net.JoinHostPort(ip.String(), port) conn, lastErr = dial(ctx, network, addr) if lastErr == nil { return conn, nil diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index a8d8e7a..d161ac7 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -3,61 +3,109 @@ package resolver import ( "context" "crypto/tls" + "errors" "net" "net/http" + "net/netip" "strings" ) -// PreferGo works on Windows since go1.19, https://github.com/golang/go/issues/33097 +var errNotRetry = errors.New("not retry") -func New(addr string, dialContext func(context.Context, string, string) (net.Conn, error)) *net.Resolver { +type Resolver struct { + sysAddr, addr string + network string + tlsConfig *tls.Config + httpClient *http.Client + + r *net.Resolver +} + +func (r *Resolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { + ipNetwork := network + switch network { + case "tcp", "udp": + ipNetwork = "ip" + case "tcp4", "udp4": + ipNetwork = "ip4" + case "tcp6", "udp6": + ipNetwork = "ip6" + } + + return r.r.LookupNetIP(ctx, ipNetwork, host) +} + +func New(dns string, dial func(ctx context.Context, network, address string) (net.Conn, error)) *Resolver { + r := &Resolver{} switch { - case strings.HasPrefix(addr, "tls://"): - return &net.Resolver{ + case strings.HasPrefix(dns, "tls://"): + r.addr = withDefaultPort(dns[len("tls://"):], "853") + host, _, _ := net.SplitHostPort(r.addr) + r.tlsConfig = &tls.Config{ + ServerName: host, + } + r.r = &net.Resolver{ PreferGo: true, - Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { - address := withDefaultPort(addr[len("tls://"):], "853") - conn, err := dialContext(ctx, "tcp", address) + Dial: func(ctx context.Context, _, address string) (net.Conn, error) { + if r.sysAddr == "" { + r.sysAddr = address + } + if r.sysAddr != address { + return nil, errNotRetry + } + conn, err := dial(ctx, "tcp", r.addr) if err != nil { return nil, err } - host, _, _ := net.SplitHostPort(address) - c := tls.Client(conn, &tls.Config{ - ServerName: host, - }) - return c, nil + return tls.Client(conn, r.tlsConfig), nil }, } - case strings.HasPrefix(addr, "https://"): - c := &http.Client{ + case strings.HasPrefix(dns, "https://"): + r.httpClient = &http.Client{ Transport: &http.Transport{ - DialContext: dialContext, + DialContext: dial, }, } - return &net.Resolver{ + r.r = &net.Resolver{ PreferGo: true, - Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { - return newDoHConn(ctx, c, addr) - }, - } - case addr != "": - return &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { - address := addr - network := "udp" - - if strings.HasPrefix(addr, "tcp://") || strings.HasPrefix(addr, "udp://") { - network = addr[:len("tcp")] - address = addr[len("tcp://"):] + Dial: func(ctx context.Context, _, address string) (net.Conn, error) { + if r.sysAddr == "" { + r.sysAddr = address + } + if r.sysAddr != address { + return nil, errNotRetry } - return dialContext(ctx, network, withDefaultPort(address, "53")) + return newDoHConn(ctx, r.httpClient, dns) + }, + } + case dns != "": + r.addr = dns + r.network = "udp" + + if strings.HasPrefix(dns, "tcp://") || strings.HasPrefix(dns, "udp://") { + r.addr = dns[len("tcp://"):] + r.network = dns[:len("tcp")] + } + r.addr = withDefaultPort(r.addr, "53") + + r.r = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, address string) (net.Conn, error) { + if r.sysAddr == "" { + r.sysAddr = address + } + if r.sysAddr != address { + return nil, errNotRetry + } + + return dial(ctx, r.network, r.addr) }, } default: - return &net.Resolver{} + r.r = &net.Resolver{} } + return r } func withDefaultPort(addr, port string) string { diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go index 5819f26..8fcd174 100644 --- a/internal/resolver/resolver_test.go +++ b/internal/resolver/resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "net" "testing" ) @@ -25,14 +26,13 @@ func TestResolve(t *testing.T) { "https://223.5.5.5:443/dns-query", } { t.Run(server, func(t *testing.T) { - d := &net.Dialer{ - Resolver: New(server, (&net.Dialer{}).DialContext), - } - c, err := d.Dial("tcp4", "www.example.com:80") + r := New(server, (&net.Dialer{}).DialContext) + ips, err := r.LookupNetIP(context.TODO(), "ip4", "www.example.com") + if err != nil { t.Error(err) } else { - t.Logf("got %s", c.RemoteAddr()) + t.Logf("got %s", ips) } }) }