From b867d8198414862bd900bb56c63bcb902b3c83e3 Mon Sep 17 00:00:00 2001 From: bel Date: Wed, 4 Feb 2026 00:12:16 -0700 Subject: [PATCH] move conn acquire release to wrapper func accepting callback --- src/adapt.go | 52 +++++++++++++++++++++++---------------------------- src/config.go | 26 ++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/src/adapt.go b/src/adapt.go index 996302b..8805aad 100755 --- a/src/adapt.go +++ b/src/adapt.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "fmt" - "hash/crc32" "io" "log" "net" @@ -33,41 +32,36 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error { if len(message) > 1 { hashKey = message[1].(string) } - hash := int(crc32.ChecksumIEEE([]byte(hashKey))) - forward := config.forwards[hash%len(config.forwards)] - forwardCon := forward.Get() - if forwardCon == nil { - return true, io.EOF - } - forwardConn := forwardCon.(net.Conn) - log.Printf("forwarding %q", raw) - written := 0 - for written < len(raw) { - more, err := forwardConn.Write(raw[written:]) - if err != nil { - forwardConn.Close() - return true, err + var reply []byte + if err := config.WithConn(hashKey, func(forwardConn net.Conn) error { + log.Printf("forwarding %q", raw) + + written := 0 + for written < len(raw) { + more, err := forwardConn.Write(raw[written:]) + if err != nil { + return err + } + written += more } - written += more - } - replyer := bufio.NewReader(forwardConn) - log.Printf("reading reply to %q", raw) - raw, _, err := readMessage(replyer) - log.Printf("read reply %q", raw) - if err != nil { - forwardConn.Close() + + replyer := bufio.NewReader(forwardConn) + log.Printf("reading reply to %q", raw) + reply, _, err = readMessage(replyer) + log.Printf("read reply %q", raw) + return err + }); err != nil { return true, err } - forward.Put(forwardCon) - log.Printf("replying: %v", len(raw)) - written = 0 - for written < len(raw) { - more, err := conn.Write(raw[written:min(len(raw), written+4096)]) + log.Printf("replying: %v", len(reply)) + written := 0 + for written < len(reply) { + more, err := conn.Write(reply[written:min(len(reply), written+4096)]) if err != nil { return true, err } written += more - log.Printf("replied %v of %v...", written, len(raw)) + log.Printf("replied %v of %v...", written, len(reply)) } } diff --git a/src/config.go b/src/config.go index 285a102..b3f684e 100755 --- a/src/config.go +++ b/src/config.go @@ -6,6 +6,8 @@ import ( "encoding/base64" "encoding/json" "fmt" + "hash/crc32" + "io" "log" "net" "os" @@ -102,3 +104,27 @@ func (c Config) Close() { } } } + +func (c Config) WithConn(hashKey string, foo func(net.Conn) error) error { + hash := int(crc32.ChecksumIEEE([]byte(hashKey))) + hashIdx := hash % len(c.forwards) + forward := c.forwards[hashIdx] + + log.Printf("acquire conn %v", hashIdx) + forwardCon := forward.Get() + if forwardCon == nil { + log.Printf("got a nil conn to %v", hashIdx) + return io.EOF + } + forwardConn := forwardCon.(net.Conn) + + if err := foo(forwardConn); err != nil { + log.Printf("errored with conn %v: %v", hashIdx, err) + forwardConn.Close() + return err + } + + log.Printf("release conn %v", hashIdx) + forward.Put(forwardConn) + return nil +}