diff --git a/view/auth_test.go b/view/auth_test.go index 1948f04..32aea13 100644 --- a/view/auth_test.go +++ b/view/auth_test.go @@ -81,7 +81,7 @@ func TestAuth(t *testing.T) { }) t.Run("auth: none provided: files", func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/__files__/myfile?namespace=col", nil) + r := httptest.NewRequest(http.MethodGet, "/__files__/col/myfile", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusSeeOther { @@ -140,7 +140,7 @@ func TestAuth(t *testing.T) { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } - r = httptest.NewRequest(http.MethodTrace, "/__files__/myfile?namespace=col", nil) + r = httptest.NewRequest(http.MethodTrace, "/__files__/col/myfile", nil) w = httptest.NewRecorder() r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, token)) handler.ServeHTTP(w, r) diff --git a/view/files.go b/view/files.go index 27e8a47..e948169 100644 --- a/view/files.go +++ b/view/files.go @@ -23,11 +23,10 @@ func files(_ storage.Graph, w http.ResponseWriter, r *http.Request) error { return nil } r.URL.Path = strings.TrimPrefix(r.URL.Path, config.New().FilePrefix) - if len(r.URL.Path) < 2 { + if len(strings.TrimPrefix("/"+namespace, r.URL.Path)) < 2 { http.NotFound(w, r) return nil } - r.URL.Path = path.Join(namespace, r.URL.Path) simpleserve.SetContentTypeIfMedia(w, r) switch r.Method { case http.MethodGet: diff --git a/view/files_test.go b/view/files_test.go index f87a3f9..f5b9a44 100644 --- a/view/files_test.go +++ b/view/files_test.go @@ -38,11 +38,10 @@ func TestFiles(t *testing.T) { t.Logf("config: %+v", config.New()) handler := jsonHandler(storage.Graph{}) - prefix := path.Join(config.New().FilePrefix) - qparam := "namespace=col" + prefix := path.Join(config.New().FilePrefix, "col") - t.Run("bad qparam doesnt have namespace", func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/fake", prefix), nil) + t.Run("has qparam, doesnt fix namespace for files prefix", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/?namespace=col", config.New().FilePrefix), nil) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusBadRequest { @@ -51,7 +50,7 @@ func TestFiles(t *testing.T) { }) t.Run("get fake file 404", func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/fake?%s", prefix, qparam), nil) + r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/fake", prefix), nil) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusNotFound { @@ -77,7 +76,7 @@ func TestFiles(t *testing.T) { f.Write([]byte("hello, world")) f.Close() defer os.Remove(f.Name()) - r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s?direct=true&%s", path.Join(prefix, path.Base(f.Name())), qparam), nil) + r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s?direct=true", path.Join(prefix, path.Base(f.Name()))), nil) w := httptest.NewRecorder() t.Logf("URL = %q", r.URL.String()) handler.ServeHTTP(w, r) @@ -109,14 +108,14 @@ func TestFiles(t *testing.T) { })) defer s.Close() - r := httptest.NewRequest(http.MethodPost, fmt.Sprintf("%s?direct=true&%s", path.Join(prefix, name), qparam), strings.NewReader(s.URL)) + r := httptest.NewRequest(http.MethodPost, fmt.Sprintf("%s?direct=true", path.Join(prefix, name)), strings.NewReader(s.URL)) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } - r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name)+"?"+qparam, nil) + r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name), nil) w = httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusOK { @@ -136,7 +135,7 @@ func TestFiles(t *testing.T) { f.Close() defer os.Remove(f.Name()) - r := httptest.NewRequest(http.MethodPost, path.Join(prefix, path.Base(f.Name()))+"?"+qparam, strings.NewReader(`bad link teehee`)) + r := httptest.NewRequest(http.MethodPost, path.Join(prefix, path.Base(f.Name())), strings.NewReader(`bad link teehee`)) w := httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code == http.StatusOK { @@ -158,7 +157,7 @@ func TestFiles(t *testing.T) { w2.Write([]byte("hello, world")) writer.Close() - r := httptest.NewRequest(http.MethodPost, path.Join(prefix, name)+"?"+qparam, b) + r := httptest.NewRequest(http.MethodPost, path.Join(prefix, name), b) r.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() handler.ServeHTTP(w, r) @@ -166,7 +165,7 @@ func TestFiles(t *testing.T) { t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) } - r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name)+"?"+qparam, nil) + r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name), nil) w = httptest.NewRecorder() handler.ServeHTTP(w, r) if w.Code != http.StatusOK { diff --git a/view/json.go b/view/json.go index edf7c0c..d474759 100644 --- a/view/json.go +++ b/view/json.go @@ -77,14 +77,22 @@ func jsonHandler(g storage.Graph) http.Handler { } func getAuthNamespace(r *http.Request) (string, error) { - namespace := r.URL.Query().Get("namespace") - if len(namespace) == 0 { - return "", errors.New("no namespace found") - } - return strings.Join([]string{namespace, AuthKey}, "."), nil + namespace, err := getNamespace(r) + return strings.Join([]string{namespace, AuthKey}, "."), err } func getNamespace(r *http.Request) (string, error) { + if strings.HasPrefix(r.URL.Path, config.New().FilePrefix) { + path := strings.TrimPrefix(r.URL.Path, config.New().FilePrefix+"/") + if path == r.URL.Path { + return "", errors.New("no namespace on files") + } + path = strings.Split(path, "/")[0] + if path == "" { + return "", errors.New("empty namespace on files") + } + return path, nil + } namespace := r.URL.Query().Get("namespace") if len(namespace) == 0 { return "", errors.New("no namespace found") diff --git a/view/json_test.go b/view/json_test.go new file mode 100644 index 0000000..2d1493b --- /dev/null +++ b/view/json_test.go @@ -0,0 +1,61 @@ +package view + +import ( + "local/dndex/config" + "net/http" + "net/url" + "os" + "testing" +) + +func TestGetNamespace(t *testing.T) { + os.Args = os.Args[:1] + + cases := map[string]struct { + url string + want string + invalid bool + }{ + "no query param, not files, should fail": { + url: "/", + invalid: true, + }, + "empty query param, not files, should fail": { + url: "/a?namespace=", + invalid: true, + }, + "query param, not files": { + url: "/a?namespace=OK", + want: "OK", + }, + "files, no query param": { + url: config.New().FilePrefix + "/OK", + want: "OK", + }, + } + + for name, d := range cases { + c := d + t.Run(name, func(t *testing.T) { + c.url = "http://host.tld:80" + c.url + uri, err := url.Parse(c.url) + if err != nil { + t.Fatal(err) + } + + ns, err := getNamespace(&http.Request{URL: uri}) + if err != nil && !c.invalid { + t.Fatal(c.invalid, err) + } else if err == nil && ns != c.want { + t.Fatal(c.want, ns) + } + + authns, err := getAuthNamespace(&http.Request{URL: uri}) + if err != nil && !c.invalid { + t.Fatal(c.invalid, err) + } else if err == nil && authns != c.want+"."+AuthKey { + t.Fatal(c.want+"."+AuthKey, authns) + } + }) + } +}