From 25b5218e5acda15550ef3982ba118c6585474e36 Mon Sep 17 00:00:00 2001 From: bel Date: Wed, 4 Feb 2026 00:02:51 -0700 Subject: [PATCH] mvp with terrible conn handling --- go.mod | 0 go.sum | 0 main.go | 0 mise.toml | 0 src/adapt.go | 80 ++++++++++++++++++++++++++++++++--------- src/config.go | 47 +++++++++++++++++++----- src/listen.go | 5 +-- src/main.go | 3 ++ src/testdata/redises.sh | 17 +++++++++ 9 files changed, 125 insertions(+), 27 deletions(-) mode change 100644 => 100755 go.mod mode change 100644 => 100755 go.sum mode change 100644 => 100755 main.go mode change 100644 => 100755 mise.toml mode change 100644 => 100755 src/adapt.go mode change 100644 => 100755 src/config.go mode change 100644 => 100755 src/listen.go mode change 100644 => 100755 src/main.go create mode 100644 src/testdata/redises.sh diff --git a/go.mod b/go.mod old mode 100644 new mode 100755 diff --git a/go.sum b/go.sum old mode 100644 new mode 100755 diff --git a/main.go b/main.go old mode 100644 new mode 100755 diff --git a/mise.toml b/mise.toml old mode 100644 new mode 100755 diff --git a/src/adapt.go b/src/adapt.go old mode 100644 new mode 100755 index 215a301..996302b --- a/src/adapt.go +++ b/src/adapt.go @@ -18,16 +18,21 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error { for ctx.Err() == nil { if done, err := func() (bool, error) { raw, message, err := readMessage(reader) - log.Printf("%q", raw) if err != nil { if err == io.EOF { return true, nil } return true, err } + log.Printf("routing: %q (%+v)", raw, message) - if len(message) > 0 { - hashKey := message[max(0, len(message)-1)].(string) + if len(message) > 0 && message[0].(string) == "COMMAND" { + fmt.Fprintf(conn, "*0\r\n") + } else if len(raw) > 0 && len(message) > 0 { + hashKey := message[0].(string) + if len(message) > 1 { + hashKey = message[1].(string) + } hash := int(crc32.ChecksumIEEE([]byte(hashKey))) forward := config.forwards[hash%len(config.forwards)] forwardCon := forward.Get() @@ -35,17 +40,34 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error { return true, io.EOF } forwardConn := forwardCon.(net.Conn) - if _, err := forwardConn.Write(raw); err != nil { - return true, err + log.Printf("forwarding %q", raw) + written := 0 + for written < len(raw) { + more, err := forwardConn.Write(raw[written:]) + if err != nil { + forwardConn.Close() + return true, err + } + written += more } replyer := bufio.NewReader(forwardConn) + log.Printf("reading reply to %q", raw) raw, _, err := readMessage(replyer) + log.Printf("read reply %q", raw) if err != nil { + forwardConn.Close() return true, err } - log.Printf("%q", raw) - if _, err := conn.Write(raw); err != nil { - return true, err + forward.Put(forwardCon) + log.Printf("replying: %v", len(raw)) + written = 0 + for written < len(raw) { + more, err := conn.Write(raw[written:min(len(raw), written+4096)]) + if err != nil { + return true, err + } + written += more + log.Printf("replied %v of %v...", written, len(raw)) } } @@ -61,9 +83,13 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error { } func readMessage(reader *bufio.Reader) ([]byte, []any, error) { + b, arr, err := _readMessage(reader) + return b, arr, err +} + +func _readMessage(reader *bufio.Reader) ([]byte, []any, error) { w := bytes.NewBuffer(nil) firstLine, _, err := reader.ReadLine() - w.Write(firstLine) if err != nil { return w.Bytes(), nil, err } @@ -71,35 +97,55 @@ func readMessage(reader *bufio.Reader) ([]byte, []any, error) { if len(firstLine) == 0 { return w.Bytes(), nil, nil } + fmt.Fprintf(w, "%s\r\n", firstLine) switch firstLine[0] { case '+': // simple string, like +OK return w.Bytes(), []any{string(firstLine[1:])}, nil case '-': // simple error, like -message - return w.Bytes(), nil, fmt.Errorf("error: %s", firstLine[1:]) + return w.Bytes(), []any{string(firstLine)}, nil case ':': // number, like /[+-][0-9]+/ firstLine = bytes.TrimPrefix(firstLine[1:], []byte("+")) - n, err := strconv.Atoi(string(firstLine[1:])) + n, err := strconv.Atoi(string(firstLine)) + if err != nil { + return w.Bytes(), nil, fmt.Errorf("num not a num in %q: %q: %w", w.Bytes(), firstLine, err) + } return w.Bytes(), []any{n}, err case '$': // long string, like $-1 for nil, like $LEN\r\nSTRING\r\n if firstLine[1] == '-' { return w.Bytes(), []any{nil}, nil } - nextLine, _, err := reader.ReadLine() - w.Write(nextLine) - nextLine = bytes.TrimSuffix(nextLine, []byte("\r\n")) + firstLine = bytes.TrimPrefix(firstLine[1:], []byte("+")) + n, err := strconv.Atoi(string(firstLine)) + if err != nil { + return w.Bytes(), nil, fmt.Errorf("num not a num in %q: %q: %w", w.Bytes(), firstLine, err) + } + log.Printf("reading %v+2 bytes for bulk string", n) + nextLine := make([]byte, n+2) + nAt := 0 + for nAt < n+2 { + nMore, err := reader.Read(nextLine[nAt:]) + if err != nil { + return w.Bytes(), nil, fmt.Errorf("couldnt read %v more/%v bytes for long string: %w", n+2-nAt, n+2, err) + } + nAt += nMore + } + fmt.Fprintf(w, "%s", nextLine) return w.Bytes(), []any{string(nextLine)}, err case '*': // array, like *-1 for nil, like *4 for [1,2,3,4] n, err := strconv.Atoi(string(firstLine[1:])) if err != nil { - return w.Bytes(), nil, err + return w.Bytes(), nil, fmt.Errorf("arr not a num: %q: %w", firstLine, err) } else if n == -1 { return w.Bytes(), nil, nil } var result []any for i := 0; i < n; i++ { - moreBytes, more, err := readMessage(reader) - w.Write(moreBytes) + moreBytes, more, err := _readMessage(reader) + moreBytes = bytes.TrimSuffix(moreBytes, []byte("\r\n")) + if len(moreBytes) > 0 { + fmt.Fprintf(w, "%s\r\n", moreBytes) + } if err != nil { return w.Bytes(), nil, err } diff --git a/src/config.go b/src/config.go old mode 100644 new mode 100755 index ea3ad69..285a102 --- a/src/config.go +++ b/src/config.go @@ -1,9 +1,12 @@ package src import ( + "bufio" "context" + "encoding/base64" "encoding/json" "fmt" + "log" "net" "os" "slices" @@ -14,6 +17,7 @@ import ( type Config struct { Listen string `json:"LISTEN"` Forwards string `json:"FORWARDS"` + Hello string `json:"HELLO_B64"` forwards []*sync.Pool } @@ -21,6 +25,7 @@ func NewConfig(ctx context.Context) (Config, error) { config := Config{ Listen: ":10000", Forwards: "", + Hello: base64.StdEncoding.EncodeToString([]byte("*1\r\n$4\r\nping\r\n")), } b, _ := json.Marshal(config) var m map[string]any @@ -37,6 +42,11 @@ func NewConfig(ctx context.Context) (Config, error) { return config, err } + hello, err := base64.StdEncoding.DecodeString(config.Hello) + if err != nil { + return config, err + } + forwards := strings.Split(config.Forwards, ",") forwards = slices.DeleteFunc(forwards, func(s string) bool { return s == "" }) if len(forwards) == 0 { @@ -44,12 +54,30 @@ func NewConfig(ctx context.Context) (Config, error) { } config.forwards = make([]*sync.Pool, len(forwards)) for i := range forwards { + forward := forwards[i] config.forwards[i] = &sync.Pool{ New: func() any { - v, err := (&net.Dialer{}).DialContext(ctx, "tcp", forwards[i]) - if err != nil { + if ctx.Err() != nil { return nil } + log.Printf("dialing %q with %q", forward, hello) + v, err := (&net.Dialer{}).DialContext(ctx, "tcp", forward) + if err != nil { + log.Printf("! failed dial %q: %v", forward, err) + return nil + } + if _, err := v.Write([]byte(hello)); err != nil { + log.Printf("! failed write hello %q: %v", forward, err) + v.Close() + return nil + } + if raw, _, err := readMessage(bufio.NewReader(v)); err != nil { + log.Printf("! failed read hello %q: %v", forward, err) + v.Close() + return nil + } else { + log.Printf("dial reply: %q", raw) + } return v }, } @@ -61,13 +89,16 @@ func NewConfig(ctx context.Context) (Config, error) { func (c Config) Close() { for i := range c.forwards { if c.forwards[i] != nil { - c.forwards[i].New = nil - for { - got := c.forwards[i].Get() - if got != nil { - got.(net.Conn).Close() + i := i + go func() { + c.forwards[i].New = nil + for { + got := c.forwards[i].Get() + if got != nil { + go got.(net.Conn).Close() + } } - } + }() } } } diff --git a/src/listen.go b/src/listen.go old mode 100644 new mode 100755 index 8b0b7a7..1134508 --- a/src/listen.go +++ b/src/listen.go @@ -8,6 +8,8 @@ import ( ) func listen(ctx context.Context, config Config) error { + defer log.Println("/listen()") + wg := &sync.WaitGroup{} defer wg.Wait() @@ -15,8 +17,6 @@ func listen(ctx context.Context, config Config) error { if err != nil { return err } - defer listener.Close() - wg.Add(1) go func() { defer wg.Done() @@ -31,6 +31,7 @@ func listen(ctx context.Context, config Config) error { } else { wg.Add(1) go func() { + defer log.Println("/handle()") defer wg.Done() handle(ctx, config, conn) }() diff --git a/src/main.go b/src/main.go old mode 100644 new mode 100755 index d85f4f3..f48c2dd --- a/src/main.go +++ b/src/main.go @@ -2,9 +2,12 @@ package src import ( "context" + "log" ) func Main(ctx context.Context) error { + defer log.Println("/Main()") + config, err := NewConfig(ctx) if err != nil { return err diff --git a/src/testdata/redises.sh b/src/testdata/redises.sh new file mode 100644 index 0000000..5ed2d07 --- /dev/null +++ b/src/testdata/redises.sh @@ -0,0 +1,17 @@ +#! /usr/bin/bash + +d=$(mktemp -d) +mkdir -p $d/1 +mkdir -p $d/2 +valkey-server --dir $d/1 --port 60113 --logfile $d/1.log & +bg=${!} +cleanup() { + kill $bg +} +trap cleanup EXIT +( + sleep 5 + echo FORWARDS=127.0.0.1:60113,127.0.0.1:60114 +) & +valkey-server --dir $d/2 --port 60114 --logfile $d/2.log +