move conn acquire release to wrapper func accepting callback

This commit is contained in:
bel
2026-02-04 00:12:16 -07:00
parent 25b5218e5a
commit b867d81984
2 changed files with 49 additions and 29 deletions

View File

@@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"hash/crc32"
"io" "io"
"log" "log"
"net" "net"
@@ -33,41 +32,36 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error {
if len(message) > 1 { if len(message) > 1 {
hashKey = message[1].(string) hashKey = message[1].(string)
} }
hash := int(crc32.ChecksumIEEE([]byte(hashKey))) var reply []byte
forward := config.forwards[hash%len(config.forwards)] if err := config.WithConn(hashKey, func(forwardConn net.Conn) error {
forwardCon := forward.Get()
if forwardCon == nil {
return true, io.EOF
}
forwardConn := forwardCon.(net.Conn)
log.Printf("forwarding %q", raw) log.Printf("forwarding %q", raw)
written := 0 written := 0
for written < len(raw) { for written < len(raw) {
more, err := forwardConn.Write(raw[written:]) more, err := forwardConn.Write(raw[written:])
if err != nil { if err != nil {
forwardConn.Close() return err
return true, err
} }
written += more written += more
} }
replyer := bufio.NewReader(forwardConn) replyer := bufio.NewReader(forwardConn)
log.Printf("reading reply to %q", raw) log.Printf("reading reply to %q", raw)
raw, _, err := readMessage(replyer) reply, _, err = readMessage(replyer)
log.Printf("read reply %q", raw) log.Printf("read reply %q", raw)
if err != nil { return err
forwardConn.Close() }); err != nil {
return true, err return true, err
} }
forward.Put(forwardCon) log.Printf("replying: %v", len(reply))
log.Printf("replying: %v", len(raw)) written := 0
written = 0 for written < len(reply) {
for written < len(raw) { more, err := conn.Write(reply[written:min(len(reply), written+4096)])
more, err := conn.Write(raw[written:min(len(raw), written+4096)])
if err != nil { if err != nil {
return true, err return true, err
} }
written += more written += more
log.Printf("replied %v of %v...", written, len(raw)) log.Printf("replied %v of %v...", written, len(reply))
} }
} }

View File

@@ -6,6 +6,8 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"hash/crc32"
"io"
"log" "log"
"net" "net"
"os" "os"
@@ -102,3 +104,27 @@ func (c Config) Close() {
} }
} }
} }
func (c Config) WithConn(hashKey string, foo func(net.Conn) error) error {
hash := int(crc32.ChecksumIEEE([]byte(hashKey)))
hashIdx := hash % len(c.forwards)
forward := c.forwards[hashIdx]
log.Printf("acquire conn %v", hashIdx)
forwardCon := forward.Get()
if forwardCon == nil {
log.Printf("got a nil conn to %v", hashIdx)
return io.EOF
}
forwardConn := forwardCon.(net.Conn)
if err := foo(forwardConn); err != nil {
log.Printf("errored with conn %v: %v", hashIdx, err)
forwardConn.Close()
return err
}
log.Printf("release conn %v", hashIdx)
forward.Put(forwardConn)
return nil
}