162 lines
3.5 KiB
Go
162 lines
3.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
_ "embed"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/jessevdk/go-flags"
|
|
"golang.zx2c4.com/wireguard/device"
|
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
|
|
"github.com/zhsj/wghttp/internal/proxy"
|
|
)
|
|
|
|
//go:embed README.md
|
|
var readme string
|
|
|
|
var (
|
|
logger *device.Logger
|
|
opts options
|
|
)
|
|
|
|
func main() {
|
|
parser := flags.NewParser(&opts, flags.Default)
|
|
parser.LongDescription = fmt.Sprintf("wghttp %s\n\n", version())
|
|
parser.LongDescription += strings.Trim(strings.TrimPrefix(readme, "# wghttp"), "\n")
|
|
if _, err := parser.Parse(); err != nil {
|
|
code := 1
|
|
fe := &flags.Error{}
|
|
if errors.As(err, &fe) && fe.Type == flags.ErrHelp {
|
|
code = 0
|
|
}
|
|
os.Exit(code)
|
|
}
|
|
if opts.Verbose {
|
|
logger = device.NewLogger(device.LogLevelVerbose, "")
|
|
} else {
|
|
logger = device.NewLogger(device.LogLevelError, "")
|
|
}
|
|
logger.Verbosef("Options: %+v", opts)
|
|
|
|
dev, tnet, err := setupNet()
|
|
if err != nil {
|
|
logger.Errorf("Setup netstack: %v", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
listener, err := proxyListener(tnet)
|
|
if err != nil {
|
|
logger.Errorf("Create net listener: %v", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
proxier := proxy.Proxy{
|
|
Dial: proxyDialer(tnet),
|
|
DNS: opts.DNS,
|
|
Stats: stats(dev),
|
|
}
|
|
if opts.TLSKey != "" {
|
|
proxier.ServeTLS(listener, opts.TLSCert, opts.TLSKey)
|
|
} else {
|
|
proxier.Serve(listener)
|
|
}
|
|
|
|
os.Exit(1)
|
|
}
|
|
|
|
func proxyDialer(tnet *netstack.Net) (dialer func(ctx context.Context, network, address string) (net.Conn, error)) {
|
|
switch opts.ExitMode {
|
|
case "local":
|
|
d := net.Dialer{}
|
|
dialer = d.DialContext
|
|
case "remote":
|
|
dialer = tnet.DialContext
|
|
}
|
|
return
|
|
}
|
|
|
|
func proxyListener(tnet *netstack.Net) (net.Listener, error) {
|
|
var tcpListener net.Listener
|
|
|
|
tcpAddr, err := net.ResolveTCPAddr("tcp", opts.Listen)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve listen addr: %w", err)
|
|
}
|
|
|
|
switch opts.ExitMode {
|
|
case "local":
|
|
tcpListener, err = tnet.ListenTCP(tcpAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create listener on netstack: %w", err)
|
|
}
|
|
case "remote":
|
|
tcpListener, err = net.ListenTCP("tcp", tcpAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create listener on local net: %w", err)
|
|
}
|
|
}
|
|
logger.Verbosef("Listening on %s", tcpListener.Addr())
|
|
return tcpListener, nil
|
|
}
|
|
|
|
type peekingNetListener struct {
|
|
net.Listener
|
|
}
|
|
|
|
type peekingNetConn struct {
|
|
net.Conn
|
|
}
|
|
|
|
func (peek peekingNetListener) Accept() (net.Conn, error) {
|
|
con, err := peek.Listener.Accept()
|
|
return peekingNetConn{Conn: con}, err
|
|
}
|
|
|
|
func (peek peekingNetConn) Read(b []byte) (int, error) {
|
|
n, err := peek.Conn.Read(b)
|
|
log.Printf("net.Conn.Read: (%d) %q", n, b[:n])
|
|
return n, err
|
|
}
|
|
|
|
func listenerToTLS(ln net.Listener, certFile, pemFile string) (net.Listener, error) {
|
|
ln = peekingNetListener{Listener: ln}
|
|
|
|
cert, err := tls.LoadX509KeyPair(certFile, pemFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
config := &tls.Config{Certificates: []tls.Certificate{cert}}
|
|
|
|
return peekingNetListener{tls.NewListener(ln, config)}, nil
|
|
}
|
|
|
|
func setupNet() (*device.Device, *netstack.Net, error) {
|
|
clientIPs := []netip.Addr{}
|
|
for _, ip := range opts.ClientIPs {
|
|
clientIPs = append(clientIPs, netip.Addr(ip))
|
|
}
|
|
tun, tnet, err := netstack.CreateNetTUN(clientIPs, nil, opts.MTU)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("create netstack tun: %w", err)
|
|
}
|
|
dev := device.NewDevice(tun, newConnBind(opts.ClientID), logger)
|
|
|
|
if err := ipcSet(dev); err != nil {
|
|
return nil, nil, fmt.Errorf("config device: %w", err)
|
|
}
|
|
|
|
if err := dev.Up(); err != nil {
|
|
return nil, nil, fmt.Errorf("bring up device: %w", err)
|
|
}
|
|
|
|
return dev, tnet, nil
|
|
}
|