refactor: only export LookupNetIP for resolver package
parent
9298006bc7
commit
4dc0f8e861
82
conf.go
82
conf.go
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zhsj/wghttp/internal/resolver"
|
"github.com/zhsj/wghttp/internal/resolver"
|
||||||
|
|
@ -13,13 +15,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type peer struct {
|
type peer struct {
|
||||||
dialer *net.Dialer
|
resolver *resolver.Resolver
|
||||||
|
|
||||||
pubKey string
|
pubKey string
|
||||||
psk string
|
psk string
|
||||||
|
|
||||||
addr string
|
host string
|
||||||
ipPort string
|
ip netip.Addr
|
||||||
|
port uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeerEndpoint() (*peer, error) {
|
func newPeerEndpoint() (*peer, error) {
|
||||||
|
|
@ -33,30 +36,45 @@ func newPeerEndpoint() (*peer, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
p := &peer{
|
p := &peer{
|
||||||
dialer: &net.Dialer{
|
|
||||||
Resolver: resolver.New(
|
|
||||||
opts.ResolveDNS,
|
|
||||||
func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
netConn, err := (&net.Dialer{}).DialContext(ctx, network, address)
|
|
||||||
logger.Verbosef("Using %s to resolve peer endpoint: %v", opts.ResolveDNS, err)
|
|
||||||
return netConn, err
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
pubKey: hex.EncodeToString(pubKey),
|
pubKey: hex.EncodeToString(pubKey),
|
||||||
psk: hex.EncodeToString(psk),
|
psk: hex.EncodeToString(psk),
|
||||||
addr: opts.PeerEndpoint,
|
|
||||||
}
|
}
|
||||||
p.ipPort, err = p.resolveAddr()
|
host, port, err := net.SplitHostPort(opts.PeerEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("resolve peer endpoint: %w", err)
|
return nil, fmt.Errorf("parse peer endpoint: %w", err)
|
||||||
}
|
}
|
||||||
|
port16, err := strconv.ParseUint(port, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse peer endpoint port: %w", err)
|
||||||
|
}
|
||||||
|
p.host = host
|
||||||
|
p.port = uint16(port16)
|
||||||
|
|
||||||
|
p.ip, err = netip.ParseAddr(p.host)
|
||||||
|
if err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.resolver = resolver.New(
|
||||||
|
opts.ResolveDNS,
|
||||||
|
func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
netConn, err := (&net.Dialer{}).DialContext(ctx, network, address)
|
||||||
|
logger.Verbosef("Using %s to resolve peer endpoint: %v", opts.ResolveDNS, err)
|
||||||
|
return netConn, err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
p.ip, err = p.resolveHost()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve peer endpoint ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return p, err
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *peer) initConf() string {
|
func (p *peer) initConf() string {
|
||||||
conf := "public_key=" + p.pubKey + "\n"
|
conf := "public_key=" + p.pubKey + "\n"
|
||||||
conf += "endpoint=" + p.ipPort + "\n"
|
conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n"
|
||||||
conf += "allowed_ip=0.0.0.0/0\n"
|
conf += "allowed_ip=0.0.0.0/0\n"
|
||||||
conf += "allowed_ip=::/0\n"
|
conf += "allowed_ip=::/0\n"
|
||||||
|
|
||||||
|
|
@ -71,30 +89,38 @@ func (p *peer) initConf() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *peer) updateConf() (string, bool) {
|
func (p *peer) updateConf() (string, bool) {
|
||||||
newIPPort, err := p.resolveAddr()
|
newIP, err := p.resolveHost()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Verbosef("Resolve peer endpoint: %v", err)
|
logger.Verbosef("Resolve peer endpoint: %v", err)
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
if p.ipPort == newIPPort {
|
if p.ip == newIP {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
p.ipPort = newIPPort
|
p.ip = newIP
|
||||||
logger.Verbosef("PeerEndpoint is changed to: %s", p.ipPort)
|
logger.Verbosef("PeerEndpoint is changed to: %s", p.ip)
|
||||||
|
|
||||||
conf := "public_key=" + p.pubKey + "\n"
|
conf := "public_key=" + p.pubKey + "\n"
|
||||||
conf += "update_only=true\n"
|
conf += "update_only=true\n"
|
||||||
conf += "endpoint=" + p.ipPort + "\n"
|
conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n"
|
||||||
return conf, true
|
return conf, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *peer) resolveAddr() (string, error) {
|
func (p *peer) resolveHost() (netip.Addr, error) {
|
||||||
c, err := p.dialer.Dial("udp", p.addr)
|
ips, err := p.resolver.LookupNetIP(context.Background(), "ip", p.host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return netip.Addr{}, fmt.Errorf("resolve ip for %s: %w", p.host, err)
|
||||||
}
|
}
|
||||||
defer c.Close()
|
for _, ip := range ips {
|
||||||
return c.RemoteAddr().String(), nil
|
conn, err := net.DialUDP("udp", nil, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, p.port)))
|
||||||
|
if err == nil {
|
||||||
|
conn.Close()
|
||||||
|
return ip, nil
|
||||||
|
} else {
|
||||||
|
logger.Verbosef("Dial %s: %s", ip, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return netip.Addr{}, fmt.Errorf("no available ip for %s", p.host)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipcSet(dev *device.Device) error {
|
func ipcSet(dev *device.Device) error {
|
||||||
|
|
@ -118,7 +144,7 @@ func ipcSet(dev *device.Device) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.addr != peer.ipPort {
|
if peer.resolver != nil {
|
||||||
go func() {
|
go func() {
|
||||||
c := time.Tick(opts.ResolveInterval)
|
c := time.Tick(opts.ResolveInterval)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ func dialWithDNS(dial dialer, dns string) dialer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, err := resolv.LookupHost(ctx, host)
|
ips, err := resolv.LookupNetIP(ctx, network, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -61,7 +61,7 @@ func dialWithDNS(dial dialer, dns string) dialer {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
)
|
)
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
addr := net.JoinHostPort(ip, port)
|
addr := net.JoinHostPort(ip.String(), port)
|
||||||
conn, lastErr = dial(ctx, network, addr)
|
conn, lastErr = dial(ctx, network, addr)
|
||||||
if lastErr == nil {
|
if lastErr == nil {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|
|
||||||
|
|
@ -3,61 +3,109 @@ package resolver
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PreferGo works on Windows since go1.19, https://github.com/golang/go/issues/33097
|
var errNotRetry = errors.New("not retry")
|
||||||
|
|
||||||
func New(addr string, dialContext func(context.Context, string, string) (net.Conn, error)) *net.Resolver {
|
type Resolver struct {
|
||||||
|
sysAddr, addr string
|
||||||
|
network string
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
httpClient *http.Client
|
||||||
|
|
||||||
|
r *net.Resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Resolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
ipNetwork := network
|
||||||
|
switch network {
|
||||||
|
case "tcp", "udp":
|
||||||
|
ipNetwork = "ip"
|
||||||
|
case "tcp4", "udp4":
|
||||||
|
ipNetwork = "ip4"
|
||||||
|
case "tcp6", "udp6":
|
||||||
|
ipNetwork = "ip6"
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.r.LookupNetIP(ctx, ipNetwork, host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(dns string, dial func(ctx context.Context, network, address string) (net.Conn, error)) *Resolver {
|
||||||
|
r := &Resolver{}
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(addr, "tls://"):
|
case strings.HasPrefix(dns, "tls://"):
|
||||||
return &net.Resolver{
|
r.addr = withDefaultPort(dns[len("tls://"):], "853")
|
||||||
|
host, _, _ := net.SplitHostPort(r.addr)
|
||||||
|
r.tlsConfig = &tls.Config{
|
||||||
|
ServerName: host,
|
||||||
|
}
|
||||||
|
r.r = &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
Dial: func(ctx context.Context, _, address string) (net.Conn, error) {
|
||||||
address := withDefaultPort(addr[len("tls://"):], "853")
|
if r.sysAddr == "" {
|
||||||
conn, err := dialContext(ctx, "tcp", address)
|
r.sysAddr = address
|
||||||
|
}
|
||||||
|
if r.sysAddr != address {
|
||||||
|
return nil, errNotRetry
|
||||||
|
}
|
||||||
|
conn, err := dial(ctx, "tcp", r.addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
host, _, _ := net.SplitHostPort(address)
|
return tls.Client(conn, r.tlsConfig), nil
|
||||||
c := tls.Client(conn, &tls.Config{
|
|
||||||
ServerName: host,
|
|
||||||
})
|
|
||||||
return c, nil
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
case strings.HasPrefix(addr, "https://"):
|
case strings.HasPrefix(dns, "https://"):
|
||||||
c := &http.Client{
|
r.httpClient = &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
DialContext: dialContext,
|
DialContext: dial,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return &net.Resolver{
|
r.r = &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
Dial: func(ctx context.Context, _, address string) (net.Conn, error) {
|
||||||
return newDoHConn(ctx, c, addr)
|
if r.sysAddr == "" {
|
||||||
},
|
r.sysAddr = address
|
||||||
}
|
}
|
||||||
case addr != "":
|
if r.sysAddr != address {
|
||||||
return &net.Resolver{
|
return nil, errNotRetry
|
||||||
PreferGo: true,
|
|
||||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
||||||
address := addr
|
|
||||||
network := "udp"
|
|
||||||
|
|
||||||
if strings.HasPrefix(addr, "tcp://") || strings.HasPrefix(addr, "udp://") {
|
|
||||||
network = addr[:len("tcp")]
|
|
||||||
address = addr[len("tcp://"):]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return dialContext(ctx, network, withDefaultPort(address, "53"))
|
return newDoHConn(ctx, r.httpClient, dns)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case dns != "":
|
||||||
|
r.addr = dns
|
||||||
|
r.network = "udp"
|
||||||
|
|
||||||
|
if strings.HasPrefix(dns, "tcp://") || strings.HasPrefix(dns, "udp://") {
|
||||||
|
r.addr = dns[len("tcp://"):]
|
||||||
|
r.network = dns[:len("tcp")]
|
||||||
|
}
|
||||||
|
r.addr = withDefaultPort(r.addr, "53")
|
||||||
|
|
||||||
|
r.r = &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, _, address string) (net.Conn, error) {
|
||||||
|
if r.sysAddr == "" {
|
||||||
|
r.sysAddr = address
|
||||||
|
}
|
||||||
|
if r.sysAddr != address {
|
||||||
|
return nil, errNotRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
return dial(ctx, r.network, r.addr)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return &net.Resolver{}
|
r.r = &net.Resolver{}
|
||||||
}
|
}
|
||||||
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func withDefaultPort(addr, port string) string {
|
func withDefaultPort(addr, port string) string {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
@ -25,14 +26,13 @@ func TestResolve(t *testing.T) {
|
||||||
"https://223.5.5.5:443/dns-query",
|
"https://223.5.5.5:443/dns-query",
|
||||||
} {
|
} {
|
||||||
t.Run(server, func(t *testing.T) {
|
t.Run(server, func(t *testing.T) {
|
||||||
d := &net.Dialer{
|
r := New(server, (&net.Dialer{}).DialContext)
|
||||||
Resolver: New(server, (&net.Dialer{}).DialContext),
|
ips, err := r.LookupNetIP(context.TODO(), "ip4", "www.example.com")
|
||||||
}
|
|
||||||
c, err := d.Dial("tcp4", "www.example.com:80")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
} else {
|
} else {
|
||||||
t.Logf("got %s", c.RemoteAddr())
|
t.Logf("got %s", ips)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue