refactor: only export LookupNetIP for resolver package

master
Shengjing Zhu 2022-07-23 17:14:01 +08:00
parent 9298006bc7
commit 4dc0f8e861
4 changed files with 141 additions and 67 deletions

76
conf.go
View File

@ -6,6 +6,8 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"strconv"
"time" "time"
"github.com/zhsj/wghttp/internal/resolver" "github.com/zhsj/wghttp/internal/resolver"
@ -13,13 +15,14 @@ import (
) )
type peer struct { type peer struct {
dialer *net.Dialer resolver *resolver.Resolver
pubKey string pubKey string
psk string psk string
addr string host string
ipPort string ip netip.Addr
port uint16
} }
func newPeerEndpoint() (*peer, error) { func newPeerEndpoint() (*peer, error) {
@ -33,30 +36,45 @@ func newPeerEndpoint() (*peer, error) {
} }
p := &peer{ p := &peer{
dialer: &net.Dialer{ pubKey: hex.EncodeToString(pubKey),
Resolver: resolver.New( psk: hex.EncodeToString(psk),
}
host, port, err := net.SplitHostPort(opts.PeerEndpoint)
if err != nil {
return nil, fmt.Errorf("parse peer endpoint: %w", err)
}
port16, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, fmt.Errorf("parse peer endpoint port: %w", err)
}
p.host = host
p.port = uint16(port16)
p.ip, err = netip.ParseAddr(p.host)
if err == nil {
return p, nil
}
p.resolver = resolver.New(
opts.ResolveDNS, opts.ResolveDNS,
func(ctx context.Context, network, address string) (net.Conn, error) { func(ctx context.Context, network, address string) (net.Conn, error) {
netConn, err := (&net.Dialer{}).DialContext(ctx, network, address) netConn, err := (&net.Dialer{}).DialContext(ctx, network, address)
logger.Verbosef("Using %s to resolve peer endpoint: %v", opts.ResolveDNS, err) logger.Verbosef("Using %s to resolve peer endpoint: %v", opts.ResolveDNS, err)
return netConn, err return netConn, err
}, },
), )
},
pubKey: hex.EncodeToString(pubKey), p.ip, err = p.resolveHost()
psk: hex.EncodeToString(psk),
addr: opts.PeerEndpoint,
}
p.ipPort, err = p.resolveAddr()
if err != nil { if err != nil {
return nil, fmt.Errorf("resolve peer endpoint: %w", err) return nil, fmt.Errorf("resolve peer endpoint ip: %w", err)
} }
return p, err return p, err
} }
func (p *peer) initConf() string { func (p *peer) initConf() string {
conf := "public_key=" + p.pubKey + "\n" conf := "public_key=" + p.pubKey + "\n"
conf += "endpoint=" + p.ipPort + "\n" conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n"
conf += "allowed_ip=0.0.0.0/0\n" conf += "allowed_ip=0.0.0.0/0\n"
conf += "allowed_ip=::/0\n" conf += "allowed_ip=::/0\n"
@ -71,30 +89,38 @@ func (p *peer) initConf() string {
} }
func (p *peer) updateConf() (string, bool) { func (p *peer) updateConf() (string, bool) {
newIPPort, err := p.resolveAddr() newIP, err := p.resolveHost()
if err != nil { if err != nil {
logger.Verbosef("Resolve peer endpoint: %v", err) logger.Verbosef("Resolve peer endpoint: %v", err)
return "", false return "", false
} }
if p.ipPort == newIPPort { if p.ip == newIP {
return "", false return "", false
} }
p.ipPort = newIPPort p.ip = newIP
logger.Verbosef("PeerEndpoint is changed to: %s", p.ipPort) logger.Verbosef("PeerEndpoint is changed to: %s", p.ip)
conf := "public_key=" + p.pubKey + "\n" conf := "public_key=" + p.pubKey + "\n"
conf += "update_only=true\n" conf += "update_only=true\n"
conf += "endpoint=" + p.ipPort + "\n" conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n"
return conf, true return conf, true
} }
func (p *peer) resolveAddr() (string, error) { func (p *peer) resolveHost() (netip.Addr, error) {
c, err := p.dialer.Dial("udp", p.addr) ips, err := p.resolver.LookupNetIP(context.Background(), "ip", p.host)
if err != nil { if err != nil {
return "", err return netip.Addr{}, fmt.Errorf("resolve ip for %s: %w", p.host, err)
} }
defer c.Close() for _, ip := range ips {
return c.RemoteAddr().String(), nil conn, err := net.DialUDP("udp", nil, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, p.port)))
if err == nil {
conn.Close()
return ip, nil
} else {
logger.Verbosef("Dial %s: %s", ip, err)
}
}
return netip.Addr{}, fmt.Errorf("no available ip for %s", p.host)
} }
func ipcSet(dev *device.Device) error { func ipcSet(dev *device.Device) error {
@ -118,7 +144,7 @@ func ipcSet(dev *device.Device) error {
return err return err
} }
if peer.addr != peer.ipPort { if peer.resolver != nil {
go func() { go func() {
c := time.Tick(opts.ResolveInterval) c := time.Tick(opts.ResolveInterval)

View File

@ -51,7 +51,7 @@ func dialWithDNS(dial dialer, dns string) dialer {
} }
} }
ips, err := resolv.LookupHost(ctx, host) ips, err := resolv.LookupNetIP(ctx, network, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,7 +61,7 @@ func dialWithDNS(dial dialer, dns string) dialer {
conn net.Conn conn net.Conn
) )
for _, ip := range ips { for _, ip := range ips {
addr := net.JoinHostPort(ip, port) addr := net.JoinHostPort(ip.String(), port)
conn, lastErr = dial(ctx, network, addr) conn, lastErr = dial(ctx, network, addr)
if lastErr == nil { if lastErr == nil {
return conn, nil return conn, nil

View File

@ -3,61 +3,109 @@ package resolver
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"net" "net"
"net/http" "net/http"
"net/netip"
"strings" "strings"
) )
// PreferGo works on Windows since go1.19, https://github.com/golang/go/issues/33097 var errNotRetry = errors.New("not retry")
func New(addr string, dialContext func(context.Context, string, string) (net.Conn, error)) *net.Resolver { type Resolver struct {
sysAddr, addr string
network string
tlsConfig *tls.Config
httpClient *http.Client
r *net.Resolver
}
func (r *Resolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
ipNetwork := network
switch network {
case "tcp", "udp":
ipNetwork = "ip"
case "tcp4", "udp4":
ipNetwork = "ip4"
case "tcp6", "udp6":
ipNetwork = "ip6"
}
return r.r.LookupNetIP(ctx, ipNetwork, host)
}
func New(dns string, dial func(ctx context.Context, network, address string) (net.Conn, error)) *Resolver {
r := &Resolver{}
switch { switch {
case strings.HasPrefix(addr, "tls://"): case strings.HasPrefix(dns, "tls://"):
return &net.Resolver{ r.addr = withDefaultPort(dns[len("tls://"):], "853")
host, _, _ := net.SplitHostPort(r.addr)
r.tlsConfig = &tls.Config{
ServerName: host,
}
r.r = &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { Dial: func(ctx context.Context, _, address string) (net.Conn, error) {
address := withDefaultPort(addr[len("tls://"):], "853") if r.sysAddr == "" {
conn, err := dialContext(ctx, "tcp", address) r.sysAddr = address
}
if r.sysAddr != address {
return nil, errNotRetry
}
conn, err := dial(ctx, "tcp", r.addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
host, _, _ := net.SplitHostPort(address) return tls.Client(conn, r.tlsConfig), nil
c := tls.Client(conn, &tls.Config{
ServerName: host,
})
return c, nil
}, },
} }
case strings.HasPrefix(addr, "https://"): case strings.HasPrefix(dns, "https://"):
c := &http.Client{ r.httpClient = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: dialContext, DialContext: dial,
}, },
} }
return &net.Resolver{ r.r = &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { Dial: func(ctx context.Context, _, address string) (net.Conn, error) {
return newDoHConn(ctx, c, addr) if r.sysAddr == "" {
r.sysAddr = address
}
if r.sysAddr != address {
return nil, errNotRetry
}
return newDoHConn(ctx, r.httpClient, dns)
}, },
} }
case addr != "": case dns != "":
return &net.Resolver{ r.addr = dns
PreferGo: true, r.network = "udp"
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
address := addr
network := "udp"
if strings.HasPrefix(addr, "tcp://") || strings.HasPrefix(addr, "udp://") { if strings.HasPrefix(dns, "tcp://") || strings.HasPrefix(dns, "udp://") {
network = addr[:len("tcp")] r.addr = dns[len("tcp://"):]
address = addr[len("tcp://"):] r.network = dns[:len("tcp")]
}
r.addr = withDefaultPort(r.addr, "53")
r.r = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, _, address string) (net.Conn, error) {
if r.sysAddr == "" {
r.sysAddr = address
}
if r.sysAddr != address {
return nil, errNotRetry
} }
return dialContext(ctx, network, withDefaultPort(address, "53")) return dial(ctx, r.network, r.addr)
}, },
} }
default: default:
return &net.Resolver{} r.r = &net.Resolver{}
} }
return r
} }
func withDefaultPort(addr, port string) string { func withDefaultPort(addr, port string) string {

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"net" "net"
"testing" "testing"
) )
@ -25,14 +26,13 @@ func TestResolve(t *testing.T) {
"https://223.5.5.5:443/dns-query", "https://223.5.5.5:443/dns-query",
} { } {
t.Run(server, func(t *testing.T) { t.Run(server, func(t *testing.T) {
d := &net.Dialer{ r := New(server, (&net.Dialer{}).DialContext)
Resolver: New(server, (&net.Dialer{}).DialContext), ips, err := r.LookupNetIP(context.TODO(), "ip4", "www.example.com")
}
c, err := d.Dial("tcp4", "www.example.com:80")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else { } else {
t.Logf("got %s", c.RemoteAddr()) t.Logf("got %s", ips)
} }
}) })
} }