package view import ( "bytes" "fmt" "io/ioutil" "local/dndex/config" "local/dndex/storage" "mime/multipart" "net/http" "net/http/httptest" "os" "path" "strings" "testing" ) func TestFiles(t *testing.T) { os.Args = os.Args[:1] f, err := ioutil.TempFile(os.TempDir(), "pattern*") if err != nil { t.Fatal(err) } f.Close() defer os.Remove(f.Name()) os.Setenv("DBURI", f.Name()) d, err := ioutil.TempDir(os.TempDir(), "pattern*") if err != nil { t.Fatal(err) } defer os.RemoveAll(d) if err := os.MkdirAll(path.Join(d, "col"), os.ModePerm); err != nil { t.Fatal(err) } os.Setenv("FILEROOT", d) t.Logf("config: %+v", config.New()) handler := jsonHandler(storage.Graph{}) prefix := path.Join(config.New().FilePrefix, "col") t.Run("has qparam, doesnt fix namespace for files prefix", func(t *testing.T) { r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/?namespace=col", config.New().FilePrefix), nil) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } }) t.Run("get fake file 404", func(t *testing.T) { r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/fake", prefix), nil) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusNotFound { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } }) t.Run("get typed files", func(t *testing.T) { cases := map[string]string{ "txt": "text/plain", "jpeg": "image/jpeg", "jpg": "image/jpeg", "gif": "image/gif", "mkv": "video/x-matroska", } for ext, ct := range cases { for _, extC := range []string{strings.ToLower(ext), strings.ToUpper(ext)} { f, err := ioutil.TempFile(path.Join(d, "col"), "*."+extC) t.Logf("tempFile(%q/col, *.%s) = %q", d, extC, f.Name()) if err != nil { t.Fatal(err) } f.Write([]byte("hello, world")) f.Close() defer os.Remove(f.Name()) r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s?direct=true", path.Join(prefix, path.Base(f.Name()))), nil) w := httptest.NewRecorder() t.Logf("URL = %q", r.URL.String()) handler.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } if contentType, ok := w.Header()["Content-Type"]; !ok { t.Fatal(w.Header()) } else if len(contentType) < 1 || !strings.HasPrefix(contentType[0], ct) { t.Fatal(contentType, ", want:", ct) } t.Logf("%+v", w) } } }) t.Run("post file: direct link", func(t *testing.T) { f, err := ioutil.TempFile(os.TempDir(), "*.html") if err != nil { t.Fatal(err) } f.Write([]byte("hello, world")) f.Close() defer os.Remove(f.Name()) name := path.Base(f.Name()) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hi")) })) defer s.Close() r := httptest.NewRequest(http.MethodPost, fmt.Sprintf("%s?direct=true", path.Join(prefix, name)), strings.NewReader(s.URL)) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name), nil) w = httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } if body := string(w.Body.Bytes()); body != "hi" { t.Fatal(body) } }) t.Run("post file: bad direct link", func(t *testing.T) { f, err := ioutil.TempFile(os.TempDir(), "*.txt") if err != nil { t.Fatal(err) } f.Write([]byte("hello, world")) f.Close() defer os.Remove(f.Name()) r := httptest.NewRequest(http.MethodPost, path.Join(prefix, path.Base(f.Name())), strings.NewReader(`bad link teehee`)) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code == http.StatusOK { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } }) t.Run("post file: form file", func(t *testing.T) { f, err := ioutil.TempFile(os.TempDir(), "*.html") if err != nil { t.Fatal(err) } f.Close() defer os.Remove(f.Name()) name := path.Base(f.Name()) b := bytes.NewBuffer(nil) writer := multipart.NewWriter(b) w2, _ := writer.CreateFormFile("file", name) w2.Write([]byte("hello, world")) writer.Close() r := httptest.NewRequest(http.MethodPost, path.Join(prefix, name), b) r.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name), nil) w = httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } if body := string(w.Body.Bytes()); body != "hello, world" { t.Fatal(body) } }) }