From d904433c8f75f5d0bc20b6108f44f1e9d0b84249 Mon Sep 17 00:00:00 2001 From: Shengjing Zhu Date: Wed, 27 Jul 2022 01:31:36 +0800 Subject: [PATCH] refactor: use UnmarshalFlag interface to parse options Updates: #4 --- conf.go | 54 +++++++++++-------------------------- main.go | 45 +++++-------------------------- option.go | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 77 deletions(-) create mode 100644 option.go diff --git a/conf.go b/conf.go index d8c64f0..8de99fd 100644 --- a/conf.go +++ b/conf.go @@ -2,12 +2,9 @@ package main import ( "context" - "encoding/base64" - "encoding/hex" "fmt" "net" "net/netip" - "strconv" "time" "github.com/zhsj/wghttp/internal/resolver" @@ -17,8 +14,8 @@ import ( type peer struct { resolver *resolver.Resolver - pubKey string - psk string + pubKey keyT + psk keyT host string ip netip.Addr @@ -26,30 +23,13 @@ type peer struct { } func newPeerEndpoint() (*peer, error) { - pubKey, err := base64.StdEncoding.DecodeString(opts.PeerKey) - if err != nil { - return nil, fmt.Errorf("parse peer public key: %w", err) - } - psk, err := base64.StdEncoding.DecodeString(opts.PresharedKey) - if err != nil { - return nil, fmt.Errorf("parse preshared key: %w", err) - } - p := &peer{ - pubKey: hex.EncodeToString(pubKey), - psk: hex.EncodeToString(psk), + pubKey: opts.PeerKey, + psk: opts.PresharedKey, + host: opts.PeerEndpoint.host, + port: opts.PeerEndpoint.port, } - host, port, err := net.SplitHostPort(opts.PeerEndpoint) - if err != nil { - return nil, fmt.Errorf("parse peer endpoint: %w", err) - } - port16, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return nil, fmt.Errorf("parse peer endpoint port: %w", err) - } - p.host = host - p.port = uint16(port16) - + var err error p.ip, err = netip.ParseAddr(p.host) if err == nil { return p, nil @@ -73,16 +53,16 @@ func newPeerEndpoint() (*peer, error) { } func (p *peer) initConf() string { - conf := "public_key=" + p.pubKey + "\n" - conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n" + conf := fmt.Sprintf("public_key=%s\n", p.pubKey) + conf += fmt.Sprintf("endpoint=%s\n", netip.AddrPortFrom(p.ip, p.port)) conf += "allowed_ip=0.0.0.0/0\n" conf += "allowed_ip=::/0\n" if opts.KeepaliveInterval > 0 { - conf += fmt.Sprintf("persistent_keepalive_interval=%.f\n", opts.KeepaliveInterval.Seconds()) + conf += fmt.Sprintf("persistent_keepalive_interval=%d\n", opts.KeepaliveInterval) } if p.psk != "" { - conf += "preshared_key=" + p.psk + "\n" + conf += fmt.Sprintf("preshared_key=%s\n", p.psk) } return conf @@ -100,9 +80,9 @@ func (p *peer) updateConf() (string, bool) { p.ip = newIP logger.Verbosef("PeerEndpoint is changed to: %s", p.ip) - conf := "public_key=" + p.pubKey + "\n" + conf := fmt.Sprintf("public_key=%s\n", p.pubKey) conf += "update_only=true\n" - conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n" + conf += fmt.Sprintf("endpoint=%s\n", netip.AddrPortFrom(p.ip, p.port)) return conf, true } @@ -126,11 +106,7 @@ func (p *peer) resolveHost() (netip.Addr, error) { } func ipcSet(dev *device.Device) error { - privateKey, err := base64.StdEncoding.DecodeString(opts.PrivateKey) - if err != nil { - return fmt.Errorf("parse client private key: %w", err) - } - conf := "private_key=" + hex.EncodeToString(privateKey) + "\n" + conf := fmt.Sprintf("private_key=%s\n", opts.PrivateKey) if opts.ClientPort != 0 { conf += fmt.Sprintf("listen_port=%d\n", opts.ClientPort) } @@ -148,7 +124,7 @@ func ipcSet(dev *device.Device) error { if peer.resolver != nil { go func() { - c := time.Tick(opts.ResolveInterval) + c := time.Tick(time.Duration(opts.ResolveInterval) * time.Second) for range c { conf, needUpdate := peer.updateConf() diff --git a/main.go b/main.go index 9db2f22..d8b5ba1 100644 --- a/main.go +++ b/main.go @@ -4,12 +4,12 @@ import ( "bufio" "context" _ "embed" + "errors" "fmt" "net" "net/netip" "os" "strings" - "time" "github.com/jessevdk/go-flags" "golang.zx2c4.com/wireguard/device" @@ -26,28 +26,6 @@ var ( opts options ) -type options struct { - ClientIPs []string `long:"client-ip" env:"CLIENT_IP" env-delim:"," description:"[Interface].Address\tfor WireGuard client (can be set multiple times)"` - ClientPort int `long:"client-port" env:"CLIENT_PORT" description:"[Interface].ListenPort\tfor WireGuard client (optional)"` - PrivateKey string `long:"private-key" env:"PRIVATE_KEY" description:"[Interface].PrivateKey\tfor WireGuard client (format: base64)"` - DNS string `long:"dns" env:"DNS" description:"[Interface].DNS\tfor WireGuard network (format: protocol://ip:port)\nProtocol includes udp(default), tcp, tls(DNS over TLS) and https(DNS over HTTPS)"` - MTU int `long:"mtu" env:"MTU" default:"1280" description:"[Interface].MTU\tfor WireGuard network"` - - PeerEndpoint string `long:"peer-endpoint" env:"PEER_ENDPOINT" description:"[Peer].Endpoint\tfor WireGuard server (format: host:port)"` - PeerKey string `long:"peer-key" env:"PEER_KEY" description:"[Peer].PublicKey\tfor WireGuard server (format: base64)"` - PresharedKey string `long:"preshared-key" env:"PRESHARED_KEY" description:"[Peer].PresharedKey\tfor WireGuard network (optional, format: base64)"` - KeepaliveInterval time.Duration `long:"keepalive-interval" env:"KEEPALIVE_INTERVAL" description:"[Peer].PersistentKeepalive\tfor WireGuard network (optional)"` - - ResolveDNS string `long:"resolve-dns" env:"RESOLVE_DNS" description:"DNS for resolving WireGuard server address (optional, format: protocol://ip:port)\nProtocol includes udp(default), tcp, tls(DNS over TLS) and https(DNS over HTTPS)"` - ResolveInterval time.Duration `long:"resolve-interval" env:"RESOLVE_INTERVAL" default:"1m" description:"Interval for resolving WireGuard server address (set 0 to disable)"` - - Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 server address"` - ExitMode string `long:"exit-mode" env:"EXIT_MODE" choice:"remote" choice:"local" default:"remote" description:"Exit mode"` - Verbose bool `short:"v" long:"verbose" description:"Show verbose debug information"` - - ClientID string `long:"client-id" env:"CLIENT_ID" hidden:"true"` -} - func main() { parser := flags.NewParser(&opts, flags.Default) parser.Usage = `[OPTIONS] @@ -60,10 +38,9 @@ Description:` parser.Usage = strings.TrimSuffix(parser.Usage, "\n") if _, err := parser.Parse(); err != nil { code := 1 - if fe, ok := err.(*flags.Error); ok { - if fe.Type == flags.ErrHelp { - code = 0 - } + fe := &flags.Error{} + if errors.As(err, &fe) && fe.Type == flags.ErrHelp { + code = 0 } os.Exit(code) } @@ -130,18 +107,10 @@ func proxyListener(tnet *netstack.Net) (net.Listener, error) { } func setupNet() (*device.Device, *netstack.Net, error) { - if len(opts.ClientIPs) == 0 { - return nil, nil, fmt.Errorf("client IP is required") + clientIPs := []netip.Addr{} + for _, ip := range opts.ClientIPs { + clientIPs = append(clientIPs, netip.Addr(ip)) } - var clientIPs []netip.Addr - for _, s := range opts.ClientIPs { - ip, err := netip.ParseAddr(s) - if err != nil { - return nil, nil, fmt.Errorf("parse client IP: %w", err) - } - clientIPs = append(clientIPs, ip) - } - tun, tnet, err := netstack.CreateNetTUN(clientIPs, nil, opts.MTU) if err != nil { return nil, nil, fmt.Errorf("create netstack tun: %w", err) diff --git a/option.go b/option.go new file mode 100644 index 0000000..26295df --- /dev/null +++ b/option.go @@ -0,0 +1,80 @@ +package main + +import ( + "encoding/base64" + "encoding/hex" + "net" + "net/netip" + "strconv" + "time" +) + +type ipT netip.Addr + +func (o *ipT) UnmarshalFlag(value string) error { + ip, err := netip.ParseAddr(value) + *o = ipT(ip) + return err +} + +func (o ipT) String() string { + return netip.Addr(o).String() +} + +type hostPortT struct { + host string + port uint16 +} + +func (o *hostPortT) UnmarshalFlag(value string) error { + host, port, err := net.SplitHostPort(value) + if err != nil { + return err + } + port16, err := strconv.ParseUint(port, 10, 16) + *o = hostPortT{host, uint16(port16)} + return err +} + +type keyT string + +func (o *keyT) UnmarshalFlag(value string) error { + key, err := base64.StdEncoding.DecodeString(value) + *o = keyT(hex.EncodeToString(key)) + return err +} + +type timeT int64 + +func (o *timeT) UnmarshalFlag(value string) error { + i, err := strconv.ParseInt(value, 10, 32) + if err == nil { + *o = timeT(i) + return nil + } + d, err := time.ParseDuration(value) + *o = timeT(d.Seconds()) + return err +} + +type options struct { + ClientIPs []ipT `long:"client-ip" env:"CLIENT_IP" env-delim:"," required:"true" description:"[Interface].Address\tfor WireGuard client (can be set multiple times)"` + ClientPort int `long:"client-port" env:"CLIENT_PORT" description:"[Interface].ListenPort\tfor WireGuard client (optional)"` + PrivateKey keyT `long:"private-key" env:"PRIVATE_KEY" required:"true" description:"[Interface].PrivateKey\tfor WireGuard client (format: base64)"` + DNS string `long:"dns" env:"DNS" description:"[Interface].DNS\tfor WireGuard network (format: protocol://ip:port)\nProtocol includes udp(default), tcp, tls(DNS over TLS) and https(DNS over HTTPS)"` + MTU int `long:"mtu" env:"MTU" default:"1280" description:"[Interface].MTU\tfor WireGuard network"` + + PeerEndpoint hostPortT `long:"peer-endpoint" env:"PEER_ENDPOINT" required:"true" description:"[Peer].Endpoint\tfor WireGuard server (format: host:port)"` + PeerKey keyT `long:"peer-key" env:"PEER_KEY" required:"true" description:"[Peer].PublicKey\tfor WireGuard server (format: base64)"` + PresharedKey keyT `long:"preshared-key" env:"PRESHARED_KEY" description:"[Peer].PresharedKey\tfor WireGuard network (optional, format: base64)"` + KeepaliveInterval timeT `long:"keepalive-interval" env:"KEEPALIVE_INTERVAL" description:"[Peer].PersistentKeepalive\tfor WireGuard network (optional)"` + + ResolveDNS string `long:"resolve-dns" env:"RESOLVE_DNS" description:"DNS for resolving WireGuard server address (optional, format: protocol://ip:port)\nProtocol includes udp(default), tcp, tls(DNS over TLS) and https(DNS over HTTPS)"` + ResolveInterval timeT `long:"resolve-interval" env:"RESOLVE_INTERVAL" default:"1m" description:"Interval for resolving WireGuard server address (set 0 to disable)"` + + Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 server address"` + ExitMode string `long:"exit-mode" env:"EXIT_MODE" choice:"remote" choice:"local" default:"remote" description:"Exit mode"` + Verbose bool `short:"v" long:"verbose" description:"Show verbose debug information"` + + ClientID string `long:"client-id" env:"CLIENT_ID" hidden:"true"` +}