From 16e3b4adf184d97be10f45acfa360b8aa5bb10ac Mon Sep 17 00:00:00 2001 From: Shengjing Zhu Date: Sun, 6 Mar 2022 01:59:34 +0800 Subject: [PATCH] feat: support override reserved header section --- main.go | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index 4d9b037..581fd7f 100644 --- a/main.go +++ b/main.go @@ -26,6 +26,8 @@ type options struct { Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP proxy server listen address"` DNS []string `long:"dns" env:"DNS" env-delim:"," default:"1.0.0.1" description:"DNS server IP address"` MTU int `long:"mtu" env:"MTU" default:"1280" description:"MTU"` + + ClientID string `long:"client-id" env:"CLIENT_ID" hidden:"true"` } func main() { @@ -46,7 +48,7 @@ func main() { if err != nil { log.Fatal(err) } - log.Printf("Listening on %s", ln.Addr()) + log.Printf("listening on %s", ln.Addr()) socksListener, httpListener := proxymux.SplitSOCKSAndHTTP(ln) @@ -71,11 +73,11 @@ func main() { func setupNet(opts options) *netstack.Net { privateKey, err := base64.StdEncoding.DecodeString(opts.PrivateKey) if err != nil { - log.Fatalf("Parse private key: %v", err) + log.Fatalf("parse private key: %v", err) } peerKey, err := base64.StdEncoding.DecodeString(opts.PeerKey) if err != nil { - log.Fatalf("Parse peer key: %v", err) + log.Fatalf("parse peer key: %v", err) } conf := "private_key=" + hex.EncodeToString(privateKey) + "\n" conf += "public_key=" + hex.EncodeToString(peerKey) + "\n" @@ -101,14 +103,63 @@ func setupNet(opts options) *netstack.Net { } tun, tnet, err := netstack.CreateNetTUN(ips, dnsServers, opts.MTU) if err != nil { - log.Fatal(err) + log.Fatalf("create netstack tun: %v", err) } - dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelError, "")) + dev := device.NewDevice(tun, newConnBind(opts.ClientID), device.NewLogger(device.LogLevelError, "")) if err := dev.IpcSet(conf); err != nil { log.Fatal(err) } if err := dev.Up(); err != nil { - log.Fatal(err) + log.Fatalf("bring up device: %v", err) } return tnet } + +type connBind struct { + // magic 3 bytes in wireguard header reserved section. + clientID []uint8 + defaultBind conn.Bind +} + +func newConnBind(clientID string) conn.Bind { + defaultBind := conn.NewDefaultBind() + if clientID == "" { + return defaultBind + } + parsed, err := base64.StdEncoding.DecodeString(clientID) + if err != nil { + log.Fatalf("parse client id: %v", err) + } + return &connBind{clientID: parsed, defaultBind: defaultBind} +} + +func (c *connBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + fns, actualPort, err := c.defaultBind.Open(port) + newFNs := make([]conn.ReceiveFunc, 0, len(fns)) + for i := range fns { + f := fns[i] + newFNs = append(newFNs, func(b []byte) (n int, ep conn.Endpoint, err error) { + n, ep, err = f(b) + if len(b) > 4 { + copy(b[1:4], []byte{0, 0, 0}) + } + return + }) + } + return newFNs, actualPort, err +} + +func (c *connBind) Close() error { return c.defaultBind.Close() } + +func (c *connBind) SetMark(mark uint32) error { return c.defaultBind.SetMark(mark) } + +func (c *connBind) Send(b []byte, ep conn.Endpoint) error { + if len(b) > 4 { + copy(b[1:4], c.clientID) + } + return c.defaultBind.Send(b, ep) +} + +func (c *connBind) ParseEndpoint(s string) (conn.Endpoint, error) { + return c.defaultBind.ParseEndpoint(s) +}