diff --git a/conf.go b/conf.go index 4e3fa08..8ff95f2 100644 --- a/conf.go +++ b/conf.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/base64" "encoding/hex" "fmt" @@ -33,7 +34,14 @@ func newPeerEndpoint() (*peer, error) { p := &peer{ dialer: &net.Dialer{ - Resolver: resolver.New(opts.ResolveDNS), + Resolver: resolver.New( + opts.ResolveDNS, + func(ctx context.Context, network, address string) (net.Conn, error) { + netConn, err := (&net.Dialer{}).DialContext(ctx, network, address) + logger.Verbosef("Using %s to resolve peer endpoint: %v", opts.ResolveDNS, err) + return netConn, err + }, + ), }, pubKey: hex.EncodeToString(pubKey), psk: hex.EncodeToString(psk), diff --git a/internal/resolver/doh.go b/internal/resolver/doh.go index 6741757..b0362b0 100644 --- a/internal/resolver/doh.go +++ b/internal/resolver/doh.go @@ -2,80 +2,93 @@ package resolver import ( "bytes" + "context" + "fmt" "io" "net" "net/http" "net/url" - "sync" "time" ) var _ net.Conn = &dohConn{} type dohConn struct { - addr string + query, resp *bytes.Buffer - once sync.Once - onceErr error - - in, ret bytes.Buffer + do func() error } -func (c *dohConn) Close() error { return nil } -func (c *dohConn) LocalAddr() net.Addr { return nil } -func (c *dohConn) RemoteAddr() net.Addr { return nil } -func (c *dohConn) SetDeadline(t time.Time) error { return nil } -func (c *dohConn) SetReadDeadline(t time.Time) error { return nil } -func (c *dohConn) SetWriteDeadline(t time.Time) error { return nil } +func newDoHConn(ctx context.Context, client *http.Client, addr string) (*dohConn, error) { + c := &dohConn{ + query: &bytes.Buffer{}, + resp: &bytes.Buffer{}, + } -func (c *dohConn) Write(b []byte) (int, error) { return c.in.Write(b) } + url, err := url.Parse(addr) + if err != nil { + return nil, err + } + // RFC 8484 + url.Path = "/dns-query" -func (c *dohConn) Read(b []byte) (int, error) { - c.once.Do(func() { - url, err := url.Parse(c.addr) - if err != nil { - c.onceErr = err - return + c.do = func() error { + if c.query.Len() <= 2 || c.resp.Len() > 0 { + return nil } - // RFC 8484 - url.Path = "/dns-query" - // Skip 2 bytes which are length - reqBody := bytes.NewReader(c.in.Bytes()[2:]) - req, err := http.NewRequest("POST", url.String(), reqBody) + // Skip length header + c.query.Next(2) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), c.query) if err != nil { - c.onceErr = err - return + return err } req.Header.Set("content-type", "application/dns-message") req.Header.Set("accept", "application/dns-message") - resp, err := http.DefaultClient.Do(req) + + resp, err := client.Do(req) if err != nil { - c.onceErr = err - return + return err } defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) if err != nil { - c.onceErr = err - return + return err } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server return %d: %s", resp.StatusCode, respBody) + } + + // Add length header l := uint16(len(respBody)) - _, err = c.ret.Write([]byte{uint8(l >> 8), uint8(l & ((1 << 8) - 1))}) + _, err = c.resp.Write([]byte{uint8(l >> 8), uint8(l & ((1 << 8) - 1))}) if err != nil { - c.onceErr = err - return + return err } - _, err = c.ret.Write(respBody) - if err != nil { - c.onceErr = err - return - } - }) - if c.onceErr != nil { - return 0, c.onceErr + _, err = c.resp.Write(respBody) + return err } - return c.ret.Read(b) + + return c, nil +} + +func (c *dohConn) Close() error { return nil } +func (c *dohConn) LocalAddr() net.Addr { return nil } +func (c *dohConn) RemoteAddr() net.Addr { return nil } +func (c *dohConn) SetDeadline(time.Time) error { return nil } +func (c *dohConn) SetReadDeadline(time.Time) error { return nil } +func (c *dohConn) SetWriteDeadline(time.Time) error { return nil } + +func (c *dohConn) Write(b []byte) (int, error) { return c.query.Write(b) } + +func (c *dohConn) Read(b []byte) (int, error) { + if err := c.do(); err != nil { + return 0, err + } + + return c.resp.Read(b) } diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index e587b54..a8d8e7a 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -4,33 +4,46 @@ import ( "context" "crypto/tls" "net" + "net/http" "strings" ) -func New(addr string) *net.Resolver { +// PreferGo works on Windows since go1.19, https://github.com/golang/go/issues/33097 + +func New(addr string, dialContext func(context.Context, string, string) (net.Conn, error)) *net.Resolver { switch { case strings.HasPrefix(addr, "tls://"): return &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { - d := tls.Dialer{} - address := addr[len("tls://"):] - return d.DialContext(ctx, "tcp", withDefaultPort(address, "853")) + address := withDefaultPort(addr[len("tls://"):], "853") + conn, err := dialContext(ctx, "tcp", address) + if err != nil { + return nil, err + } + host, _, _ := net.SplitHostPort(address) + c := tls.Client(conn, &tls.Config{ + ServerName: host, + }) + return c, nil }, } case strings.HasPrefix(addr, "https://"): + c := &http.Client{ + Transport: &http.Transport{ + DialContext: dialContext, + }, + } return &net.Resolver{ PreferGo: true, - Dial: func(_ context.Context, _, _ string) (net.Conn, error) { - conn := &dohConn{addr: addr} - return conn, nil + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + return newDoHConn(ctx, c, addr) }, } case addr != "": return &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { - d := net.Dialer{} address := addr network := "udp" @@ -39,7 +52,7 @@ func New(addr string) *net.Resolver { address = addr[len("tcp://"):] } - return d.DialContext(ctx, network, withDefaultPort(address, "53")) + return dialContext(ctx, network, withDefaultPort(address, "53")) }, } default: diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go index e4a4e18..5819f26 100644 --- a/internal/resolver/resolver_test.go +++ b/internal/resolver/resolver_test.go @@ -26,7 +26,7 @@ func TestResolve(t *testing.T) { } { t.Run(server, func(t *testing.T) { d := &net.Dialer{ - Resolver: New(server), + Resolver: New(server, (&net.Dialer{}).DialContext), } c, err := d.Dial("tcp4", "www.example.com:80") if err != nil {