feat: support override reserved header section

master
Shengjing Zhu 2022-03-06 01:59:34 +08:00
parent 6fa6690285
commit 16e3b4adf1
1 changed files with 57 additions and 6 deletions

63
main.go
View File

@ -26,6 +26,8 @@ type options struct {
Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP proxy server listen address"`
DNS []string `long:"dns" env:"DNS" env-delim:"," default:"1.0.0.1" description:"DNS server IP address"`
MTU int `long:"mtu" env:"MTU" default:"1280" description:"MTU"`
ClientID string `long:"client-id" env:"CLIENT_ID" hidden:"true"`
}
func main() {
@ -46,7 +48,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
log.Printf("Listening on %s", ln.Addr())
log.Printf("listening on %s", ln.Addr())
socksListener, httpListener := proxymux.SplitSOCKSAndHTTP(ln)
@ -71,11 +73,11 @@ func main() {
func setupNet(opts options) *netstack.Net {
privateKey, err := base64.StdEncoding.DecodeString(opts.PrivateKey)
if err != nil {
log.Fatalf("Parse private key: %v", err)
log.Fatalf("parse private key: %v", err)
}
peerKey, err := base64.StdEncoding.DecodeString(opts.PeerKey)
if err != nil {
log.Fatalf("Parse peer key: %v", err)
log.Fatalf("parse peer key: %v", err)
}
conf := "private_key=" + hex.EncodeToString(privateKey) + "\n"
conf += "public_key=" + hex.EncodeToString(peerKey) + "\n"
@ -101,14 +103,63 @@ func setupNet(opts options) *netstack.Net {
}
tun, tnet, err := netstack.CreateNetTUN(ips, dnsServers, opts.MTU)
if err != nil {
log.Fatal(err)
log.Fatalf("create netstack tun: %v", err)
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelError, ""))
dev := device.NewDevice(tun, newConnBind(opts.ClientID), device.NewLogger(device.LogLevelError, ""))
if err := dev.IpcSet(conf); err != nil {
log.Fatal(err)
}
if err := dev.Up(); err != nil {
log.Fatal(err)
log.Fatalf("bring up device: %v", err)
}
return tnet
}
type connBind struct {
// magic 3 bytes in wireguard header reserved section.
clientID []uint8
defaultBind conn.Bind
}
func newConnBind(clientID string) conn.Bind {
defaultBind := conn.NewDefaultBind()
if clientID == "" {
return defaultBind
}
parsed, err := base64.StdEncoding.DecodeString(clientID)
if err != nil {
log.Fatalf("parse client id: %v", err)
}
return &connBind{clientID: parsed, defaultBind: defaultBind}
}
func (c *connBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
fns, actualPort, err := c.defaultBind.Open(port)
newFNs := make([]conn.ReceiveFunc, 0, len(fns))
for i := range fns {
f := fns[i]
newFNs = append(newFNs, func(b []byte) (n int, ep conn.Endpoint, err error) {
n, ep, err = f(b)
if len(b) > 4 {
copy(b[1:4], []byte{0, 0, 0})
}
return
})
}
return newFNs, actualPort, err
}
func (c *connBind) Close() error { return c.defaultBind.Close() }
func (c *connBind) SetMark(mark uint32) error { return c.defaultBind.SetMark(mark) }
func (c *connBind) Send(b []byte, ep conn.Endpoint) error {
if len(b) > 4 {
copy(b[1:4], c.clientID)
}
return c.defaultBind.Send(b, ep)
}
func (c *connBind) ParseEndpoint(s string) (conn.Endpoint, error) {
return c.defaultBind.ParseEndpoint(s)
}