From 25a43c8a0bfcc59be14846319a86bd8d3a48da85 Mon Sep 17 00:00:00 2001 From: breel Date: Fri, 28 Aug 2020 14:47:18 -0600 Subject: [PATCH] Impl direct uplod --- server/files.go | 29 +++++++++++++++++++++++++++-- server/files_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/server/files.go b/server/files.go index 90e1cc1..31f504a 100644 --- a/server/files.go +++ b/server/files.go @@ -2,6 +2,7 @@ package server import ( "io" + "io/ioutil" "local/dndex/config" "net/http" "os" @@ -50,12 +51,34 @@ func (rest *REST) filesCreate(w http.ResponseWriter, r *http.Request) { return } defer f.Close() - if _, err := io.Copy(f, r.Body); err != nil { + if err := rest.filesStream(r, f); err != nil { rest.respError(w, err) + return } w.Write([]byte(id)) } +func (rest *REST) filesStream(r *http.Request, f io.Writer) error { + var reader io.Reader = r.Body + _, direct := r.URL.Query()["direct"] + if direct { + target, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + resp, err := http.Get(string(target)) + if err != nil { + return err + } + defer resp.Body.Close() + reader = resp.Body + } + if _, err := io.Copy(f, reader); err != nil { + return err + } + return nil +} + func (rest *REST) filesDelete(w http.ResponseWriter, r *http.Request) { localPath := rest.filesPath(r) if stat, err := os.Stat(localPath); os.IsNotExist(err) { @@ -106,10 +129,12 @@ func (rest *REST) filesUpdate(w http.ResponseWriter, r *http.Request) { return } defer f.Close() - if _, err := io.Copy(f, r.Body); err != nil { + + if err := rest.filesStream(r, f); err != nil { rest.respError(w, err) return } + if err := os.Rename(localPath+".tmp", localPath); err != nil { rest.respError(w, err) return diff --git a/server/files_test.go b/server/files_test.go index 7f21421..8be910a 100644 --- a/server/files_test.go +++ b/server/files_test.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "net/http" "net/http/httptest" "strings" @@ -101,3 +102,37 @@ func testFilesReq(t *testing.T, rest *REST, id string, scope func(*http.Request) rest.files(w, r) return w } + +func TestFilesStream(t *testing.T) { + rest, _, clean := testREST(t) + defer clean() + + t.Run("simple upload", func(t *testing.T) { + value := uuid.New().String() + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(value)) + buff := bytes.NewBuffer(nil) + if err := rest.filesStream(r, buff); err != nil { + t.Fatal(err) + } + if s := string(buff.Bytes()); s != value { + t.Fatal(s) + } + }) + + t.Run("direct upload", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(http.NotFound)) + defer s.Close() + r := httptest.NewRequest(http.MethodPost, "/?direct", strings.NewReader(s.URL)) + buff := bytes.NewBuffer(nil) + if err := rest.filesStream(r, buff); err != nil { + t.Fatal(err) + } + w := httptest.NewRecorder() + http.NotFound(w, nil) + want := string(w.Body.Bytes()) + got := string(buff.Bytes()) + if want != got { + t.Fatal(want, got) + } + }) +}