package main import ( "context" "crypto/tls" _ "embed" "errors" "fmt" "log" "net" "net/netip" "os" "strings" "github.com/jessevdk/go-flags" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" "github.com/zhsj/wghttp/internal/proxy" ) //go:embed README.md var readme string var ( logger *device.Logger opts options ) func main() { parser := flags.NewParser(&opts, flags.Default) parser.LongDescription = fmt.Sprintf("wghttp %s\n\n", version()) parser.LongDescription += strings.Trim(strings.TrimPrefix(readme, "# wghttp"), "\n") if _, err := parser.Parse(); err != nil { code := 1 fe := &flags.Error{} if errors.As(err, &fe) && fe.Type == flags.ErrHelp { code = 0 } os.Exit(code) } if opts.Verbose { logger = device.NewLogger(device.LogLevelVerbose, "") } else { logger = device.NewLogger(device.LogLevelError, "") } logger.Verbosef("Options: %+v", opts) dev, tnet, err := setupNet() if err != nil { logger.Errorf("Setup netstack: %v", err) os.Exit(1) } listener, err := proxyListener(tnet) if err != nil { logger.Errorf("Create net listener: %v", err) os.Exit(1) } proxier := proxy.Proxy{ Dial: proxyDialer(tnet), DNS: opts.DNS, Stats: stats(dev), } if opts.TLSKey != "" { proxier.ServeTLS(listener, opts.TLSCert, opts.TLSKey) } else { proxier.Serve(listener) } os.Exit(1) } func proxyDialer(tnet *netstack.Net) (dialer func(ctx context.Context, network, address string) (net.Conn, error)) { switch opts.ExitMode { case "local": d := net.Dialer{} dialer = d.DialContext case "remote": dialer = tnet.DialContext } return } func proxyListener(tnet *netstack.Net) (net.Listener, error) { var tcpListener net.Listener tcpAddr, err := net.ResolveTCPAddr("tcp", opts.Listen) if err != nil { return nil, fmt.Errorf("resolve listen addr: %w", err) } switch opts.ExitMode { case "local": tcpListener, err = tnet.ListenTCP(tcpAddr) if err != nil { return nil, fmt.Errorf("create listener on netstack: %w", err) } case "remote": tcpListener, err = net.ListenTCP("tcp", tcpAddr) if err != nil { return nil, fmt.Errorf("create listener on local net: %w", err) } } logger.Verbosef("Listening on %s", tcpListener.Addr()) return tcpListener, nil } type peekingNetListener struct { net.Listener } type peekingNetConn struct { net.Conn } func (peek peekingNetListener) Accept() (net.Conn, error) { con, err := peek.Listener.Accept() return peekingNetConn{Conn: con}, err } func (peek peekingNetConn) Read(b []byte) (int, error) { n, err := peek.Conn.Read(b) log.Printf("net.Conn.Read: (%d) %q", n, b[:n]) return n, err } func listenerToTLS(ln net.Listener, certFile, pemFile string) (net.Listener, error) { ln = peekingNetListener{Listener: ln} cert, err := tls.LoadX509KeyPair(certFile, pemFile) if err != nil { return nil, err } config := &tls.Config{Certificates: []tls.Certificate{cert}} return peekingNetListener{tls.NewListener(ln, config)}, nil } func setupNet() (*device.Device, *netstack.Net, error) { clientIPs := []netip.Addr{} for _, ip := range opts.ClientIPs { clientIPs = append(clientIPs, netip.Addr(ip)) } tun, tnet, err := netstack.CreateNetTUN(clientIPs, nil, opts.MTU) if err != nil { return nil, nil, fmt.Errorf("create netstack tun: %w", err) } dev := device.NewDevice(tun, newConnBind(opts.ClientID), logger) if err := ipcSet(dev); err != nil { return nil, nil, fmt.Errorf("config device: %w", err) } if err := dev.Up(); err != nil { return nil, nil, fmt.Errorf("bring up device: %w", err) } return dev, tnet, nil }