refactor: use UnmarshalFlag interface to parse options

Updates: #4
master
Shengjing Zhu 2022-07-27 01:31:36 +08:00
parent 56268e4ae6
commit d904433c8f
3 changed files with 102 additions and 77 deletions

54
conf.go
View File

@ -2,12 +2,9 @@ package main
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip"
"strconv"
"time"
"github.com/zhsj/wghttp/internal/resolver"
@ -17,8 +14,8 @@ import (
type peer struct {
resolver *resolver.Resolver
pubKey string
psk string
pubKey keyT
psk keyT
host string
ip netip.Addr
@ -26,30 +23,13 @@ type peer struct {
}
func newPeerEndpoint() (*peer, error) {
pubKey, err := base64.StdEncoding.DecodeString(opts.PeerKey)
if err != nil {
return nil, fmt.Errorf("parse peer public key: %w", err)
}
psk, err := base64.StdEncoding.DecodeString(opts.PresharedKey)
if err != nil {
return nil, fmt.Errorf("parse preshared key: %w", err)
}
p := &peer{
pubKey: hex.EncodeToString(pubKey),
psk: hex.EncodeToString(psk),
pubKey: opts.PeerKey,
psk: opts.PresharedKey,
host: opts.PeerEndpoint.host,
port: opts.PeerEndpoint.port,
}
host, port, err := net.SplitHostPort(opts.PeerEndpoint)
if err != nil {
return nil, fmt.Errorf("parse peer endpoint: %w", err)
}
port16, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, fmt.Errorf("parse peer endpoint port: %w", err)
}
p.host = host
p.port = uint16(port16)
var err error
p.ip, err = netip.ParseAddr(p.host)
if err == nil {
return p, nil
@ -73,16 +53,16 @@ func newPeerEndpoint() (*peer, error) {
}
func (p *peer) initConf() string {
conf := "public_key=" + p.pubKey + "\n"
conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n"
conf := fmt.Sprintf("public_key=%s\n", p.pubKey)
conf += fmt.Sprintf("endpoint=%s\n", netip.AddrPortFrom(p.ip, p.port))
conf += "allowed_ip=0.0.0.0/0\n"
conf += "allowed_ip=::/0\n"
if opts.KeepaliveInterval > 0 {
conf += fmt.Sprintf("persistent_keepalive_interval=%.f\n", opts.KeepaliveInterval.Seconds())
conf += fmt.Sprintf("persistent_keepalive_interval=%d\n", opts.KeepaliveInterval)
}
if p.psk != "" {
conf += "preshared_key=" + p.psk + "\n"
conf += fmt.Sprintf("preshared_key=%s\n", p.psk)
}
return conf
@ -100,9 +80,9 @@ func (p *peer) updateConf() (string, bool) {
p.ip = newIP
logger.Verbosef("PeerEndpoint is changed to: %s", p.ip)
conf := "public_key=" + p.pubKey + "\n"
conf := fmt.Sprintf("public_key=%s\n", p.pubKey)
conf += "update_only=true\n"
conf += "endpoint=" + netip.AddrPortFrom(p.ip, p.port).String() + "\n"
conf += fmt.Sprintf("endpoint=%s\n", netip.AddrPortFrom(p.ip, p.port))
return conf, true
}
@ -126,11 +106,7 @@ func (p *peer) resolveHost() (netip.Addr, error) {
}
func ipcSet(dev *device.Device) error {
privateKey, err := base64.StdEncoding.DecodeString(opts.PrivateKey)
if err != nil {
return fmt.Errorf("parse client private key: %w", err)
}
conf := "private_key=" + hex.EncodeToString(privateKey) + "\n"
conf := fmt.Sprintf("private_key=%s\n", opts.PrivateKey)
if opts.ClientPort != 0 {
conf += fmt.Sprintf("listen_port=%d\n", opts.ClientPort)
}
@ -148,7 +124,7 @@ func ipcSet(dev *device.Device) error {
if peer.resolver != nil {
go func() {
c := time.Tick(opts.ResolveInterval)
c := time.Tick(time.Duration(opts.ResolveInterval) * time.Second)
for range c {
conf, needUpdate := peer.updateConf()

45
main.go
View File

@ -4,12 +4,12 @@ import (
"bufio"
"context"
_ "embed"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strings"
"time"
"github.com/jessevdk/go-flags"
"golang.zx2c4.com/wireguard/device"
@ -26,28 +26,6 @@ var (
opts options
)
type options struct {
ClientIPs []string `long:"client-ip" env:"CLIENT_IP" env-delim:"," description:"[Interface].Address\tfor WireGuard client (can be set multiple times)"`
ClientPort int `long:"client-port" env:"CLIENT_PORT" description:"[Interface].ListenPort\tfor WireGuard client (optional)"`
PrivateKey string `long:"private-key" env:"PRIVATE_KEY" description:"[Interface].PrivateKey\tfor WireGuard client (format: base64)"`
DNS string `long:"dns" env:"DNS" description:"[Interface].DNS\tfor WireGuard network (format: protocol://ip:port)\nProtocol includes udp(default), tcp, tls(DNS over TLS) and https(DNS over HTTPS)"`
MTU int `long:"mtu" env:"MTU" default:"1280" description:"[Interface].MTU\tfor WireGuard network"`
PeerEndpoint string `long:"peer-endpoint" env:"PEER_ENDPOINT" description:"[Peer].Endpoint\tfor WireGuard server (format: host:port)"`
PeerKey string `long:"peer-key" env:"PEER_KEY" description:"[Peer].PublicKey\tfor WireGuard server (format: base64)"`
PresharedKey string `long:"preshared-key" env:"PRESHARED_KEY" description:"[Peer].PresharedKey\tfor WireGuard network (optional, format: base64)"`
KeepaliveInterval time.Duration `long:"keepalive-interval" env:"KEEPALIVE_INTERVAL" description:"[Peer].PersistentKeepalive\tfor WireGuard network (optional)"`
ResolveDNS string `long:"resolve-dns" env:"RESOLVE_DNS" description:"DNS for resolving WireGuard server address (optional, format: protocol://ip:port)\nProtocol includes udp(default), tcp, tls(DNS over TLS) and https(DNS over HTTPS)"`
ResolveInterval time.Duration `long:"resolve-interval" env:"RESOLVE_INTERVAL" default:"1m" description:"Interval for resolving WireGuard server address (set 0 to disable)"`
Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 server address"`
ExitMode string `long:"exit-mode" env:"EXIT_MODE" choice:"remote" choice:"local" default:"remote" description:"Exit mode"`
Verbose bool `short:"v" long:"verbose" description:"Show verbose debug information"`
ClientID string `long:"client-id" env:"CLIENT_ID" hidden:"true"`
}
func main() {
parser := flags.NewParser(&opts, flags.Default)
parser.Usage = `[OPTIONS]
@ -60,10 +38,9 @@ Description:`
parser.Usage = strings.TrimSuffix(parser.Usage, "\n")
if _, err := parser.Parse(); err != nil {
code := 1
if fe, ok := err.(*flags.Error); ok {
if fe.Type == flags.ErrHelp {
code = 0
}
fe := &flags.Error{}
if errors.As(err, &fe) && fe.Type == flags.ErrHelp {
code = 0
}
os.Exit(code)
}
@ -130,18 +107,10 @@ func proxyListener(tnet *netstack.Net) (net.Listener, error) {
}
func setupNet() (*device.Device, *netstack.Net, error) {
if len(opts.ClientIPs) == 0 {
return nil, nil, fmt.Errorf("client IP is required")
clientIPs := []netip.Addr{}
for _, ip := range opts.ClientIPs {
clientIPs = append(clientIPs, netip.Addr(ip))
}
var clientIPs []netip.Addr
for _, s := range opts.ClientIPs {
ip, err := netip.ParseAddr(s)
if err != nil {
return nil, nil, fmt.Errorf("parse client IP: %w", err)
}
clientIPs = append(clientIPs, ip)
}
tun, tnet, err := netstack.CreateNetTUN(clientIPs, nil, opts.MTU)
if err != nil {
return nil, nil, fmt.Errorf("create netstack tun: %w", err)

80
option.go Normal file
View File

@ -0,0 +1,80 @@
package main
import (
"encoding/base64"
"encoding/hex"
"net"
"net/netip"
"strconv"
"time"
)
type ipT netip.Addr
func (o *ipT) UnmarshalFlag(value string) error {
ip, err := netip.ParseAddr(value)
*o = ipT(ip)
return err
}
func (o ipT) String() string {
return netip.Addr(o).String()
}
type hostPortT struct {
host string
port uint16
}
func (o *hostPortT) UnmarshalFlag(value string) error {
host, port, err := net.SplitHostPort(value)
if err != nil {
return err
}
port16, err := strconv.ParseUint(port, 10, 16)
*o = hostPortT{host, uint16(port16)}
return err
}
type keyT string
func (o *keyT) UnmarshalFlag(value string) error {
key, err := base64.StdEncoding.DecodeString(value)
*o = keyT(hex.EncodeToString(key))
return err
}
type timeT int64
func (o *timeT) UnmarshalFlag(value string) error {
i, err := strconv.ParseInt(value, 10, 32)
if err == nil {
*o = timeT(i)
return nil
}
d, err := time.ParseDuration(value)
*o = timeT(d.Seconds())
return err
}
type options struct {
ClientIPs []ipT `long:"client-ip" env:"CLIENT_IP" env-delim:"," required:"true" description:"[Interface].Address\tfor WireGuard client (can be set multiple times)"`
ClientPort int `long:"client-port" env:"CLIENT_PORT" description:"[Interface].ListenPort\tfor WireGuard client (optional)"`
PrivateKey keyT `long:"private-key" env:"PRIVATE_KEY" required:"true" description:"[Interface].PrivateKey\tfor WireGuard client (format: base64)"`
DNS string `long:"dns" env:"DNS" description:"[Interface].DNS\tfor WireGuard network (format: protocol://ip:port)\nProtocol includes udp(default), tcp, tls(DNS over TLS) and https(DNS over HTTPS)"`
MTU int `long:"mtu" env:"MTU" default:"1280" description:"[Interface].MTU\tfor WireGuard network"`
PeerEndpoint hostPortT `long:"peer-endpoint" env:"PEER_ENDPOINT" required:"true" description:"[Peer].Endpoint\tfor WireGuard server (format: host:port)"`
PeerKey keyT `long:"peer-key" env:"PEER_KEY" required:"true" description:"[Peer].PublicKey\tfor WireGuard server (format: base64)"`
PresharedKey keyT `long:"preshared-key" env:"PRESHARED_KEY" description:"[Peer].PresharedKey\tfor WireGuard network (optional, format: base64)"`
KeepaliveInterval timeT `long:"keepalive-interval" env:"KEEPALIVE_INTERVAL" description:"[Peer].PersistentKeepalive\tfor WireGuard network (optional)"`
ResolveDNS string `long:"resolve-dns" env:"RESOLVE_DNS" description:"DNS for resolving WireGuard server address (optional, format: protocol://ip:port)\nProtocol includes udp(default), tcp, tls(DNS over TLS) and https(DNS over HTTPS)"`
ResolveInterval timeT `long:"resolve-interval" env:"RESOLVE_INTERVAL" default:"1m" description:"Interval for resolving WireGuard server address (set 0 to disable)"`
Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 server address"`
ExitMode string `long:"exit-mode" env:"EXIT_MODE" choice:"remote" choice:"local" default:"remote" description:"Exit mode"`
Verbose bool `short:"v" long:"verbose" description:"Show verbose debug information"`
ClientID string `long:"client-id" env:"CLIENT_ID" hidden:"true"`
}