move conn acquire release to wrapper func accepting callback
This commit is contained in:
32
src/adapt.go
32
src/adapt.go
@@ -5,7 +5,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/crc32"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@@ -33,41 +32,36 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error {
|
|||||||
if len(message) > 1 {
|
if len(message) > 1 {
|
||||||
hashKey = message[1].(string)
|
hashKey = message[1].(string)
|
||||||
}
|
}
|
||||||
hash := int(crc32.ChecksumIEEE([]byte(hashKey)))
|
var reply []byte
|
||||||
forward := config.forwards[hash%len(config.forwards)]
|
if err := config.WithConn(hashKey, func(forwardConn net.Conn) error {
|
||||||
forwardCon := forward.Get()
|
|
||||||
if forwardCon == nil {
|
|
||||||
return true, io.EOF
|
|
||||||
}
|
|
||||||
forwardConn := forwardCon.(net.Conn)
|
|
||||||
log.Printf("forwarding %q", raw)
|
log.Printf("forwarding %q", raw)
|
||||||
|
|
||||||
written := 0
|
written := 0
|
||||||
for written < len(raw) {
|
for written < len(raw) {
|
||||||
more, err := forwardConn.Write(raw[written:])
|
more, err := forwardConn.Write(raw[written:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
forwardConn.Close()
|
return err
|
||||||
return true, err
|
|
||||||
}
|
}
|
||||||
written += more
|
written += more
|
||||||
}
|
}
|
||||||
|
|
||||||
replyer := bufio.NewReader(forwardConn)
|
replyer := bufio.NewReader(forwardConn)
|
||||||
log.Printf("reading reply to %q", raw)
|
log.Printf("reading reply to %q", raw)
|
||||||
raw, _, err := readMessage(replyer)
|
reply, _, err = readMessage(replyer)
|
||||||
log.Printf("read reply %q", raw)
|
log.Printf("read reply %q", raw)
|
||||||
if err != nil {
|
return err
|
||||||
forwardConn.Close()
|
}); err != nil {
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
forward.Put(forwardCon)
|
log.Printf("replying: %v", len(reply))
|
||||||
log.Printf("replying: %v", len(raw))
|
written := 0
|
||||||
written = 0
|
for written < len(reply) {
|
||||||
for written < len(raw) {
|
more, err := conn.Write(reply[written:min(len(reply), written+4096)])
|
||||||
more, err := conn.Write(raw[written:min(len(raw), written+4096)])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
written += more
|
written += more
|
||||||
log.Printf("replied %v of %v...", written, len(raw))
|
log.Printf("replied %v of %v...", written, len(reply))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash/crc32"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@@ -102,3 +104,27 @@ func (c Config) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c Config) WithConn(hashKey string, foo func(net.Conn) error) error {
|
||||||
|
hash := int(crc32.ChecksumIEEE([]byte(hashKey)))
|
||||||
|
hashIdx := hash % len(c.forwards)
|
||||||
|
forward := c.forwards[hashIdx]
|
||||||
|
|
||||||
|
log.Printf("acquire conn %v", hashIdx)
|
||||||
|
forwardCon := forward.Get()
|
||||||
|
if forwardCon == nil {
|
||||||
|
log.Printf("got a nil conn to %v", hashIdx)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
forwardConn := forwardCon.(net.Conn)
|
||||||
|
|
||||||
|
if err := foo(forwardConn); err != nil {
|
||||||
|
log.Printf("errored with conn %v: %v", hashIdx, err)
|
||||||
|
forwardConn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("release conn %v", hashIdx)
|
||||||
|
forward.Put(forwardConn)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user