move conn acquire release to wrapper func accepting callback
This commit is contained in:
52
src/adapt.go
52
src/adapt.go
@@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@@ -33,41 +32,36 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error {
|
||||
if len(message) > 1 {
|
||||
hashKey = message[1].(string)
|
||||
}
|
||||
hash := int(crc32.ChecksumIEEE([]byte(hashKey)))
|
||||
forward := config.forwards[hash%len(config.forwards)]
|
||||
forwardCon := forward.Get()
|
||||
if forwardCon == nil {
|
||||
return true, io.EOF
|
||||
}
|
||||
forwardConn := forwardCon.(net.Conn)
|
||||
log.Printf("forwarding %q", raw)
|
||||
written := 0
|
||||
for written < len(raw) {
|
||||
more, err := forwardConn.Write(raw[written:])
|
||||
if err != nil {
|
||||
forwardConn.Close()
|
||||
return true, err
|
||||
var reply []byte
|
||||
if err := config.WithConn(hashKey, func(forwardConn net.Conn) error {
|
||||
log.Printf("forwarding %q", raw)
|
||||
|
||||
written := 0
|
||||
for written < len(raw) {
|
||||
more, err := forwardConn.Write(raw[written:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
written += more
|
||||
}
|
||||
written += more
|
||||
}
|
||||
replyer := bufio.NewReader(forwardConn)
|
||||
log.Printf("reading reply to %q", raw)
|
||||
raw, _, err := readMessage(replyer)
|
||||
log.Printf("read reply %q", raw)
|
||||
if err != nil {
|
||||
forwardConn.Close()
|
||||
|
||||
replyer := bufio.NewReader(forwardConn)
|
||||
log.Printf("reading reply to %q", raw)
|
||||
reply, _, err = readMessage(replyer)
|
||||
log.Printf("read reply %q", raw)
|
||||
return err
|
||||
}); err != nil {
|
||||
return true, err
|
||||
}
|
||||
forward.Put(forwardCon)
|
||||
log.Printf("replying: %v", len(raw))
|
||||
written = 0
|
||||
for written < len(raw) {
|
||||
more, err := conn.Write(raw[written:min(len(raw), written+4096)])
|
||||
log.Printf("replying: %v", len(reply))
|
||||
written := 0
|
||||
for written < len(reply) {
|
||||
more, err := conn.Write(reply[written:min(len(reply), written+4096)])
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
written += more
|
||||
log.Printf("replied %v of %v...", written, len(raw))
|
||||
log.Printf("replied %v of %v...", written, len(reply))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
@@ -102,3 +104,27 @@ func (c Config) Close() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) WithConn(hashKey string, foo func(net.Conn) error) error {
|
||||
hash := int(crc32.ChecksumIEEE([]byte(hashKey)))
|
||||
hashIdx := hash % len(c.forwards)
|
||||
forward := c.forwards[hashIdx]
|
||||
|
||||
log.Printf("acquire conn %v", hashIdx)
|
||||
forwardCon := forward.Get()
|
||||
if forwardCon == nil {
|
||||
log.Printf("got a nil conn to %v", hashIdx)
|
||||
return io.EOF
|
||||
}
|
||||
forwardConn := forwardCon.(net.Conn)
|
||||
|
||||
if err := foo(forwardConn); err != nil {
|
||||
log.Printf("errored with conn %v: %v", hashIdx, err)
|
||||
forwardConn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("release conn %v", hashIdx)
|
||||
forward.Put(forwardConn)
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user