diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9ade5d9 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +wghttp diff --git a/conf.go b/conf.go new file mode 100644 index 0000000..3277de1 --- /dev/null +++ b/conf.go @@ -0,0 +1,69 @@ +package main + +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "net" + "time" + + "golang.zx2c4.com/wireguard/device" +) + +func ipcSet(dev *device.Device, opts options) error { + privateKey, err := base64.StdEncoding.DecodeString(opts.PrivateKey) + if err != nil { + return fmt.Errorf("parse private key: %w", err) + } + peerKey, err := base64.StdEncoding.DecodeString(opts.PeerKey) + if err != nil { + return fmt.Errorf("parse peer key: %w", err) + } + conf := "private_key=" + hex.EncodeToString(privateKey) + "\n" + conf += "public_key=" + hex.EncodeToString(peerKey) + "\n" + + peerAddr, err := net.ResolveUDPAddr("udp", opts.PeerEndpoint) + if err != nil { + return fmt.Errorf("resolve peer endpoint: %w", err) + } + + conf += "endpoint=" + peerAddr.String() + "\n" + conf += "allowed_ip=0.0.0.0/0\n" + conf += "allowed_ip=::/0\n" + + if opts.ExitMode == "local" { + conf += "persistent_keepalive_interval=10\n" + } + + if err := dev.IpcSet(conf); err != nil { + return fmt.Errorf("set device config: %w", err) + } + + if peerAddr.String() != opts.PeerEndpoint { + go refreshEndpoint(dev, peerKey, peerAddr.String(), opts.PeerEndpoint) + } + return nil +} + +func refreshEndpoint(dev *device.Device, peerKey []byte, currentPeerAddr, peerEndpoint string) { + c := time.Tick(10 * time.Second) + + for range c { + addr, err := net.ResolveUDPAddr("udp", peerEndpoint) + if err != nil { + logger.Errorf("Resolve peer endpoint: %v", err) + continue + } + if currentPeerAddr == addr.String() { + continue + } + currentPeerAddr = addr.String() + logger.Verbosef("Endpoint is changed to: %s", addr) + conf := "public_key=" + hex.EncodeToString(peerKey) + "\n" + conf += "update_only=true\n" + conf += "endpoint=" + addr.String() + "\n" + if err := dev.IpcSet(conf); err != nil { + logger.Errorf("Set device config: %v", err) + } + } +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..123ae6e --- /dev/null +++ b/conn.go @@ -0,0 +1,57 @@ +package main + +import ( + "encoding/base64" + + "golang.zx2c4.com/wireguard/conn" +) + +type connBind struct { + // magic 3 bytes in wireguard header reserved section. + clientID []uint8 + defaultBind conn.Bind +} + +func newConnBind(clientID string) conn.Bind { + defaultBind := conn.NewDefaultBind() + if clientID == "" { + return defaultBind + } + parsed, err := base64.StdEncoding.DecodeString(clientID) + if err != nil { + logger.Errorf("Invalid client id: %v, fallback to default", err) + return defaultBind + } + return &connBind{clientID: parsed, defaultBind: defaultBind} +} + +func (c *connBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + fns, actualPort, err := c.defaultBind.Open(port) + newFNs := make([]conn.ReceiveFunc, 0, len(fns)) + for i := range fns { + f := fns[i] + newFNs = append(newFNs, func(b []byte) (n int, ep conn.Endpoint, err error) { + n, ep, err = f(b) + if len(b) > 4 { + copy(b[1:4], []byte{0, 0, 0}) + } + return + }) + } + return newFNs, actualPort, err +} + +func (c *connBind) Close() error { return c.defaultBind.Close() } + +func (c *connBind) SetMark(mark uint32) error { return c.defaultBind.SetMark(mark) } + +func (c *connBind) Send(b []byte, ep conn.Endpoint) error { + if len(b) > 4 { + copy(b[1:4], c.clientID) + } + return c.defaultBind.Send(b, ep) +} + +func (c *connBind) ParseEndpoint(s string) (conn.Endpoint, error) { + return c.defaultBind.ParseEndpoint(s) +} diff --git a/main.go b/main.go index d98416c..fd0ac75 100644 --- a/main.go +++ b/main.go @@ -1,15 +1,13 @@ package main import ( - "encoding/base64" - "encoding/hex" - "log" + "context" + "fmt" "net" "net/http" "os" "github.com/jessevdk/go-flags" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -18,14 +16,18 @@ import ( "github.com/zhsj/wghttp/internal/third_party/tailscale/socks5" ) +var logger *device.Logger + type options struct { PeerEndpoint string `long:"peer-endpoint" env:"PEER_ENDPOINT" required:"true" description:"WireGuard server address"` PeerKey string `long:"peer-key" env:"PEER_KEY" required:"true" description:"WireGuard server public key in base64 format"` PrivateKey string `long:"private-key" env:"PRIVATE_KEY" required:"true" description:"WireGuard client private key in base64 format"` ClientIPs []string `long:"client-ip" env:"CLIENT_IP" env-delim:"," required:"true" description:"WireGuard client IP address"` - Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 proxy server listen address"` - DNS []string `long:"dns" env:"DNS" env-delim:"," default:"1.0.0.1" description:"DNS server IP address"` + Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 server address"` + DNS []string `long:"dns" env:"DNS" env-delim:"," default:"" description:"DNS server IP address, only used WireGuard"` MTU int `long:"mtu" env:"MTU" default:"1280" description:"MTU"` + ExitMode string `long:"exit-mode" env:"EXIT_MODE" default:"remote" choice:"remote" choice:"local" description:"Exit mode"` + Verbose bool `short:"v" long:"verbose" description:"Show verbose debug information"` ClientID string `long:"client-id" env:"CLIENT_ID" hidden:"true"` } @@ -42,53 +44,88 @@ func main() { } os.Exit(code) } - - tnet := setupNet(opts) - ln, err := net.Listen("tcp", opts.Listen) - if err != nil { - log.Fatal(err) + if opts.Verbose { + logger = device.NewLogger(device.LogLevelVerbose, "") + } else { + logger = device.NewLogger(device.LogLevelError, "") } - log.Printf("listening on %s", ln.Addr()) - socksListener, httpListener := proxymux.SplitSOCKSAndHTTP(ln) + dev, tnet, err := setupNet(opts) + if err != nil { + logger.Errorf("Setup netstack: %v", err) + os.Exit(1) + } - httpProxy := &http.Server{Handler: httpproxy.Handler(tnet.DialContext)} - socksProxy := &socks5.Server{Dialer: tnet.DialContext} + listener, err := netListener(opts, tnet) + if err != nil { + logger.Errorf("Create net listener: %v", err) + os.Exit(1) + } + + socksListener, httpListener := proxymux.SplitSOCKSAndHTTP(listener) + dialer := netDialer(opts, tnet) + + httpProxy := &http.Server{Handler: statsHandler(httpproxy.Handler(dialer), dev)} + socksProxy := &socks5.Server{Dialer: dialer} errc := make(chan error, 2) go func() { if err := httpProxy.Serve(httpListener); err != nil { + logger.Errorf("Serving http proxy: %v", err) errc <- err } }() go func() { if err := socksProxy.Serve(socksListener); err != nil { + logger.Errorf("Serving socks5 proxy: %v", err) errc <- err } }() - - log.Fatal(<-errc) + <-errc + os.Exit(1) } -func setupNet(opts options) *netstack.Net { - privateKey, err := base64.StdEncoding.DecodeString(opts.PrivateKey) +func netListener(opts options, tnet *netstack.Net) (net.Listener, error) { + var tcpListener net.Listener + + tcpAddr, err := net.ResolveTCPAddr("tcp", opts.Listen) if err != nil { - log.Fatalf("parse private key: %v", err) + return nil, fmt.Errorf("resolve listen addr: %w", err) } - peerKey, err := base64.StdEncoding.DecodeString(opts.PeerKey) - if err != nil { - log.Fatalf("parse peer key: %v", err) + + switch opts.ExitMode { + case "local": + tcpListener, err = tnet.ListenTCP(tcpAddr) + if err != nil { + return nil, fmt.Errorf("create listener on netstack: %w", err) + } + case "remote": + tcpListener, err = net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, fmt.Errorf("create listener on local net: %w", err) + } } - conf := "private_key=" + hex.EncodeToString(privateKey) + "\n" - conf += "public_key=" + hex.EncodeToString(peerKey) + "\n" - conf += "endpoint=" + opts.PeerEndpoint + "\n" - conf += "allowed_ip=0.0.0.0/0\n" - conf += "allowed_ip=::/0\n" + logger.Verbosef("Listening on %s", tcpListener.Addr()) + return tcpListener, nil +} + +func netDialer(opts options, tnet *netstack.Net) (dialer func(ctx context.Context, network, address string) (net.Conn, error)) { + switch opts.ExitMode { + case "local": + var d net.Dialer + dialer = d.DialContext + case "remote": + dialer = tnet.DialContext + } + return +} + +func setupNet(opts options) (*device.Device, *netstack.Net, error) { ips := []net.IP{} for _, s := range opts.ClientIPs { ip := net.ParseIP(s) if ip == nil { - log.Fatalf("invalid local ip: %s", s) + return nil, nil, fmt.Errorf("invalid client ip: %s", s) } ips = append(ips, ip) } @@ -96,70 +133,24 @@ func setupNet(opts options) *netstack.Net { for _, s := range opts.DNS { ip := net.ParseIP(s) if ip == nil { - log.Fatalf("invalid dns ip: %s", s) + return nil, nil, fmt.Errorf("invalid dns ip: %s", s) } dnsServers = append(dnsServers, ip) } tun, tnet, err := netstack.CreateNetTUN(ips, dnsServers, opts.MTU) if err != nil { - log.Fatalf("create netstack tun: %v", err) + return nil, nil, fmt.Errorf("create netstack tun: %w", err) } - dev := device.NewDevice(tun, newConnBind(opts.ClientID), device.NewLogger(device.LogLevelError, "")) - if err := dev.IpcSet(conf); err != nil { - log.Fatal(err) + dev := device.NewDevice(tun, newConnBind(opts.ClientID), logger) + + if err := ipcSet(dev, opts); err != nil { + return nil, nil, fmt.Errorf("config device: %w", err) } + if err := dev.Up(); err != nil { - log.Fatalf("bring up device: %v", err) + return nil, nil, fmt.Errorf("bring up device: %w", err) } - return tnet -} - -type connBind struct { - // magic 3 bytes in wireguard header reserved section. - clientID []uint8 - defaultBind conn.Bind -} - -func newConnBind(clientID string) conn.Bind { - defaultBind := conn.NewDefaultBind() - if clientID == "" { - return defaultBind - } - parsed, err := base64.StdEncoding.DecodeString(clientID) - if err != nil { - log.Fatalf("parse client id: %v", err) - } - return &connBind{clientID: parsed, defaultBind: defaultBind} -} - -func (c *connBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { - fns, actualPort, err := c.defaultBind.Open(port) - newFNs := make([]conn.ReceiveFunc, 0, len(fns)) - for i := range fns { - f := fns[i] - newFNs = append(newFNs, func(b []byte) (n int, ep conn.Endpoint, err error) { - n, ep, err = f(b) - if len(b) > 4 { - copy(b[1:4], []byte{0, 0, 0}) - } - return - }) - } - return newFNs, actualPort, err -} - -func (c *connBind) Close() error { return c.defaultBind.Close() } - -func (c *connBind) SetMark(mark uint32) error { return c.defaultBind.SetMark(mark) } - -func (c *connBind) Send(b []byte, ep conn.Endpoint) error { - if len(b) > 4 { - copy(b[1:4], c.clientID) - } - return c.defaultBind.Send(b, ep) -} - -func (c *connBind) ParseEndpoint(s string) (conn.Endpoint, error) { - return c.defaultBind.ParseEndpoint(s) + + return dev, tnet, nil } diff --git a/stats.go b/stats.go new file mode 100644 index 0000000..6a86b39 --- /dev/null +++ b/stats.go @@ -0,0 +1,59 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "net/http" + "runtime" + "strconv" + "strings" + + "golang.zx2c4.com/wireguard/device" +) + +func statsHandler(next http.Handler, dev *device.Device) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.URL.Host != "" || r.URL.Path != "/stats" { + next.ServeHTTP(rw, r) + return + } + + stats := struct { + Endpoint string + LastHandshakeTimestamp int64 + ReceivedBytes int64 + SentBytes int64 + + NumGoroutine int + }{ + NumGoroutine: runtime.NumGoroutine(), + } + + var buf bytes.Buffer + if err := dev.IpcGetOperation(&buf); err != nil { + logger.Errorf("Get device config: %v", err) + rw.WriteHeader(http.StatusInternalServerError) + } else { + scanner := bufio.NewScanner(&buf) + for scanner.Scan() { + line := scanner.Text() + if prefix := "endpoint="; strings.HasPrefix(line, prefix) { + stats.Endpoint = strings.TrimPrefix(line, prefix) + } + if prefix := "last_handshake_time_sec="; strings.HasPrefix(line, prefix) { + stats.LastHandshakeTimestamp, _ = strconv.ParseInt(strings.TrimPrefix(line, prefix), 10, 64) + } + if prefix := "rx_bytes="; strings.HasPrefix(line, prefix) { + stats.ReceivedBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, prefix), 10, 64) + } + if prefix := "tx_bytes="; strings.HasPrefix(line, prefix) { + stats.SentBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, prefix), 10, 64) + } + } + resp, _ := json.MarshalIndent(stats, "", " ") + rw.Header().Set("Content-Type", "application/json") + rw.Write(append(resp, '\n')) + } + }) +}