diff --git a/src/adapt.go b/src/adapt.go index 6790861..567cd7d 100644 --- a/src/adapt.go +++ b/src/adapt.go @@ -1,13 +1,143 @@ package src import ( + "bufio" + "bytes" "context" + "fmt" "io" + "log" "net" + "strconv" ) func adapt(ctx context.Context, config Config, conn net.Conn) error { + reader := bufio.NewReader(conn) + for ctx.Err() == nil { + message, err := readMessage(reader) + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + if len(message) > 0 { + var reply string + switch message[0] { + case "hello": + reply = "*0\r\n" + case "client": + reply = "+OK\r\n" + case "ping": + reply = "+PONG\r\n" + case "set": + k := message[1].(string) + v := message[2] + reply = "+OK\r\n" + switch v2 := v.(type) { + case []byte: + if err := config.db.Put(ctx, 1, k, string(v2)); err != nil { + reply = "-" + err.Error() + "\r\n" + } + case string: + if err := config.db.Put(ctx, 1, k, v2); err != nil { + reply = "-" + err.Error() + "\r\n" + } + default: + log.Fatalf("not impl: set %T", v) + } + case "del": + k := message[1].(string) + reply = "+OK\r\n" + if err := config.db.Del(ctx, 1, k); err != nil { + reply = "-" + err.Error() + "\r\n" + } + case "get": + v, err := Select[any](ctx, 1, fmt.Sprint(message[1])) + if err != nil { + reply = "-" + err.Error() + } else if v == nil || *v == nil { + reply = "_\r\n" // RESP3 null + } else { + switch v2 := (*v).(type) { + case []byte: + reply = fmt.Sprintf("+%s\r\n", v2) + case string: + reply = fmt.Sprintf("+%s\r\n", v2) + default: + log.Fatalf("not impl: get type %T", v2) + return nil + } + } + default: + log.Fatalf("unhandled: %+v", message) + } + log.Printf("replying: %q", reply) + if _, err := fmt.Fprintf(conn, reply); err != nil { + return err + } + } + + log.Println() } + return io.EOF } + +func readMessage(reader *bufio.Reader) ([]any, error) { + firstLine, _, err := reader.ReadLine() + log.Printf("%q", firstLine) + if err != nil { + return nil, err + } + firstLine = bytes.TrimSuffix(firstLine, []byte("\r\n")) + if len(firstLine) == 0 { + return nil, nil + } + + switch firstLine[0] { + case '+': // simple string, like +OK + return []any{string(firstLine[1:])}, nil + case '-': // simple error, like -message + return nil, fmt.Errorf("error: %s", firstLine[1:]) + case ':': // number, like /[+-][0-9]+/ + firstLine = bytes.TrimPrefix(firstLine[1:], []byte("+")) + n, err := strconv.Atoi(string(firstLine[1:])) + return []any{n}, err + case '$': // long string, like $-1 for nil, like $LEN\r\nSTRING\r\n + if firstLine[1] == '-' { + return []any{nil}, nil + } + nextLine, _, err := reader.ReadLine() + log.Printf("%q", nextLine) + nextLine = bytes.TrimSuffix(nextLine, []byte("\r\n")) + return []any{string(nextLine)}, err + case '*': // array, like *-1 for nil, like *4 for [1,2,3,4] + n, err := strconv.Atoi(string(firstLine[1:])) + if err != nil { + return nil, err + } else if n == -1 { + return nil, nil + } + var result []any + for i := 0; i < n; i++ { + more, err := readMessage(reader) + if err != nil { + return nil, err + } + result = append(result, more...) + } + return result, nil + case '_': // nil + return nil, nil + case '#': // boolean, like #t or #f + return []any{firstLine[1] == 't'}, nil + case ',': // double + log.Fatal("not impl") + } + + log.Fatalf("not impl: %q", firstLine) + return nil, nil +} diff --git a/src/db.go b/src/db.go index 5920765..6f5d06c 100644 --- a/src/db.go +++ b/src/db.go @@ -36,7 +36,7 @@ func (db DB) setup(ctx context.Context) error { return err } if _, err := db.ExecContext(ctx, ` - CREATE TABLE data ( + CREATE TABLE IF NOT EXISTS data ( database INTEGER , key TEXT , value TEXT