refactor: add diale options to internal resolver

master
Shengjing Zhu 2022-07-22 22:38:20 +08:00
parent 252040b47c
commit 5dc8b57908
4 changed files with 88 additions and 54 deletions

10
conf.go
View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
@ -33,7 +34,14 @@ func newPeerEndpoint() (*peer, error) {
p := &peer{ p := &peer{
dialer: &net.Dialer{ dialer: &net.Dialer{
Resolver: resolver.New(opts.ResolveDNS), Resolver: resolver.New(
opts.ResolveDNS,
func(ctx context.Context, network, address string) (net.Conn, error) {
netConn, err := (&net.Dialer{}).DialContext(ctx, network, address)
logger.Verbosef("Using %s to resolve peer endpoint: %v", opts.ResolveDNS, err)
return netConn, err
},
),
}, },
pubKey: hex.EncodeToString(pubKey), pubKey: hex.EncodeToString(pubKey),
psk: hex.EncodeToString(psk), psk: hex.EncodeToString(psk),

View File

@ -2,80 +2,93 @@ package resolver
import ( import (
"bytes" "bytes"
"context"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"time" "time"
) )
var _ net.Conn = &dohConn{} var _ net.Conn = &dohConn{}
type dohConn struct { type dohConn struct {
addr string query, resp *bytes.Buffer
once sync.Once do func() error
onceErr error }
in, ret bytes.Buffer func newDoHConn(ctx context.Context, client *http.Client, addr string) (*dohConn, error) {
c := &dohConn{
query: &bytes.Buffer{},
resp: &bytes.Buffer{},
}
url, err := url.Parse(addr)
if err != nil {
return nil, err
}
// RFC 8484
url.Path = "/dns-query"
c.do = func() error {
if c.query.Len() <= 2 || c.resp.Len() > 0 {
return nil
}
// Skip length header
c.query.Next(2)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), c.query)
if err != nil {
return err
}
req.Header.Set("content-type", "application/dns-message")
req.Header.Set("accept", "application/dns-message")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server return %d: %s", resp.StatusCode, respBody)
}
// Add length header
l := uint16(len(respBody))
_, err = c.resp.Write([]byte{uint8(l >> 8), uint8(l & ((1 << 8) - 1))})
if err != nil {
return err
}
_, err = c.resp.Write(respBody)
return err
}
return c, nil
} }
func (c *dohConn) Close() error { return nil } func (c *dohConn) Close() error { return nil }
func (c *dohConn) LocalAddr() net.Addr { return nil } func (c *dohConn) LocalAddr() net.Addr { return nil }
func (c *dohConn) RemoteAddr() net.Addr { return nil } func (c *dohConn) RemoteAddr() net.Addr { return nil }
func (c *dohConn) SetDeadline(t time.Time) error { return nil } func (c *dohConn) SetDeadline(time.Time) error { return nil }
func (c *dohConn) SetReadDeadline(t time.Time) error { return nil } func (c *dohConn) SetReadDeadline(time.Time) error { return nil }
func (c *dohConn) SetWriteDeadline(t time.Time) error { return nil } func (c *dohConn) SetWriteDeadline(time.Time) error { return nil }
func (c *dohConn) Write(b []byte) (int, error) { return c.in.Write(b) } func (c *dohConn) Write(b []byte) (int, error) { return c.query.Write(b) }
func (c *dohConn) Read(b []byte) (int, error) { func (c *dohConn) Read(b []byte) (int, error) {
c.once.Do(func() { if err := c.do(); err != nil {
url, err := url.Parse(c.addr) return 0, err
if err != nil {
c.onceErr = err
return
}
// RFC 8484
url.Path = "/dns-query"
// Skip 2 bytes which are length
reqBody := bytes.NewReader(c.in.Bytes()[2:])
req, err := http.NewRequest("POST", url.String(), reqBody)
if err != nil {
c.onceErr = err
return
}
req.Header.Set("content-type", "application/dns-message")
req.Header.Set("accept", "application/dns-message")
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.onceErr = err
return
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
c.onceErr = err
return
} }
l := uint16(len(respBody)) return c.resp.Read(b)
_, err = c.ret.Write([]byte{uint8(l >> 8), uint8(l & ((1 << 8) - 1))})
if err != nil {
c.onceErr = err
return
}
_, err = c.ret.Write(respBody)
if err != nil {
c.onceErr = err
return
}
})
if c.onceErr != nil {
return 0, c.onceErr
}
return c.ret.Read(b)
} }

View File

@ -4,33 +4,46 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"net" "net"
"net/http"
"strings" "strings"
) )
func New(addr string) *net.Resolver { // PreferGo works on Windows since go1.19, https://github.com/golang/go/issues/33097
func New(addr string, dialContext func(context.Context, string, string) (net.Conn, error)) *net.Resolver {
switch { switch {
case strings.HasPrefix(addr, "tls://"): case strings.HasPrefix(addr, "tls://"):
return &net.Resolver{ return &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
d := tls.Dialer{} address := withDefaultPort(addr[len("tls://"):], "853")
address := addr[len("tls://"):] conn, err := dialContext(ctx, "tcp", address)
return d.DialContext(ctx, "tcp", withDefaultPort(address, "853")) if err != nil {
return nil, err
}
host, _, _ := net.SplitHostPort(address)
c := tls.Client(conn, &tls.Config{
ServerName: host,
})
return c, nil
}, },
} }
case strings.HasPrefix(addr, "https://"): case strings.HasPrefix(addr, "https://"):
c := &http.Client{
Transport: &http.Transport{
DialContext: dialContext,
},
}
return &net.Resolver{ return &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(_ context.Context, _, _ string) (net.Conn, error) { Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
conn := &dohConn{addr: addr} return newDoHConn(ctx, c, addr)
return conn, nil
}, },
} }
case addr != "": case addr != "":
return &net.Resolver{ return &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
d := net.Dialer{}
address := addr address := addr
network := "udp" network := "udp"
@ -39,7 +52,7 @@ func New(addr string) *net.Resolver {
address = addr[len("tcp://"):] address = addr[len("tcp://"):]
} }
return d.DialContext(ctx, network, withDefaultPort(address, "53")) return dialContext(ctx, network, withDefaultPort(address, "53"))
}, },
} }
default: default:

View File

@ -26,7 +26,7 @@ func TestResolve(t *testing.T) {
} { } {
t.Run(server, func(t *testing.T) { t.Run(server, func(t *testing.T) {
d := &net.Dialer{ d := &net.Dialer{
Resolver: New(server), Resolver: New(server, (&net.Dialer{}).DialContext),
} }
c, err := d.Dial("tcp4", "www.example.com:80") c, err := d.Dial("tcp4", "www.example.com:80")
if err != nil { if err != nil {