diff --git a/conf.go b/conf.go index fb9490b..4e3fa08 100644 --- a/conf.go +++ b/conf.go @@ -1,14 +1,13 @@ package main import ( - "context" - "crypto/tls" "encoding/base64" "encoding/hex" "fmt" "net" "time" + "github.com/zhsj/wghttp/internal/resolver" "golang.zx2c4.com/wireguard/device" ) @@ -16,6 +15,7 @@ type peer struct { dialer *net.Dialer pubKey string + psk string addr string ipPort string @@ -26,37 +26,17 @@ func newPeerEndpoint() (*peer, error) { if err != nil { return nil, fmt.Errorf("parse peer public key: %w", err) } + psk, err := base64.StdEncoding.DecodeString(opts.PresharedKey) + if err != nil { + return nil, fmt.Errorf("parse preshared key: %w", err) + } p := &peer{ dialer: &net.Dialer{ - Resolver: &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - dot := false - if opts.DNS != "" { - port := "53" - if opts.DoT != "" { - port = opts.DoT - dot = true - } - address = net.JoinHostPort(opts.DNS, port) - } - logger.Verbosef("Using %s (DoT: %t) to resolve peer endpoint", address, dot) - - if !dot { - var d net.Dialer - return d.DialContext(ctx, network, address) - } - d := tls.Dialer{ - Config: &tls.Config{ - InsecureSkipVerify: true, - }, - } - return d.DialContext(ctx, "tcp", address) - }, - }, + Resolver: resolver.New(opts.ResolveDNS), }, pubKey: hex.EncodeToString(pubKey), + psk: hex.EncodeToString(psk), addr: opts.PeerEndpoint, } p.ipPort, err = p.resolveAddr() @@ -75,6 +55,9 @@ func (p *peer) initConf() string { if opts.KeepaliveInterval > 0 { conf += fmt.Sprintf("persistent_keepalive_interval=%.f\n", opts.KeepaliveInterval.Seconds()) } + if p.psk != "" { + conf += "preshared_key=" + p.psk + "\n" + } return conf } @@ -112,6 +95,9 @@ func ipcSet(dev *device.Device) error { return fmt.Errorf("parse client private key: %w", err) } conf := "private_key=" + hex.EncodeToString(privateKey) + "\n" + if opts.ClientPort != 0 { + conf += fmt.Sprintf("listen_port=%d\n", opts.ClientPort) + } peer, err := newPeerEndpoint() if err != nil { diff --git a/internal/resolver/doh.go b/internal/resolver/doh.go new file mode 100644 index 0000000..6741757 --- /dev/null +++ b/internal/resolver/doh.go @@ -0,0 +1,81 @@ +package resolver + +import ( + "bytes" + "io" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +var _ net.Conn = &dohConn{} + +type dohConn struct { + addr string + + once sync.Once + onceErr error + + in, ret bytes.Buffer +} + +func (c *dohConn) Close() error { return nil } +func (c *dohConn) LocalAddr() net.Addr { return nil } +func (c *dohConn) RemoteAddr() net.Addr { return nil } +func (c *dohConn) SetDeadline(t time.Time) error { return nil } +func (c *dohConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dohConn) SetWriteDeadline(t time.Time) error { return nil } + +func (c *dohConn) Write(b []byte) (int, error) { return c.in.Write(b) } + +func (c *dohConn) Read(b []byte) (int, error) { + c.once.Do(func() { + url, err := url.Parse(c.addr) + if err != nil { + c.onceErr = err + return + } + // RFC 8484 + url.Path = "/dns-query" + + // Skip 2 bytes which are length + reqBody := bytes.NewReader(c.in.Bytes()[2:]) + req, err := http.NewRequest("POST", url.String(), reqBody) + if err != nil { + c.onceErr = err + return + } + req.Header.Set("content-type", "application/dns-message") + req.Header.Set("accept", "application/dns-message") + resp, err := http.DefaultClient.Do(req) + if err != nil { + c.onceErr = err + return + } + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + c.onceErr = err + return + } + + l := uint16(len(respBody)) + _, err = c.ret.Write([]byte{uint8(l >> 8), uint8(l & ((1 << 8) - 1))}) + if err != nil { + c.onceErr = err + return + } + + _, err = c.ret.Write(respBody) + if err != nil { + c.onceErr = err + return + } + }) + if c.onceErr != nil { + return 0, c.onceErr + } + return c.ret.Read(b) +} diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go new file mode 100644 index 0000000..e587b54 --- /dev/null +++ b/internal/resolver/resolver.go @@ -0,0 +1,55 @@ +package resolver + +import ( + "context" + "crypto/tls" + "net" + "strings" +) + +func New(addr string) *net.Resolver { + switch { + case strings.HasPrefix(addr, "tls://"): + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + d := tls.Dialer{} + address := addr[len("tls://"):] + return d.DialContext(ctx, "tcp", withDefaultPort(address, "853")) + }, + } + case strings.HasPrefix(addr, "https://"): + return &net.Resolver{ + PreferGo: true, + Dial: func(_ context.Context, _, _ string) (net.Conn, error) { + conn := &dohConn{addr: addr} + return conn, nil + }, + } + case addr != "": + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + d := net.Dialer{} + address := addr + network := "udp" + + if strings.HasPrefix(addr, "tcp://") || strings.HasPrefix(addr, "udp://") { + network = addr[:len("tcp")] + address = addr[len("tcp://"):] + } + + return d.DialContext(ctx, network, withDefaultPort(address, "53")) + }, + } + default: + return &net.Resolver{} + } +} + +func withDefaultPort(addr, port string) string { + if _, _, err := net.SplitHostPort(addr); err == nil { + return addr + } + return net.JoinHostPort(addr, port) +} diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go new file mode 100644 index 0000000..e4a4e18 --- /dev/null +++ b/internal/resolver/resolver_test.go @@ -0,0 +1,39 @@ +package resolver + +import ( + "net" + "testing" +) + +func TestResolve(t *testing.T) { + if testing.Short() { + t.Skip() + } + + for _, server := range []string{ + "", + "223.5.5.5", + "223.5.5.5:53", + "tcp://223.5.5.5", + "tcp://223.5.5.5:53", + "udp://223.5.5.5", + "udp://223.5.5.5:53", + "tls://223.5.5.5", + "tls://223.5.5.5:853", + "https://223.5.5.5", + "https://223.5.5.5:443", + "https://223.5.5.5:443/dns-query", + } { + t.Run(server, func(t *testing.T) { + d := &net.Dialer{ + Resolver: New(server), + } + c, err := d.Dial("tcp4", "www.example.com:80") + if err != nil { + t.Error(err) + } else { + t.Logf("got %s", c.RemoteAddr()) + } + }) + } +} diff --git a/main.go b/main.go index 83751bf..6c20ce5 100644 --- a/main.go +++ b/main.go @@ -30,17 +30,19 @@ var ( ) type options struct { - PeerEndpoint string `long:"peer-endpoint" env:"PEER_ENDPOINT" description:"WireGuard server address"` - PeerKey string `long:"peer-key" env:"PEER_KEY" description:"WireGuard server public key in base64 format"` - PrivateKey string `long:"private-key" env:"PRIVATE_KEY" description:"WireGuard client private key in base64 format"` - ClientIPs []string `long:"client-ip" env:"CLIENT_IP" env-delim:"," description:"WireGuard client IP address"` + ClientIPs []string `long:"client-ip" env:"CLIENT_IP" env-delim:"," description:"[Interface].Address\tfor WireGuard client (can be set multiple times)"` + ClientPort int `long:"client-port" env:"CLIENT_PORT" description:"[Interface].ListenPort\tfor WireGuard client (optional)"` + PrivateKey string `long:"private-key" env:"PRIVATE_KEY" description:"[Interface].PrivateKey\tfor WireGuard client (format: base64)"` + DNS string `long:"dns" env:"DNS" description:"[Interface].DNS\tfor WireGuard network (format: IP)"` + MTU int `long:"mtu" env:"MTU" default:"1280" description:"[Interface].MTU\tfor WireGuard network"` - DNS string `long:"dns" env:"DNS" description:"DNS IP for WireGuard network and resolving server address"` - DoT string `long:"dot" env:"DOT" description:"Port for DNS over TLS, used to resolve WireGuard server address if available"` - MTU int `long:"mtu" env:"MTU" default:"1280" description:"MTU for WireGuard network"` + PeerEndpoint string `long:"peer-endpoint" env:"PEER_ENDPOINT" description:"[Peer].Endpoint\tfor WireGuard server (format: host:port)"` + PeerKey string `long:"peer-key" env:"PEER_KEY" description:"[Peer].PublicKey\tfor WireGuard server (format: base64)"` + PresharedKey string `long:"preshared-key" env:"PRESHARED_KEY" description:"[Peer].PresharedKey\tfor WireGuard network (optional, format: base64)"` + KeepaliveInterval time.Duration `long:"keepalive-interval" env:"KEEPALIVE_INTERVAL" description:"[Peer].PersistentKeepalive\tfor WireGuard network (optional)"` - KeepaliveInterval time.Duration `long:"keepalive-interval" env:"KEEPALIVE_INTERVAL" description:"Interval for sending keepalive packet"` - ResolveInterval time.Duration `long:"resolve-interval" env:"RESOLVE_INTERVAL" default:"1m" description:"Interval for resolving WireGuard server address"` + ResolveDNS string `long:"resolve-dns" env:"RESOLVE_DNS" description:"DNS for resolving WireGuard server address (optional, format: protocol://ip:port)\nProtocol includes tcp, udp, tls(DNS over TLS) and https(DNS over HTTPS)"` + ResolveInterval time.Duration `long:"resolve-interval" env:"RESOLVE_INTERVAL" default:"1m" description:"Interval for resolving WireGuard server address (set 0 to disable)"` Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 server address"` ExitMode string `long:"exit-mode" env:"EXIT_MODE" choice:"remote" choice:"local" default:"remote" description:"Exit mode"`