From f90ccaae3809cb3d591637c0ea76e2103c9434e8 Mon Sep 17 00:00:00 2001 From: breel Date: Sat, 25 Jul 2020 20:55:22 -0600 Subject: [PATCH] Ensure cleaning up --- config/config.go | 1 + public/swagger/swagger-files.yaml | 5 +++++ view/files.go | 3 ++- view/files_test.go | 22 ++++++++++------------ 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/config/config.go b/config/config.go index f4a6310..71a585a 100644 --- a/config/config.go +++ b/config/config.go @@ -39,6 +39,7 @@ func New() Config { as.Append(args.INT, "max-file-size", "max file size for uploads in bytes", 50*(1<<20)) if err := as.Parse(); err != nil { + os.Remove(f.Name()) panic(err) } diff --git a/public/swagger/swagger-files.yaml b/public/swagger/swagger-files.yaml index 68cff0e..855e830 100644 --- a/public/swagger/swagger-files.yaml +++ b/public/swagger/swagger-files.yaml @@ -5,6 +5,7 @@ paths: summary: "Fetch a file" parameters: - $ref: "#/components/parameters/path" + - $ref: "#/components/parameters/namespace" responses: 200: content: @@ -18,6 +19,7 @@ paths: summary: "Provide or direct link to a file" parameters: - $ref: "#/components/parameters/path" + - $ref: "#/components/parameters/namespace" - $ref: "#/components/parameters/direct" requestBody: description: "If ?direct=true, then a direct link to the file, else a multi-part form" @@ -39,6 +41,9 @@ components: path: $ref: "./swagger.yaml#/components/parameters/path" + namespace: + $ref: "./swagger.yaml#/components/parameters/namespace" + direct: name: direct in: query diff --git a/view/files.go b/view/files.go index bd24f04..27e8a47 100644 --- a/view/files.go +++ b/view/files.go @@ -22,11 +22,12 @@ func files(_ storage.Graph, w http.ResponseWriter, r *http.Request) error { http.Error(w, err.Error(), http.StatusBadRequest) return nil } - r.URL.Path = strings.TrimPrefix(r.URL.Path, path.Join(config.New().FilePrefix, namespace)) + r.URL.Path = strings.TrimPrefix(r.URL.Path, config.New().FilePrefix) if len(r.URL.Path) < 2 { http.NotFound(w, r) return nil } + r.URL.Path = path.Join(namespace, r.URL.Path) simpleserve.SetContentTypeIfMedia(w, r) switch r.Method { case http.MethodGet: diff --git a/view/files_test.go b/view/files_test.go index ecbb63c..f87a3f9 100644 --- a/view/files_test.go +++ b/view/files_test.go @@ -30,12 +30,15 @@ func TestFiles(t *testing.T) { t.Fatal(err) } defer os.RemoveAll(d) + if err := os.MkdirAll(path.Join(d, "col"), os.ModePerm); err != nil { + t.Fatal(err) + } os.Setenv("FILEROOT", d) t.Logf("config: %+v", config.New()) handler := jsonHandler(storage.Graph{}) - prefix := path.Join(config.New().FilePrefix, "col") + prefix := path.Join(config.New().FilePrefix) qparam := "namespace=col" t.Run("bad qparam doesnt have namespace", func(t *testing.T) { @@ -47,15 +50,6 @@ func TestFiles(t *testing.T) { } }) - t.Run("bad prefix doesnt have namespace", func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/fake?%s", config.New().FilePrefix, qparam), nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - if w.Code != http.StatusNotFound { - t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) - } - }) - t.Run("get fake file 404", func(t *testing.T) { r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/fake?%s", prefix, qparam), nil) w := httptest.NewRecorder() @@ -75,13 +69,14 @@ func TestFiles(t *testing.T) { } for ext, ct := range cases { for _, extC := range []string{strings.ToLower(ext), strings.ToUpper(ext)} { - f, err := ioutil.TempFile(d, "*."+extC) - t.Logf("tempFile(%q, *.%s) = %q", d, extC, f.Name()) + f, err := ioutil.TempFile(path.Join(d, "col"), "*."+extC) + t.Logf("tempFile(%q/col, *.%s) = %q", d, extC, f.Name()) if err != nil { t.Fatal(err) } f.Write([]byte("hello, world")) f.Close() + defer os.Remove(f.Name()) r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s?direct=true&%s", path.Join(prefix, path.Base(f.Name())), qparam), nil) w := httptest.NewRecorder() t.Logf("URL = %q", r.URL.String()) @@ -106,6 +101,7 @@ func TestFiles(t *testing.T) { } f.Write([]byte("hello, world")) f.Close() + defer os.Remove(f.Name()) name := path.Base(f.Name()) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -138,6 +134,7 @@ func TestFiles(t *testing.T) { } f.Write([]byte("hello, world")) f.Close() + defer os.Remove(f.Name()) r := httptest.NewRequest(http.MethodPost, path.Join(prefix, path.Base(f.Name()))+"?"+qparam, strings.NewReader(`bad link teehee`)) w := httptest.NewRecorder() @@ -153,6 +150,7 @@ func TestFiles(t *testing.T) { t.Fatal(err) } f.Close() + defer os.Remove(f.Name()) name := path.Base(f.Name()) b := bytes.NewBuffer(nil) writer := multipart.NewWriter(b)