move conn acquire release to wrapper func accepting callback

This commit is contained in:
bel
2026-02-04 00:12:16 -07:00
parent 25b5218e5a
commit b867d81984
2 changed files with 49 additions and 29 deletions

View File

@@ -5,7 +5,6 @@ import (
"bytes"
"context"
"fmt"
"hash/crc32"
"io"
"log"
"net"
@@ -33,41 +32,36 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error {
if len(message) > 1 {
hashKey = message[1].(string)
}
hash := int(crc32.ChecksumIEEE([]byte(hashKey)))
forward := config.forwards[hash%len(config.forwards)]
forwardCon := forward.Get()
if forwardCon == nil {
return true, io.EOF
}
forwardConn := forwardCon.(net.Conn)
log.Printf("forwarding %q", raw)
written := 0
for written < len(raw) {
more, err := forwardConn.Write(raw[written:])
if err != nil {
forwardConn.Close()
return true, err
var reply []byte
if err := config.WithConn(hashKey, func(forwardConn net.Conn) error {
log.Printf("forwarding %q", raw)
written := 0
for written < len(raw) {
more, err := forwardConn.Write(raw[written:])
if err != nil {
return err
}
written += more
}
written += more
}
replyer := bufio.NewReader(forwardConn)
log.Printf("reading reply to %q", raw)
raw, _, err := readMessage(replyer)
log.Printf("read reply %q", raw)
if err != nil {
forwardConn.Close()
replyer := bufio.NewReader(forwardConn)
log.Printf("reading reply to %q", raw)
reply, _, err = readMessage(replyer)
log.Printf("read reply %q", raw)
return err
}); err != nil {
return true, err
}
forward.Put(forwardCon)
log.Printf("replying: %v", len(raw))
written = 0
for written < len(raw) {
more, err := conn.Write(raw[written:min(len(raw), written+4096)])
log.Printf("replying: %v", len(reply))
written := 0
for written < len(reply) {
more, err := conn.Write(reply[written:min(len(reply), written+4096)])
if err != nil {
return true, err
}
written += more
log.Printf("replied %v of %v...", written, len(raw))
log.Printf("replied %v of %v...", written, len(reply))
}
}

View File

@@ -6,6 +6,8 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"hash/crc32"
"io"
"log"
"net"
"os"
@@ -102,3 +104,27 @@ func (c Config) Close() {
}
}
}
func (c Config) WithConn(hashKey string, foo func(net.Conn) error) error {
hash := int(crc32.ChecksumIEEE([]byte(hashKey)))
hashIdx := hash % len(c.forwards)
forward := c.forwards[hashIdx]
log.Printf("acquire conn %v", hashIdx)
forwardCon := forward.Get()
if forwardCon == nil {
log.Printf("got a nil conn to %v", hashIdx)
return io.EOF
}
forwardConn := forwardCon.(net.Conn)
if err := foo(forwardConn); err != nil {
log.Printf("errored with conn %v: %v", hashIdx, err)
forwardConn.Close()
return err
}
log.Printf("release conn %v", hashIdx)
forward.Put(forwardConn)
return nil
}