feat: new exit mode
parent
0fc2fafb50
commit
72493d4c3c
|
|
@ -0,0 +1 @@
|
|||
wghttp
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func ipcSet(dev *device.Device, opts options) error {
|
||||
privateKey, err := base64.StdEncoding.DecodeString(opts.PrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
peerKey, err := base64.StdEncoding.DecodeString(opts.PeerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse peer key: %w", err)
|
||||
}
|
||||
conf := "private_key=" + hex.EncodeToString(privateKey) + "\n"
|
||||
conf += "public_key=" + hex.EncodeToString(peerKey) + "\n"
|
||||
|
||||
peerAddr, err := net.ResolveUDPAddr("udp", opts.PeerEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve peer endpoint: %w", err)
|
||||
}
|
||||
|
||||
conf += "endpoint=" + peerAddr.String() + "\n"
|
||||
conf += "allowed_ip=0.0.0.0/0\n"
|
||||
conf += "allowed_ip=::/0\n"
|
||||
|
||||
if opts.ExitMode == "local" {
|
||||
conf += "persistent_keepalive_interval=10\n"
|
||||
}
|
||||
|
||||
if err := dev.IpcSet(conf); err != nil {
|
||||
return fmt.Errorf("set device config: %w", err)
|
||||
}
|
||||
|
||||
if peerAddr.String() != opts.PeerEndpoint {
|
||||
go refreshEndpoint(dev, peerKey, peerAddr.String(), opts.PeerEndpoint)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func refreshEndpoint(dev *device.Device, peerKey []byte, currentPeerAddr, peerEndpoint string) {
|
||||
c := time.Tick(10 * time.Second)
|
||||
|
||||
for range c {
|
||||
addr, err := net.ResolveUDPAddr("udp", peerEndpoint)
|
||||
if err != nil {
|
||||
logger.Errorf("Resolve peer endpoint: %v", err)
|
||||
continue
|
||||
}
|
||||
if currentPeerAddr == addr.String() {
|
||||
continue
|
||||
}
|
||||
currentPeerAddr = addr.String()
|
||||
logger.Verbosef("Endpoint is changed to: %s", addr)
|
||||
conf := "public_key=" + hex.EncodeToString(peerKey) + "\n"
|
||||
conf += "update_only=true\n"
|
||||
conf += "endpoint=" + addr.String() + "\n"
|
||||
if err := dev.IpcSet(conf); err != nil {
|
||||
logger.Errorf("Set device config: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type connBind struct {
|
||||
// magic 3 bytes in wireguard header reserved section.
|
||||
clientID []uint8
|
||||
defaultBind conn.Bind
|
||||
}
|
||||
|
||||
func newConnBind(clientID string) conn.Bind {
|
||||
defaultBind := conn.NewDefaultBind()
|
||||
if clientID == "" {
|
||||
return defaultBind
|
||||
}
|
||||
parsed, err := base64.StdEncoding.DecodeString(clientID)
|
||||
if err != nil {
|
||||
logger.Errorf("Invalid client id: %v, fallback to default", err)
|
||||
return defaultBind
|
||||
}
|
||||
return &connBind{clientID: parsed, defaultBind: defaultBind}
|
||||
}
|
||||
|
||||
func (c *connBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
fns, actualPort, err := c.defaultBind.Open(port)
|
||||
newFNs := make([]conn.ReceiveFunc, 0, len(fns))
|
||||
for i := range fns {
|
||||
f := fns[i]
|
||||
newFNs = append(newFNs, func(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
n, ep, err = f(b)
|
||||
if len(b) > 4 {
|
||||
copy(b[1:4], []byte{0, 0, 0})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
return newFNs, actualPort, err
|
||||
}
|
||||
|
||||
func (c *connBind) Close() error { return c.defaultBind.Close() }
|
||||
|
||||
func (c *connBind) SetMark(mark uint32) error { return c.defaultBind.SetMark(mark) }
|
||||
|
||||
func (c *connBind) Send(b []byte, ep conn.Endpoint) error {
|
||||
if len(b) > 4 {
|
||||
copy(b[1:4], c.clientID)
|
||||
}
|
||||
return c.defaultBind.Send(b, ep)
|
||||
}
|
||||
|
||||
func (c *connBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
return c.defaultBind.ParseEndpoint(s)
|
||||
}
|
||||
161
main.go
161
main.go
|
|
@ -1,15 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"log"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/jessevdk/go-flags"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
|
|
@ -18,14 +16,18 @@ import (
|
|||
"github.com/zhsj/wghttp/internal/third_party/tailscale/socks5"
|
||||
)
|
||||
|
||||
var logger *device.Logger
|
||||
|
||||
type options struct {
|
||||
PeerEndpoint string `long:"peer-endpoint" env:"PEER_ENDPOINT" required:"true" description:"WireGuard server address"`
|
||||
PeerKey string `long:"peer-key" env:"PEER_KEY" required:"true" description:"WireGuard server public key in base64 format"`
|
||||
PrivateKey string `long:"private-key" env:"PRIVATE_KEY" required:"true" description:"WireGuard client private key in base64 format"`
|
||||
ClientIPs []string `long:"client-ip" env:"CLIENT_IP" env-delim:"," required:"true" description:"WireGuard client IP address"`
|
||||
Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 proxy server listen address"`
|
||||
DNS []string `long:"dns" env:"DNS" env-delim:"," default:"1.0.0.1" description:"DNS server IP address"`
|
||||
Listen string `long:"listen" env:"LISTEN" default:"localhost:8080" description:"HTTP & SOCKS5 server address"`
|
||||
DNS []string `long:"dns" env:"DNS" env-delim:"," default:"" description:"DNS server IP address, only used WireGuard"`
|
||||
MTU int `long:"mtu" env:"MTU" default:"1280" description:"MTU"`
|
||||
ExitMode string `long:"exit-mode" env:"EXIT_MODE" default:"remote" choice:"remote" choice:"local" description:"Exit mode"`
|
||||
Verbose bool `short:"v" long:"verbose" description:"Show verbose debug information"`
|
||||
|
||||
ClientID string `long:"client-id" env:"CLIENT_ID" hidden:"true"`
|
||||
}
|
||||
|
|
@ -42,53 +44,88 @@ func main() {
|
|||
}
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
tnet := setupNet(opts)
|
||||
ln, err := net.Listen("tcp", opts.Listen)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
if opts.Verbose {
|
||||
logger = device.NewLogger(device.LogLevelVerbose, "")
|
||||
} else {
|
||||
logger = device.NewLogger(device.LogLevelError, "")
|
||||
}
|
||||
log.Printf("listening on %s", ln.Addr())
|
||||
|
||||
socksListener, httpListener := proxymux.SplitSOCKSAndHTTP(ln)
|
||||
dev, tnet, err := setupNet(opts)
|
||||
if err != nil {
|
||||
logger.Errorf("Setup netstack: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
httpProxy := &http.Server{Handler: httpproxy.Handler(tnet.DialContext)}
|
||||
socksProxy := &socks5.Server{Dialer: tnet.DialContext}
|
||||
listener, err := netListener(opts, tnet)
|
||||
if err != nil {
|
||||
logger.Errorf("Create net listener: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
socksListener, httpListener := proxymux.SplitSOCKSAndHTTP(listener)
|
||||
dialer := netDialer(opts, tnet)
|
||||
|
||||
httpProxy := &http.Server{Handler: statsHandler(httpproxy.Handler(dialer), dev)}
|
||||
socksProxy := &socks5.Server{Dialer: dialer}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() {
|
||||
if err := httpProxy.Serve(httpListener); err != nil {
|
||||
logger.Errorf("Serving http proxy: %v", err)
|
||||
errc <- err
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := socksProxy.Serve(socksListener); err != nil {
|
||||
logger.Errorf("Serving socks5 proxy: %v", err)
|
||||
errc <- err
|
||||
}
|
||||
}()
|
||||
|
||||
log.Fatal(<-errc)
|
||||
<-errc
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func setupNet(opts options) *netstack.Net {
|
||||
privateKey, err := base64.StdEncoding.DecodeString(opts.PrivateKey)
|
||||
func netListener(opts options, tnet *netstack.Net) (net.Listener, error) {
|
||||
var tcpListener net.Listener
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", opts.Listen)
|
||||
if err != nil {
|
||||
log.Fatalf("parse private key: %v", err)
|
||||
return nil, fmt.Errorf("resolve listen addr: %w", err)
|
||||
}
|
||||
peerKey, err := base64.StdEncoding.DecodeString(opts.PeerKey)
|
||||
if err != nil {
|
||||
log.Fatalf("parse peer key: %v", err)
|
||||
|
||||
switch opts.ExitMode {
|
||||
case "local":
|
||||
tcpListener, err = tnet.ListenTCP(tcpAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create listener on netstack: %w", err)
|
||||
}
|
||||
case "remote":
|
||||
tcpListener, err = net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create listener on local net: %w", err)
|
||||
}
|
||||
}
|
||||
conf := "private_key=" + hex.EncodeToString(privateKey) + "\n"
|
||||
conf += "public_key=" + hex.EncodeToString(peerKey) + "\n"
|
||||
conf += "endpoint=" + opts.PeerEndpoint + "\n"
|
||||
conf += "allowed_ip=0.0.0.0/0\n"
|
||||
conf += "allowed_ip=::/0\n"
|
||||
logger.Verbosef("Listening on %s", tcpListener.Addr())
|
||||
return tcpListener, nil
|
||||
}
|
||||
|
||||
func netDialer(opts options, tnet *netstack.Net) (dialer func(ctx context.Context, network, address string) (net.Conn, error)) {
|
||||
switch opts.ExitMode {
|
||||
case "local":
|
||||
var d net.Dialer
|
||||
dialer = d.DialContext
|
||||
case "remote":
|
||||
dialer = tnet.DialContext
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func setupNet(opts options) (*device.Device, *netstack.Net, error) {
|
||||
ips := []net.IP{}
|
||||
for _, s := range opts.ClientIPs {
|
||||
ip := net.ParseIP(s)
|
||||
if ip == nil {
|
||||
log.Fatalf("invalid local ip: %s", s)
|
||||
return nil, nil, fmt.Errorf("invalid client ip: %s", s)
|
||||
}
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
|
|
@ -96,70 +133,24 @@ func setupNet(opts options) *netstack.Net {
|
|||
for _, s := range opts.DNS {
|
||||
ip := net.ParseIP(s)
|
||||
if ip == nil {
|
||||
log.Fatalf("invalid dns ip: %s", s)
|
||||
return nil, nil, fmt.Errorf("invalid dns ip: %s", s)
|
||||
}
|
||||
dnsServers = append(dnsServers, ip)
|
||||
|
||||
}
|
||||
tun, tnet, err := netstack.CreateNetTUN(ips, dnsServers, opts.MTU)
|
||||
if err != nil {
|
||||
log.Fatalf("create netstack tun: %v", err)
|
||||
return nil, nil, fmt.Errorf("create netstack tun: %w", err)
|
||||
}
|
||||
dev := device.NewDevice(tun, newConnBind(opts.ClientID), device.NewLogger(device.LogLevelError, ""))
|
||||
if err := dev.IpcSet(conf); err != nil {
|
||||
log.Fatal(err)
|
||||
dev := device.NewDevice(tun, newConnBind(opts.ClientID), logger)
|
||||
|
||||
if err := ipcSet(dev, opts); err != nil {
|
||||
return nil, nil, fmt.Errorf("config device: %w", err)
|
||||
}
|
||||
|
||||
if err := dev.Up(); err != nil {
|
||||
log.Fatalf("bring up device: %v", err)
|
||||
return nil, nil, fmt.Errorf("bring up device: %w", err)
|
||||
}
|
||||
return tnet
|
||||
}
|
||||
|
||||
type connBind struct {
|
||||
// magic 3 bytes in wireguard header reserved section.
|
||||
clientID []uint8
|
||||
defaultBind conn.Bind
|
||||
}
|
||||
|
||||
func newConnBind(clientID string) conn.Bind {
|
||||
defaultBind := conn.NewDefaultBind()
|
||||
if clientID == "" {
|
||||
return defaultBind
|
||||
}
|
||||
parsed, err := base64.StdEncoding.DecodeString(clientID)
|
||||
if err != nil {
|
||||
log.Fatalf("parse client id: %v", err)
|
||||
}
|
||||
return &connBind{clientID: parsed, defaultBind: defaultBind}
|
||||
}
|
||||
|
||||
func (c *connBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
fns, actualPort, err := c.defaultBind.Open(port)
|
||||
newFNs := make([]conn.ReceiveFunc, 0, len(fns))
|
||||
for i := range fns {
|
||||
f := fns[i]
|
||||
newFNs = append(newFNs, func(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
n, ep, err = f(b)
|
||||
if len(b) > 4 {
|
||||
copy(b[1:4], []byte{0, 0, 0})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
return newFNs, actualPort, err
|
||||
}
|
||||
|
||||
func (c *connBind) Close() error { return c.defaultBind.Close() }
|
||||
|
||||
func (c *connBind) SetMark(mark uint32) error { return c.defaultBind.SetMark(mark) }
|
||||
|
||||
func (c *connBind) Send(b []byte, ep conn.Endpoint) error {
|
||||
if len(b) > 4 {
|
||||
copy(b[1:4], c.clientID)
|
||||
}
|
||||
return c.defaultBind.Send(b, ep)
|
||||
}
|
||||
|
||||
func (c *connBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
return c.defaultBind.ParseEndpoint(s)
|
||||
|
||||
return dev, tnet, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func statsHandler(next http.Handler, dev *device.Device) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Host != "" || r.URL.Path != "/stats" {
|
||||
next.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
stats := struct {
|
||||
Endpoint string
|
||||
LastHandshakeTimestamp int64
|
||||
ReceivedBytes int64
|
||||
SentBytes int64
|
||||
|
||||
NumGoroutine int
|
||||
}{
|
||||
NumGoroutine: runtime.NumGoroutine(),
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := dev.IpcGetOperation(&buf); err != nil {
|
||||
logger.Errorf("Get device config: %v", err)
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
scanner := bufio.NewScanner(&buf)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if prefix := "endpoint="; strings.HasPrefix(line, prefix) {
|
||||
stats.Endpoint = strings.TrimPrefix(line, prefix)
|
||||
}
|
||||
if prefix := "last_handshake_time_sec="; strings.HasPrefix(line, prefix) {
|
||||
stats.LastHandshakeTimestamp, _ = strconv.ParseInt(strings.TrimPrefix(line, prefix), 10, 64)
|
||||
}
|
||||
if prefix := "rx_bytes="; strings.HasPrefix(line, prefix) {
|
||||
stats.ReceivedBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, prefix), 10, 64)
|
||||
}
|
||||
if prefix := "tx_bytes="; strings.HasPrefix(line, prefix) {
|
||||
stats.SentBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, prefix), 10, 64)
|
||||
}
|
||||
}
|
||||
resp, _ := json.MarshalIndent(stats, "", " ")
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Write(append(resp, '\n'))
|
||||
}
|
||||
})
|
||||
}
|
||||
Loading…
Reference in New Issue