diff --git a/src/adapt.go b/src/adapt.go index 4d9e755..21a1c96 100755 --- a/src/adapt.go +++ b/src/adapt.go @@ -32,7 +32,6 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error { if len(message) > 1 { hashKey = message[1].(string) } - var reply []byte if err := config.WithConn(hashKey, func(forwardConn net.Conn) error { //log.Printf("forwarding %q", raw) @@ -47,22 +46,12 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error { replyer := bufio.NewReader(forwardConn) //log.Printf("reading reply to %q", raw) - reply, _, err = readMessage(replyer) + _, err = readMessageTo(conn, replyer) //log.Printf("read reply %q", raw) return err }); err != nil { return true, err } - log.Printf("replying: %v", len(reply)) - written := 0 - for written < len(reply) { - more, err := conn.Write(reply[written:min(len(reply), written+4096)]) - if err != nil { - return true, err - } - written += more - //log.Printf("replied %v of %v...", written, len(reply)) - } } return false, nil @@ -77,42 +66,42 @@ func adapt(ctx context.Context, config Config, conn net.Conn) error { } func readMessage(reader *bufio.Reader) ([]byte, []any, error) { - b, arr, err := _readMessage(reader) - return b, arr, err + w := bytes.NewBuffer(nil) + arr, err := readMessageTo(w, reader) + return w.Bytes(), arr, err } -func _readMessage(reader *bufio.Reader) ([]byte, []any, error) { - w := bytes.NewBuffer(nil) +func readMessageTo(w io.Writer, reader *bufio.Reader) ([]any, error) { firstLine, _, err := reader.ReadLine() if err != nil { - return w.Bytes(), nil, err + return nil, err } firstLine = bytes.TrimSuffix(firstLine, []byte("\r\n")) if len(firstLine) == 0 { - return w.Bytes(), nil, nil + return nil, nil } fmt.Fprintf(w, "%s\r\n", firstLine) switch firstLine[0] { case '+': // simple string, like +OK - return w.Bytes(), []any{string(firstLine[1:])}, nil + return []any{string(firstLine[1:])}, nil case '-': // simple error, like -message - return w.Bytes(), []any{string(firstLine)}, nil + return []any{string(firstLine)}, nil case ':': // number, like /[+-][0-9]+/ firstLine = bytes.TrimPrefix(firstLine[1:], []byte("+")) n, err := strconv.Atoi(string(firstLine)) if err != nil { - return w.Bytes(), nil, fmt.Errorf("num not a num in %q: %q: %w", w.Bytes(), firstLine, err) + return nil, fmt.Errorf("num not a num: %q: %w", firstLine, err) } - return w.Bytes(), []any{n}, err + return []any{n}, err case '$': // long string, like $-1 for nil, like $LEN\r\nSTRING\r\n if firstLine[1] == '-' { - return w.Bytes(), []any{nil}, nil + return []any{nil}, nil } firstLine = bytes.TrimPrefix(firstLine[1:], []byte("+")) n, err := strconv.Atoi(string(firstLine)) if err != nil { - return w.Bytes(), nil, fmt.Errorf("num not a num in %q: %q: %w", w.Bytes(), firstLine, err) + return nil, fmt.Errorf("num not a num: %q: %w", firstLine, err) } //log.Printf("reading %v+2 bytes for bulk string", n) nextLine := make([]byte, n+2) @@ -120,40 +109,36 @@ func _readMessage(reader *bufio.Reader) ([]byte, []any, error) { for nAt < n+2 { nMore, err := reader.Read(nextLine[nAt:]) if err != nil { - return w.Bytes(), nil, fmt.Errorf("couldnt read %v more/%v bytes for long string: %w", n+2-nAt, n+2, err) + return nil, fmt.Errorf("couldnt read %v more/%v bytes for long string: %w", n+2-nAt, n+2, err) } nAt += nMore } fmt.Fprintf(w, "%s", nextLine) - return w.Bytes(), []any{string(nextLine)}, err + return []any{string(nextLine)}, err case '*': // array, like *-1 for nil, like *4 for [1,2,3,4] n, err := strconv.Atoi(string(firstLine[1:])) if err != nil { - return w.Bytes(), nil, fmt.Errorf("arr not a num: %q: %w", firstLine, err) + return nil, fmt.Errorf("arr not a num: %q: %w", firstLine, err) } else if n == -1 { - return w.Bytes(), nil, nil + return nil, nil } var result []any for i := 0; i < n; i++ { - moreBytes, more, err := _readMessage(reader) - moreBytes = bytes.TrimSuffix(moreBytes, []byte("\r\n")) - if len(moreBytes) > 0 { - fmt.Fprintf(w, "%s\r\n", moreBytes) - } + more, err := readMessageTo(w, reader) if err != nil { - return w.Bytes(), nil, err + return nil, err } result = append(result, more...) } - return w.Bytes(), result, nil + return result, nil case '_': // nil - return w.Bytes(), nil, nil + return nil, nil case '#': // boolean, like #t or #f - return w.Bytes(), []any{firstLine[1] == 't'}, nil + return []any{firstLine[1] == 't'}, nil case ',': // double log.Fatal("not impl") } log.Fatalf("not impl: %q", firstLine) - return w.Bytes(), nil, nil + return nil, nil }