diff --git a/storage/driver/boltdb_test.go b/storage/driver/boltdb_test.go index 1a73627..2fd7baf 100644 --- a/storage/driver/boltdb_test.go +++ b/storage/driver/boltdb_test.go @@ -117,9 +117,6 @@ func TestBoltDBFind(t *testing.T) { if o.Title == "" { t.Error(o.Title) } - if o.Image == "" { - t.Error(o.Image) - } if o.Text == "" { t.Error(o.Text) } @@ -380,7 +377,6 @@ func fillBoltDB(t *testing.T, bdb *BoltDB) { Name: "name-" + uuid.New().String()[:5], Type: "type-" + uuid.New().String()[:5], Title: "titl-" + uuid.New().String()[:5], - Image: "imge-" + uuid.New().String()[:5], Text: "text-" + uuid.New().String()[:5], Modified: time.Now().UnixNano(), Connections: map[string]entity.One{p.Name: p}, diff --git a/storage/entity/one.go b/storage/entity/one.go index 50d942e..00cedf0 100644 --- a/storage/entity/one.go +++ b/storage/entity/one.go @@ -14,7 +14,6 @@ const ( Relationship = "relationship" Type = "type" Title = "title" - Image = "image" Text = "text" Modified = "modified" Connections = "connections" @@ -25,7 +24,6 @@ type One struct { Name string `bson:"_id,omitempty" json:"name,omitempty"` Type string `bson:"type,omitempty" json:"type,omitempty"` Title string `bson:"title,omitempty" json:"title,omitempty"` - Image string `bson:"image,omitempty" json:"image,omitempty"` Text string `bson:"text,omitempty" json:"text,omitempty"` Relationship string `bson:"relationship,omitempty" json:"relationship,omitempty"` Modified int64 `bson:"modified,omitempty" json:"modified,omitempty"` diff --git a/storage/graph_test.go b/storage/graph_test.go index bef7502..6a0b3e2 100644 --- a/storage/graph_test.go +++ b/storage/graph_test.go @@ -187,7 +187,6 @@ func randomOne() entity.One { Name: uuid.New().String()[:5], Type: "Humman", Title: "Biggus", - Image: "/path/to.jpg", Text: "tee hee xd", Modified: time.Now().UnixNano(), Connections: map[string]entity.One{}, diff --git a/view/files.go b/view/files.go index 18bbc0a..cee06b0 100644 --- a/view/files.go +++ b/view/files.go @@ -1,10 +1,17 @@ package view import ( + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" "local/dndex/config" "local/dndex/storage" "local/simpleserve/simpleserve" "net/http" + "net/url" + "os" "path" "strings" ) @@ -19,6 +26,8 @@ func files(_ storage.Graph, w http.ResponseWriter, r *http.Request) error { switch r.Method { case http.MethodGet: return filesGet(w, r) + case http.MethodPost: + return filesPost(w, r) default: http.NotFound(w, r) return nil @@ -26,6 +35,80 @@ func files(_ storage.Graph, w http.ResponseWriter, r *http.Request) error { } func filesGet(w http.ResponseWriter, r *http.Request) error { - http.ServeFile(w, r, path.Join(config.New().FileRoot, r.URL.Path)) + http.ServeFile(w, r, toLocalPath(r.URL.Path)) return nil } + +func filesPost(w http.ResponseWriter, r *http.Request) error { + p := toLocalPath(r.URL.Path) + if err := os.MkdirAll(path.Dir(p), os.ModePerm); err != nil { + return err + } + switch r.URL.Query().Get("direct") { + case "true": + return filesPostFromDirectLink(w, r) + default: + return filesPostFromUpload(w, r) + } +} + +func filesPostFromDirectLink(w http.ResponseWriter, r *http.Request) error { + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + url, err := url.Parse(string(b)) + if err != nil { + return err + } + resp, err := http.Get(url.String()) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad status from direct: %v", resp.StatusCode) + } + + path := toLocalPath(r.URL.Path) + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(f, resp.Body) + return err +} + +func filesPostFromUpload(w http.ResponseWriter, r *http.Request) error { + p := toLocalPath(r.URL.Path) + if err := os.MkdirAll(path.Dir(p), os.ModePerm); err != nil { + return err + } + if fi, err := os.Stat(p); err != nil && !os.IsNotExist(err) { + return err + } else if err == nil && fi.IsDir() { + return errors.New("path is a directory") + } + f, err := os.Create(p) + if err != nil { + return err + } + defer f.Close() + megabyte := 100 << 20 + r.ParseMultipartForm(int64(megabyte)) + file, _, err := r.FormFile("file") + if err != nil { + return err + } + defer file.Close() + if _, err := io.Copy(f, file); err != nil { + return err + } + return json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"}) +} + +func toLocalPath(p string) string { + return path.Join(config.New().FileRoot, p) +} diff --git a/view/files_test.go b/view/files_test.go index 8786194..166d981 100644 --- a/view/files_test.go +++ b/view/files_test.go @@ -1,10 +1,12 @@ package view import ( + "bytes" "fmt" "io/ioutil" "local/dndex/config" "local/dndex/storage" + "mime/multipart" "net/http" "net/http/httptest" "os" @@ -59,7 +61,7 @@ func TestFiles(t *testing.T) { } f.Write([]byte("hello, world")) f.Close() - r := httptest.NewRequest(http.MethodGet, path.Join(config.New().FilePrefix, path.Base(f.Name())), nil) + r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s?direct=true", path.Join(config.New().FilePrefix, path.Base(f.Name()))), nil) w := httptest.NewRecorder() t.Logf("URL = %q", r.URL.String()) handler.ServeHTTP(w, r) @@ -75,4 +77,85 @@ func TestFiles(t *testing.T) { } } }) + + t.Run("post file: direct link", func(t *testing.T) { + f, err := ioutil.TempFile(os.TempDir(), "*.html") + if err != nil { + t.Fatal(err) + } + f.Write([]byte("hello, world")) + f.Close() + name := path.Base(f.Name()) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi")) + })) + defer s.Close() + + r := httptest.NewRequest(http.MethodPost, fmt.Sprintf("%s?direct=true", path.Join(config.New().FilePrefix, name)), strings.NewReader(s.URL)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + + r = httptest.NewRequest(http.MethodGet, path.Join(config.New().FilePrefix, name), nil) + w = httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + if body := string(w.Body.Bytes()); body != "hi" { + t.Fatal(body) + } + }) + + t.Run("post file: bad direct link", func(t *testing.T) { + f, err := ioutil.TempFile(os.TempDir(), "*.txt") + if err != nil { + t.Fatal(err) + } + f.Write([]byte("hello, world")) + f.Close() + + r := httptest.NewRequest(http.MethodPost, path.Join(config.New().FilePrefix, path.Base(f.Name())), strings.NewReader(`bad link teehee`)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code == http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + }) + + t.Run("post file: form file", func(t *testing.T) { + f, err := ioutil.TempFile(os.TempDir(), "*.html") + if err != nil { + t.Fatal(err) + } + f.Close() + name := path.Base(f.Name()) + b := bytes.NewBuffer(nil) + writer := multipart.NewWriter(b) + w2, _ := writer.CreateFormFile("file", name) + w2.Write([]byte("hello, world")) + writer.Close() + + r := httptest.NewRequest(http.MethodPost, path.Join(config.New().FilePrefix, name), b) + r.Header.Set("Content-Type", writer.FormDataContentType()) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + + r = httptest.NewRequest(http.MethodGet, path.Join(config.New().FilePrefix, name), nil) + w = httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) + } + if body := string(w.Body.Bytes()); body != "hello, world" { + t.Fatal(body) + } + }) + } diff --git a/view/who_test.go b/view/who_test.go index 7e4b8fc..2cf042b 100644 --- a/view/who_test.go +++ b/view/who_test.go @@ -291,7 +291,6 @@ func randomOne() entity.One { Name: "name-" + uuid.New().String()[:5], Type: "type-" + uuid.New().String()[:5], Title: "titl-" + uuid.New().String()[:5], - Image: "imge-" + uuid.New().String()[:5], Text: "text-" + uuid.New().String()[:5], Modified: time.Now().UnixNano(), Connections: map[string]entity.One{},