package proxy import ( "bufio" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" "log" "net" "net/http" "strings" "github.com/zhsj/wghttp/internal/resolver" "github.com/zhsj/wghttp/internal/third_party/tailscale/httpproxy" "github.com/zhsj/wghttp/internal/third_party/tailscale/proxymux" "github.com/zhsj/wghttp/internal/third_party/tailscale/socks5" ) type dialer func(ctx context.Context, network, address string) (net.Conn, error) type Proxy struct { Dial dialer DNS string Stats func() (any, error) } func statsHandler(next http.Handler, stats func() (any, error)) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { if r.URL.Host != "" || r.URL.Path != "/stats" { next.ServeHTTP(rw, r) return } s, err := stats() if err != nil { rw.WriteHeader(http.StatusInternalServerError) } else { resp, _ := json.MarshalIndent(s, "", " ") rw.Header().Set("Content-Type", "application/json") _, _ = rw.Write(append(resp, '\n')) } }) } func dialWithDNS(dial dialer, dns string) dialer { resolv := resolver.New(dns, dial) return func(ctx context.Context, network, address string) (net.Conn, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, err } if err == nil { if ip := net.ParseIP(host); ip != nil { return dial(ctx, network, address) } } ips, err := resolv.LookupNetIP(ctx, network, host) if err != nil { return nil, err } var ( lastErr error conn net.Conn ) for _, ip := range ips { addr := net.JoinHostPort(ip.String(), port) conn, lastErr = dial(ctx, network, addr) if lastErr == nil { return conn, nil } } return nil, lastErr } } func (p Proxy) Serve(ln net.Listener) { d := dialWithDNS(p.Dial, p.DNS) socksListener, httpListener := proxymux.SplitSOCKSAndHTTP(ln) httpProxy := &http.Server{Handler: statsHandler(httpproxy.Handler(d), p.Stats)} socksProxy := &socks5.Server{Dialer: d} errc := make(chan error, 2) go func() { if err := httpProxy.Serve(httpListener); err != nil { errc <- err } }() go func() { if err := socksProxy.Serve(socksListener); err != nil { errc <- err } }() <-errc } func (p Proxy) ServeTLS(ln net.Listener, certFile, keyFile string) { d := dialWithDNS(p.Dial, p.DNS) cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { panic(err) } cfg := &tls.Config{Certificates: []tls.Certificate{cert}} ln = tls.NewListener(ln, cfg) defer ln.Close() for { if err := acceptAndHandle(ln, httpproxy.Handler(d), d); err != nil { panic(err) } } } type singleListener struct { single net.Conn addr net.Addr } func (listener *singleListener) Accept() (net.Conn, error) { if listener.single == nil { return nil, errors.New("single listener closed") } defer listener.Close() return listener.single, nil } func (listener *singleListener) Close() error { listener.single = nil return nil } func (listener *singleListener) Addr() net.Addr { return listener.addr } func acceptAndHandle(ln net.Listener, h http.Handler, d dialer) error { conn, err := ln.Accept() if err != nil { return err } go func() { conn, ok := isConnect(conn) if ok { handleConnectConn(conn, d) } else { handleHTTPConn(conn, h) } }() return nil } func handleConnectConn(conn peekingConn, d dialer) { if err := _handleConnectConn(conn, d); err != nil { log.Println("connect err:", err) } } func _handleConnectConn(conn peekingConn, d dialer) error { defer conn.Close() var host string for { line, _ := conn.buff.ReadString('\n') if strings.HasPrefix(line, "Host: ") { host = strings.TrimSpace(line[6:]) } if strings.TrimSpace(line) == "" { break } } if host == "" { return errors.New("no Host:") } conn2, err := d(conn.ctx, "tcp", host) if err != nil { return fmt.Errorf("failed to dial tcp:%q: %w", host, err) } defer conn2.Close() io.WriteString(conn, "HTTP/1.1 200 OK\r\n\r\n") c := make(chan error, 1) go func() { _, err := io.Copy(conn2, conn) select { case c <- err: default: close(c) } }() go func() { io.Copy(conn, conn2) select { case c <- err: default: close(c) } }() <-c return nil } func handleHTTPConn(conn net.Conn, h http.Handler) { l := &singleListener{ single: conn, addr: conn.LocalAddr(), } s := &http.Server{ Handler: h, } s.Serve(l) } type peekingConn struct { net.Conn buff *bufio.Reader ctx context.Context can context.CancelFunc } func (conn peekingConn) Close() error { conn.can() return conn.Conn.Close() } func (conn peekingConn) Read(b []byte) (int, error) { return conn.buff.Read(b) } func isConnect(c net.Conn) (peekingConn, bool) { ctx, can := context.WithCancel(context.Background()) conn := peekingConn{Conn: c, buff: bufio.NewReader(c), ctx: ctx, can: can} b, _ := conn.buff.Peek(3) return conn, string(b) == "CON" }