package src import ( "bufio" "bytes" "context" "fmt" "io" "log" "net" "strconv" ) func adapt(ctx context.Context, config Config, conn net.Conn) error { reader := bufio.NewReader(conn) for ctx.Err() == nil { if done, err := func() (bool, error) { raw, message, err := readMessage(reader) if err != nil { if err == io.EOF { return true, nil } return true, err } log.Printf("routing: %q (%+v)", raw, message) if len(message) > 0 && message[0].(string) == "COMMAND" { fmt.Fprintf(conn, "*0\r\n") } else if len(raw) > 0 && len(message) > 0 { hashKey := message[0].(string) if len(message) > 1 { hashKey = message[1].(string) } if err := config.WithConn(hashKey, func(forwardConn net.Conn) error { //log.Printf("forwarding %q", raw) written := 0 for written < len(raw) { more, err := forwardConn.Write(raw[written:]) if err != nil { return err } written += more } replyer := bufio.NewReader(forwardConn) //log.Printf("reading reply to %q", raw) _, err = readMessageTo(conn, replyer) //log.Printf("read reply %q", raw) return err }); err != nil { return true, err } } return false, nil }(); err != nil { return err } else if done { return nil } } return io.EOF } func readMessage(reader *bufio.Reader) ([]byte, error) { w := bytes.NewBuffer(nil) err := readMessageTo(w, reader) return w.Bytes(), err } func readMessageTo(w io.Writer, reader *bufio.Reader) error { w2 := bufio.NewWriter(w) defer w2.Flush() err := _readMessageTo(w2, reader) if err != nil { return err } return w2.Flush() } func _readMessageTo(w io.Writer, reader *bufio.Reader) error { firstLine, _, err := reader.ReadLine() if err != nil { return err } firstLine = bytes.TrimSuffix(firstLine, []byte("\r\n")) if len(firstLine) == 0 { return nil } fmt.Fprintf(w, "%s\r\n", firstLine) switch firstLine[0] { case '+': // simple string, like +OK return nil case '-': // simple error, like -message return nil case ':': // number, like /[+-][0-9]+/ return err case '$': // long string, like $-1 for nil, like $LEN\r\nSTRING\r\n if firstLine[1] == '-' { return nil } firstLine = bytes.TrimPrefix(firstLine[1:], []byte("+")) n, err := strconv.Atoi(string(firstLine)) if err != nil { return fmt.Errorf("num not a num: %q: %w", firstLine, err) } nextLine := make([]byte, n+2) nAt := 0 for nAt < n+2 { nMore, err := reader.Read(nextLine[nAt:]) if err != nil { return fmt.Errorf("couldnt read %v more/%v bytes for long string: %w", n+2-nAt, n+2, err) } nAt += nMore } fmt.Fprintf(w, "%s", nextLine) return err case '*', '%': // *=array %=map, like *-1 for nil, like *4 for [1,2,3,4] n, err := strconv.Atoi(string(firstLine[1:])) if err != nil { return fmt.Errorf("arr not a num: %q: %w", firstLine, err) } else if n == -1 { return nil } if firstLine[0] == '%' { n *= 2 } for i := 0; i < n; i++ { err := _readMessageTo(w, reader) if err != nil { return err } } return nil case '_': // nil return nil case '#': // boolean, like #t or #f return nil case ',': // double return nil } log.Fatalf("not impl: %q", firstLine) return nil }