diff --git a/server/files.go b/server/files.go index 0f0eaaf..2c93059 100644 --- a/server/files.go +++ b/server/files.go @@ -33,7 +33,11 @@ func (rest *REST) filesCreate(w http.ResponseWriter, r *http.Request) { rest.respConflict(w) return } - f, err := os.Open(localPath) + if err := os.MkdirAll(path.Dir(localPath), os.ModePerm); err != nil { + rest.respError(w, err) + return + } + f, err := os.Create(localPath) if err != nil { rest.respError(w, err) return @@ -89,7 +93,7 @@ func (rest *REST) filesUpdate(w http.ResponseWriter, r *http.Request) { rest.respConflict(w) return } - f, err := os.Open(localPath) + f, err := os.Create(localPath + ".tmp") if err != nil { rest.respError(w, err) return @@ -97,6 +101,11 @@ func (rest *REST) filesUpdate(w http.ResponseWriter, r *http.Request) { defer f.Close() if _, err := io.Copy(f, r.Body); err != nil { rest.respError(w, err) + return + } + if err := os.Rename(localPath+".tmp", localPath); err != nil { + rest.respError(w, err) + return } rest.respOK(w) } diff --git a/server/files_test.go b/server/files_test.go new file mode 100644 index 0000000..ef71346 --- /dev/null +++ b/server/files_test.go @@ -0,0 +1,118 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/google/uuid" +) + +func TestFiles(t *testing.T) { + cases := map[string]func(*testing.T, *REST, func(*http.Request)){ + "create-get": func(t *testing.T, rest *REST, scope func(r *http.Request)) { + s := uuid.New().String() + w := testFilesPost(t, rest, s, scope, s) + if w.Code != http.StatusOK { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + w = testFilesGet(t, rest, s, scope) + if w.Code != http.StatusOK { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + if s2 := string(w.Body.Bytes()); s2 != s { + t.Fatalf("want %q, got %q", s, s2) + } + }, + "create-collision": func(t *testing.T, rest *REST, scope func(r *http.Request)) { + s := uuid.New().String() + for i := 0; i < 2; i++ { + w := testFilesPost(t, rest, s, scope, s) + ok := false + switch i { + case 0: + ok = w.Code == http.StatusOK + default: + ok = w.Code == http.StatusConflict + } + if !ok { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + } + }, + "delete": func(t *testing.T, rest *REST, scope func(r *http.Request)) { + s := uuid.New().String() + w := testFilesPost(t, rest, s, scope, s) + if w.Code != http.StatusOK { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + w = testFilesDelete(t, rest, s, scope) + if w.Code != http.StatusOK { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + w = testFilesGet(t, rest, s, scope) + if w.Code != http.StatusNotFound { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + }, + "404": func(t *testing.T, rest *REST, scope func(r *http.Request)) { + s := uuid.New().String() + w := testFilesGet(t, rest, s, scope) + if w.Code != http.StatusNotFound { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + }, + "update": func(t *testing.T, rest *REST, scope func(r *http.Request)) { + s := uuid.New().String() + w := testFilesPost(t, rest, s, scope, s) + if w.Code != http.StatusOK { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + w = testFilesPut(t, rest, s, scope, s+"new") + if w.Code != http.StatusOK { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + w = testFilesGet(t, rest, s, scope) + if w.Code != http.StatusOK { + t.Fatal(w.Code, string(w.Body.Bytes())) + } + if v := string(w.Body.Bytes()); v != s+"new" { + t.Fatal(v) + } + }, + } + + for name, foo := range cases { + bar := foo + t.Run(name, func(t *testing.T) { + rest, scope, clean := testREST(t) + bar(t, rest, scope) + defer clean() + }) + } +} + +func testFilesDelete(t *testing.T, rest *REST, id string, scope func(*http.Request)) *httptest.ResponseRecorder { + return testFilesReq(t, rest, id, scope, http.MethodDelete, "") +} + +func testFilesGet(t *testing.T, rest *REST, id string, scope func(*http.Request)) *httptest.ResponseRecorder { + return testFilesReq(t, rest, id, scope, http.MethodGet, "") +} + +func testFilesPut(t *testing.T, rest *REST, id string, scope func(*http.Request), body string) *httptest.ResponseRecorder { + return testFilesReq(t, rest, id, scope, http.MethodPut, body) +} + +func testFilesPost(t *testing.T, rest *REST, id string, scope func(*http.Request), body string) *httptest.ResponseRecorder { + return testFilesReq(t, rest, id, scope, http.MethodPost, body) +} + +func testFilesReq(t *testing.T, rest *REST, id string, scope func(*http.Request), method, body string) *httptest.ResponseRecorder { + r := httptest.NewRequest(method, "/"+id, strings.NewReader(body)) + scope(r) + w := httptest.NewRecorder() + rest.files(w, r) + return w +}