refactor: add diale options to internal resolver
parent
252040b47c
commit
5dc8b57908
10
conf.go
10
conf.go
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -33,7 +34,14 @@ func newPeerEndpoint() (*peer, error) {
|
||||||
|
|
||||||
p := &peer{
|
p := &peer{
|
||||||
dialer: &net.Dialer{
|
dialer: &net.Dialer{
|
||||||
Resolver: resolver.New(opts.ResolveDNS),
|
Resolver: resolver.New(
|
||||||
|
opts.ResolveDNS,
|
||||||
|
func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
netConn, err := (&net.Dialer{}).DialContext(ctx, network, address)
|
||||||
|
logger.Verbosef("Using %s to resolve peer endpoint: %v", opts.ResolveDNS, err)
|
||||||
|
return netConn, err
|
||||||
|
},
|
||||||
|
),
|
||||||
},
|
},
|
||||||
pubKey: hex.EncodeToString(pubKey),
|
pubKey: hex.EncodeToString(pubKey),
|
||||||
psk: hex.EncodeToString(psk),
|
psk: hex.EncodeToString(psk),
|
||||||
|
|
|
||||||
|
|
@ -2,80 +2,93 @@ package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ net.Conn = &dohConn{}
|
var _ net.Conn = &dohConn{}
|
||||||
|
|
||||||
type dohConn struct {
|
type dohConn struct {
|
||||||
addr string
|
query, resp *bytes.Buffer
|
||||||
|
|
||||||
once sync.Once
|
do func() error
|
||||||
onceErr error
|
}
|
||||||
|
|
||||||
in, ret bytes.Buffer
|
func newDoHConn(ctx context.Context, client *http.Client, addr string) (*dohConn, error) {
|
||||||
|
c := &dohConn{
|
||||||
|
query: &bytes.Buffer{},
|
||||||
|
resp: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := url.Parse(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// RFC 8484
|
||||||
|
url.Path = "/dns-query"
|
||||||
|
|
||||||
|
c.do = func() error {
|
||||||
|
if c.query.Len() <= 2 || c.resp.Len() > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip length header
|
||||||
|
c.query.Next(2)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), c.query)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("content-type", "application/dns-message")
|
||||||
|
req.Header.Set("accept", "application/dns-message")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("server return %d: %s", resp.StatusCode, respBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add length header
|
||||||
|
l := uint16(len(respBody))
|
||||||
|
_, err = c.resp.Write([]byte{uint8(l >> 8), uint8(l & ((1 << 8) - 1))})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.resp.Write(respBody)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *dohConn) Close() error { return nil }
|
func (c *dohConn) Close() error { return nil }
|
||||||
func (c *dohConn) LocalAddr() net.Addr { return nil }
|
func (c *dohConn) LocalAddr() net.Addr { return nil }
|
||||||
func (c *dohConn) RemoteAddr() net.Addr { return nil }
|
func (c *dohConn) RemoteAddr() net.Addr { return nil }
|
||||||
func (c *dohConn) SetDeadline(t time.Time) error { return nil }
|
func (c *dohConn) SetDeadline(time.Time) error { return nil }
|
||||||
func (c *dohConn) SetReadDeadline(t time.Time) error { return nil }
|
func (c *dohConn) SetReadDeadline(time.Time) error { return nil }
|
||||||
func (c *dohConn) SetWriteDeadline(t time.Time) error { return nil }
|
func (c *dohConn) SetWriteDeadline(time.Time) error { return nil }
|
||||||
|
|
||||||
func (c *dohConn) Write(b []byte) (int, error) { return c.in.Write(b) }
|
func (c *dohConn) Write(b []byte) (int, error) { return c.query.Write(b) }
|
||||||
|
|
||||||
func (c *dohConn) Read(b []byte) (int, error) {
|
func (c *dohConn) Read(b []byte) (int, error) {
|
||||||
c.once.Do(func() {
|
if err := c.do(); err != nil {
|
||||||
url, err := url.Parse(c.addr)
|
return 0, err
|
||||||
if err != nil {
|
|
||||||
c.onceErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// RFC 8484
|
|
||||||
url.Path = "/dns-query"
|
|
||||||
|
|
||||||
// Skip 2 bytes which are length
|
|
||||||
reqBody := bytes.NewReader(c.in.Bytes()[2:])
|
|
||||||
req, err := http.NewRequest("POST", url.String(), reqBody)
|
|
||||||
if err != nil {
|
|
||||||
c.onceErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("content-type", "application/dns-message")
|
|
||||||
req.Header.Set("accept", "application/dns-message")
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
c.onceErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
c.onceErr = err
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
l := uint16(len(respBody))
|
return c.resp.Read(b)
|
||||||
_, err = c.ret.Write([]byte{uint8(l >> 8), uint8(l & ((1 << 8) - 1))})
|
|
||||||
if err != nil {
|
|
||||||
c.onceErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = c.ret.Write(respBody)
|
|
||||||
if err != nil {
|
|
||||||
c.onceErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if c.onceErr != nil {
|
|
||||||
return 0, c.onceErr
|
|
||||||
}
|
|
||||||
return c.ret.Read(b)
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,33 +4,46 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(addr string) *net.Resolver {
|
// PreferGo works on Windows since go1.19, https://github.com/golang/go/issues/33097
|
||||||
|
|
||||||
|
func New(addr string, dialContext func(context.Context, string, string) (net.Conn, error)) *net.Resolver {
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(addr, "tls://"):
|
case strings.HasPrefix(addr, "tls://"):
|
||||||
return &net.Resolver{
|
return &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
d := tls.Dialer{}
|
address := withDefaultPort(addr[len("tls://"):], "853")
|
||||||
address := addr[len("tls://"):]
|
conn, err := dialContext(ctx, "tcp", address)
|
||||||
return d.DialContext(ctx, "tcp", withDefaultPort(address, "853"))
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
host, _, _ := net.SplitHostPort(address)
|
||||||
|
c := tls.Client(conn, &tls.Config{
|
||||||
|
ServerName: host,
|
||||||
|
})
|
||||||
|
return c, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
case strings.HasPrefix(addr, "https://"):
|
case strings.HasPrefix(addr, "https://"):
|
||||||
|
c := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: dialContext,
|
||||||
|
},
|
||||||
|
}
|
||||||
return &net.Resolver{
|
return &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
Dial: func(_ context.Context, _, _ string) (net.Conn, error) {
|
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
conn := &dohConn{addr: addr}
|
return newDoHConn(ctx, c, addr)
|
||||||
return conn, nil
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
case addr != "":
|
case addr != "":
|
||||||
return &net.Resolver{
|
return &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
d := net.Dialer{}
|
|
||||||
address := addr
|
address := addr
|
||||||
network := "udp"
|
network := "udp"
|
||||||
|
|
||||||
|
|
@ -39,7 +52,7 @@ func New(addr string) *net.Resolver {
|
||||||
address = addr[len("tcp://"):]
|
address = addr[len("tcp://"):]
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.DialContext(ctx, network, withDefaultPort(address, "53"))
|
return dialContext(ctx, network, withDefaultPort(address, "53"))
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ func TestResolve(t *testing.T) {
|
||||||
} {
|
} {
|
||||||
t.Run(server, func(t *testing.T) {
|
t.Run(server, func(t *testing.T) {
|
||||||
d := &net.Dialer{
|
d := &net.Dialer{
|
||||||
Resolver: New(server),
|
Resolver: New(server, (&net.Dialer{}).DialContext),
|
||||||
}
|
}
|
||||||
c, err := d.Dial("tcp4", "www.example.com:80")
|
c, err := d.Dial("tcp4", "www.example.com:80")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue