81
internal/resolver/doh.go
Normal file
81
internal/resolver/doh.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _ net.Conn = &dohConn{}
|
||||
|
||||
type dohConn struct {
|
||||
addr string
|
||||
|
||||
once sync.Once
|
||||
onceErr error
|
||||
|
||||
in, ret bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *dohConn) Close() error { return nil }
|
||||
func (c *dohConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *dohConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *dohConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *dohConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *dohConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (c *dohConn) Write(b []byte) (int, error) { return c.in.Write(b) }
|
||||
|
||||
func (c *dohConn) Read(b []byte) (int, error) {
|
||||
c.once.Do(func() {
|
||||
url, err := url.Parse(c.addr)
|
||||
if err != nil {
|
||||
c.onceErr = err
|
||||
return
|
||||
}
|
||||
// RFC 8484
|
||||
url.Path = "/dns-query"
|
||||
|
||||
// Skip 2 bytes which are length
|
||||
reqBody := bytes.NewReader(c.in.Bytes()[2:])
|
||||
req, err := http.NewRequest("POST", url.String(), reqBody)
|
||||
if err != nil {
|
||||
c.onceErr = err
|
||||
return
|
||||
}
|
||||
req.Header.Set("content-type", "application/dns-message")
|
||||
req.Header.Set("accept", "application/dns-message")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
c.onceErr = err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.onceErr = err
|
||||
return
|
||||
}
|
||||
|
||||
l := uint16(len(respBody))
|
||||
_, err = c.ret.Write([]byte{uint8(l >> 8), uint8(l & ((1 << 8) - 1))})
|
||||
if err != nil {
|
||||
c.onceErr = err
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.ret.Write(respBody)
|
||||
if err != nil {
|
||||
c.onceErr = err
|
||||
return
|
||||
}
|
||||
})
|
||||
if c.onceErr != nil {
|
||||
return 0, c.onceErr
|
||||
}
|
||||
return c.ret.Read(b)
|
||||
}
|
||||
55
internal/resolver/resolver.go
Normal file
55
internal/resolver/resolver.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func New(addr string) *net.Resolver {
|
||||
switch {
|
||||
case strings.HasPrefix(addr, "tls://"):
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
d := tls.Dialer{}
|
||||
address := addr[len("tls://"):]
|
||||
return d.DialContext(ctx, "tcp", withDefaultPort(address, "853"))
|
||||
},
|
||||
}
|
||||
case strings.HasPrefix(addr, "https://"):
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
conn := &dohConn{addr: addr}
|
||||
return conn, nil
|
||||
},
|
||||
}
|
||||
case addr != "":
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
d := net.Dialer{}
|
||||
address := addr
|
||||
network := "udp"
|
||||
|
||||
if strings.HasPrefix(addr, "tcp://") || strings.HasPrefix(addr, "udp://") {
|
||||
network = addr[:len("tcp")]
|
||||
address = addr[len("tcp://"):]
|
||||
}
|
||||
|
||||
return d.DialContext(ctx, network, withDefaultPort(address, "53"))
|
||||
},
|
||||
}
|
||||
default:
|
||||
return &net.Resolver{}
|
||||
}
|
||||
}
|
||||
|
||||
func withDefaultPort(addr, port string) string {
|
||||
if _, _, err := net.SplitHostPort(addr); err == nil {
|
||||
return addr
|
||||
}
|
||||
return net.JoinHostPort(addr, port)
|
||||
}
|
||||
39
internal/resolver/resolver_test.go
Normal file
39
internal/resolver/resolver_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolve(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
for _, server := range []string{
|
||||
"",
|
||||
"223.5.5.5",
|
||||
"223.5.5.5:53",
|
||||
"tcp://223.5.5.5",
|
||||
"tcp://223.5.5.5:53",
|
||||
"udp://223.5.5.5",
|
||||
"udp://223.5.5.5:53",
|
||||
"tls://223.5.5.5",
|
||||
"tls://223.5.5.5:853",
|
||||
"https://223.5.5.5",
|
||||
"https://223.5.5.5:443",
|
||||
"https://223.5.5.5:443/dns-query",
|
||||
} {
|
||||
t.Run(server, func(t *testing.T) {
|
||||
d := &net.Dialer{
|
||||
Resolver: New(server),
|
||||
}
|
||||
c, err := d.Dial("tcp4", "www.example.com:80")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
t.Logf("got %s", c.RemoteAddr())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user