diff --git a/conf.go b/conf.go index 3277de1..d5d4b04 100644 --- a/conf.go +++ b/conf.go @@ -1,69 +1,147 @@ package main import ( + "context" + "crypto/tls" "encoding/base64" "encoding/hex" "fmt" + "math/rand" "net" "time" "golang.zx2c4.com/wireguard/device" ) +const ( + resolvePeerInterval = time.Second * 10 + keepaliveInterval = "10" +) + +type peer struct { + dialer *net.Dialer + + pubKey string + keepalive bool + + addr string + ipPort string +} + +func newPeerEndpoint(opts options) (*peer, error) { + pubKey, err := base64.StdEncoding.DecodeString(opts.PeerKey) + if err != nil { + return nil, fmt.Errorf("parse peer key: %w", err) + } + + p := &peer{ + dialer: &net.Dialer{ + Resolver: &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + if len(opts.DNS) > 0 { + host := opts.DNS[rand.Intn(len(opts.DNS))] + port := "53" + if opts.DoT != "" { + port = opts.DoT + } + address = net.JoinHostPort(host, port) + } + logger.Verbosef("Using %s to resolve peer endpoint", address) + + if opts.DoT == "" { + var d net.Dialer + return d.DialContext(ctx, network, address) + } + d := tls.Dialer{ + Config: &tls.Config{ + InsecureSkipVerify: true, + }, + } + return d.DialContext(ctx, "tcp", address) + }, + }, + }, + pubKey: hex.EncodeToString(pubKey), + keepalive: opts.ExitMode == "local", + addr: opts.PeerEndpoint, + } + p.ipPort, err = p.resolveAddr() + return p, err +} + +func (p *peer) initConf() string { + conf := "public_key=" + p.pubKey + "\n" + conf += "endpoint=" + p.ipPort + "\n" + conf += "allowed_ip=0.0.0.0/0\n" + conf += "allowed_ip=::/0\n" + + if p.keepalive { + conf += "persistent_keepalive_interval=" + keepaliveInterval + "\n" + } + + return conf +} + +func (p *peer) updateConf() (string, bool) { + newIPPort, err := p.resolveAddr() + if err != nil { + logger.Verbosef("Resolve peer endpoint: %v", err) + return "", false + } + if p.ipPort == newIPPort { + return "", false + } + p.ipPort = newIPPort + logger.Verbosef("PeerEndpoint is changed to: %s", p.ipPort) + + conf := "public_key=" + p.pubKey + "\n" + conf += "update_only=true\n" + conf += "endpoint=" + p.ipPort + "\n" + return conf, true +} + +func (p *peer) resolveAddr() (string, error) { + c, err := p.dialer.Dial("udp", p.addr) + if err != nil { + return "", fmt.Errorf("dial %s: %w", p.addr, err) + } + defer c.Close() + return c.RemoteAddr().String(), nil +} + func ipcSet(dev *device.Device, opts options) error { privateKey, err := base64.StdEncoding.DecodeString(opts.PrivateKey) if err != nil { return fmt.Errorf("parse private key: %w", err) } - peerKey, err := base64.StdEncoding.DecodeString(opts.PeerKey) - if err != nil { - return fmt.Errorf("parse peer key: %w", err) - } conf := "private_key=" + hex.EncodeToString(privateKey) + "\n" - conf += "public_key=" + hex.EncodeToString(peerKey) + "\n" - peerAddr, err := net.ResolveUDPAddr("udp", opts.PeerEndpoint) + peer, err := newPeerEndpoint(opts) if err != nil { - return fmt.Errorf("resolve peer endpoint: %w", err) - } - - conf += "endpoint=" + peerAddr.String() + "\n" - conf += "allowed_ip=0.0.0.0/0\n" - conf += "allowed_ip=::/0\n" - - if opts.ExitMode == "local" { - conf += "persistent_keepalive_interval=10\n" + return err } + conf += peer.initConf() if err := dev.IpcSet(conf); err != nil { return fmt.Errorf("set device config: %w", err) } - if peerAddr.String() != opts.PeerEndpoint { - go refreshEndpoint(dev, peerKey, peerAddr.String(), opts.PeerEndpoint) + if peer.addr != peer.ipPort { + go func() { + c := time.Tick(resolvePeerInterval) + + for range c { + conf, needUpdate := peer.updateConf() + if !needUpdate { + continue + } + + if err := dev.IpcSet(conf); err != nil { + logger.Errorf("Set device config: %v", err) + } + } + }() } return nil } - -func refreshEndpoint(dev *device.Device, peerKey []byte, currentPeerAddr, peerEndpoint string) { - c := time.Tick(10 * time.Second) - - for range c { - addr, err := net.ResolveUDPAddr("udp", peerEndpoint) - if err != nil { - logger.Errorf("Resolve peer endpoint: %v", err) - continue - } - if currentPeerAddr == addr.String() { - continue - } - currentPeerAddr = addr.String() - logger.Verbosef("Endpoint is changed to: %s", addr) - conf := "public_key=" + hex.EncodeToString(peerKey) + "\n" - conf += "update_only=true\n" - conf += "endpoint=" + addr.String() + "\n" - if err := dev.IpcSet(conf); err != nil { - logger.Errorf("Set device config: %v", err) - } - } -} diff --git a/main.go b/main.go index 60d0558..d366555 100644 --- a/main.go +++ b/main.go @@ -29,9 +29,10 @@ type options struct { PeerKey string `long:"peer-key" env:"PEER_KEY" required:"true" description:"WireGuard server public key in base64 format"` PrivateKey string `long:"private-key" env:"PRIVATE_KEY" required:"true" description:"WireGuard client private key in base64 format"` ClientIPs []string `long:"client-ip" env:"CLIENT_IP" env-delim:"," required:"true" description:"WireGuard client IP address"` + DNS []string `long:"dns" env:"DNS" env-delim:"," description:"DNS servers for WireGuard network and resolving server address"` + DoT string `long:"dot" env:"DOT" description:"Port for DNS over TLS, used to resolve WireGuard server address if available"` + MTU int `long:"mtu" env:"MTU" default:"1280" description:"MTU for WireGuard network"` Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 server address"` - DNS []string `long:"dns" env:"DNS" env-delim:"," description:"DNS server IP address, only used in WireGuard"` - MTU int `long:"mtu" env:"MTU" default:"1280" description:"MTU"` ExitMode string `long:"exit-mode" env:"EXIT_MODE" default:"remote" choice:"remote" choice:"local" description:"Exit mode"` Verbose bool `short:"v" long:"verbose" description:"Show verbose debug information"` @@ -70,14 +71,14 @@ Description:` os.Exit(1) } - listener, err := netListener(opts, tnet) + listener, err := proxyListener(opts, tnet) if err != nil { logger.Errorf("Create net listener: %v", err) os.Exit(1) } socksListener, httpListener := proxymux.SplitSOCKSAndHTTP(listener) - dialer := netDialer(opts, tnet) + dialer := proxyDialer(opts, tnet) httpProxy := &http.Server{Handler: statsHandler(httpproxy.Handler(dialer), dev)} socksProxy := &socks5.Server{Dialer: dialer} @@ -99,7 +100,18 @@ Description:` os.Exit(1) } -func netListener(opts options, tnet *netstack.Net) (net.Listener, error) { +func proxyDialer(opts options, tnet *netstack.Net) (dialer func(ctx context.Context, network, address string) (net.Conn, error)) { + switch opts.ExitMode { + case "local": + d := net.Dialer{} + dialer = d.DialContext + case "remote": + dialer = tnet.DialContext + } + return +} + +func proxyListener(opts options, tnet *netstack.Net) (net.Listener, error) { var tcpListener net.Listener tcpAddr, err := net.ResolveTCPAddr("tcp", opts.Listen) @@ -123,17 +135,6 @@ func netListener(opts options, tnet *netstack.Net) (net.Listener, error) { return tcpListener, nil } -func netDialer(opts options, tnet *netstack.Net) (dialer func(ctx context.Context, network, address string) (net.Conn, error)) { - switch opts.ExitMode { - case "local": - var d net.Dialer - dialer = d.DialContext - case "remote": - dialer = tnet.DialContext - } - return -} - func setupNet(opts options) (*device.Device, *netstack.Net, error) { ips := []net.IP{} for _, s := range opts.ClientIPs { @@ -150,7 +151,6 @@ func setupNet(opts options) (*device.Device, *netstack.Net, error) { return nil, nil, fmt.Errorf("invalid dns ip: %s", s) } dnsServers = append(dnsServers, ip) - } tun, tnet, err := netstack.CreateNetTUN(ips, dnsServers, opts.MTU) if err != nil {