Static files served with authorization on path-identified namespace

master
breel 2020-07-25 21:15:39 -06:00
parent f90ccaae38
commit 480456fbdf
5 changed files with 87 additions and 20 deletions

View File

@ -81,7 +81,7 @@ func TestAuth(t *testing.T) {
}) })
t.Run("auth: none provided: files", func(t *testing.T) { t.Run("auth: none provided: files", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/__files__/myfile?namespace=col", nil) r := httptest.NewRequest(http.MethodGet, "/__files__/col/myfile", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Code != http.StatusSeeOther { if w.Code != http.StatusSeeOther {
@ -140,7 +140,7 @@ func TestAuth(t *testing.T) {
t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) t.Fatalf("%d: %s", w.Code, w.Body.Bytes())
} }
r = httptest.NewRequest(http.MethodTrace, "/__files__/myfile?namespace=col", nil) r = httptest.NewRequest(http.MethodTrace, "/__files__/col/myfile", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, token)) r.Header.Set("Cookie", fmt.Sprintf("%s=%s", AuthKey, token))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)

View File

@ -23,11 +23,10 @@ func files(_ storage.Graph, w http.ResponseWriter, r *http.Request) error {
return nil return nil
} }
r.URL.Path = strings.TrimPrefix(r.URL.Path, config.New().FilePrefix) r.URL.Path = strings.TrimPrefix(r.URL.Path, config.New().FilePrefix)
if len(r.URL.Path) < 2 { if len(strings.TrimPrefix("/"+namespace, r.URL.Path)) < 2 {
http.NotFound(w, r) http.NotFound(w, r)
return nil return nil
} }
r.URL.Path = path.Join(namespace, r.URL.Path)
simpleserve.SetContentTypeIfMedia(w, r) simpleserve.SetContentTypeIfMedia(w, r)
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:

View File

@ -38,11 +38,10 @@ func TestFiles(t *testing.T) {
t.Logf("config: %+v", config.New()) t.Logf("config: %+v", config.New())
handler := jsonHandler(storage.Graph{}) handler := jsonHandler(storage.Graph{})
prefix := path.Join(config.New().FilePrefix) prefix := path.Join(config.New().FilePrefix, "col")
qparam := "namespace=col"
t.Run("bad qparam doesnt have namespace", func(t *testing.T) { t.Run("has qparam, doesnt fix namespace for files prefix", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/fake", prefix), nil) r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/?namespace=col", config.New().FilePrefix), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Code != http.StatusBadRequest { if w.Code != http.StatusBadRequest {
@ -51,7 +50,7 @@ func TestFiles(t *testing.T) {
}) })
t.Run("get fake file 404", func(t *testing.T) { t.Run("get fake file 404", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/fake?%s", prefix, qparam), nil) r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/fake", prefix), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Code != http.StatusNotFound { if w.Code != http.StatusNotFound {
@ -77,7 +76,7 @@ func TestFiles(t *testing.T) {
f.Write([]byte("hello, world")) f.Write([]byte("hello, world"))
f.Close() f.Close()
defer os.Remove(f.Name()) defer os.Remove(f.Name())
r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s?direct=true&%s", path.Join(prefix, path.Base(f.Name())), qparam), nil) r := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s?direct=true", path.Join(prefix, path.Base(f.Name()))), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
t.Logf("URL = %q", r.URL.String()) t.Logf("URL = %q", r.URL.String())
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
@ -109,14 +108,14 @@ func TestFiles(t *testing.T) {
})) }))
defer s.Close() defer s.Close()
r := httptest.NewRequest(http.MethodPost, fmt.Sprintf("%s?direct=true&%s", path.Join(prefix, name), qparam), strings.NewReader(s.URL)) r := httptest.NewRequest(http.MethodPost, fmt.Sprintf("%s?direct=true", path.Join(prefix, name)), strings.NewReader(s.URL))
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) t.Fatalf("%d: %s", w.Code, w.Body.Bytes())
} }
r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name)+"?"+qparam, nil) r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name), nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
@ -136,7 +135,7 @@ func TestFiles(t *testing.T) {
f.Close() f.Close()
defer os.Remove(f.Name()) defer os.Remove(f.Name())
r := httptest.NewRequest(http.MethodPost, path.Join(prefix, path.Base(f.Name()))+"?"+qparam, strings.NewReader(`bad link teehee`)) r := httptest.NewRequest(http.MethodPost, path.Join(prefix, path.Base(f.Name())), strings.NewReader(`bad link teehee`))
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Code == http.StatusOK { if w.Code == http.StatusOK {
@ -158,7 +157,7 @@ func TestFiles(t *testing.T) {
w2.Write([]byte("hello, world")) w2.Write([]byte("hello, world"))
writer.Close() writer.Close()
r := httptest.NewRequest(http.MethodPost, path.Join(prefix, name)+"?"+qparam, b) r := httptest.NewRequest(http.MethodPost, path.Join(prefix, name), b)
r.Header.Set("Content-Type", writer.FormDataContentType()) r.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
@ -166,7 +165,7 @@ func TestFiles(t *testing.T) {
t.Fatalf("%d: %s", w.Code, w.Body.Bytes()) t.Fatalf("%d: %s", w.Code, w.Body.Bytes())
} }
r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name)+"?"+qparam, nil) r = httptest.NewRequest(http.MethodGet, path.Join(prefix, name), nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Code != http.StatusOK { if w.Code != http.StatusOK {

View File

@ -77,14 +77,22 @@ func jsonHandler(g storage.Graph) http.Handler {
} }
func getAuthNamespace(r *http.Request) (string, error) { func getAuthNamespace(r *http.Request) (string, error) {
namespace := r.URL.Query().Get("namespace") namespace, err := getNamespace(r)
if len(namespace) == 0 { return strings.Join([]string{namespace, AuthKey}, "."), err
return "", errors.New("no namespace found")
}
return strings.Join([]string{namespace, AuthKey}, "."), nil
} }
func getNamespace(r *http.Request) (string, error) { func getNamespace(r *http.Request) (string, error) {
if strings.HasPrefix(r.URL.Path, config.New().FilePrefix) {
path := strings.TrimPrefix(r.URL.Path, config.New().FilePrefix+"/")
if path == r.URL.Path {
return "", errors.New("no namespace on files")
}
path = strings.Split(path, "/")[0]
if path == "" {
return "", errors.New("empty namespace on files")
}
return path, nil
}
namespace := r.URL.Query().Get("namespace") namespace := r.URL.Query().Get("namespace")
if len(namespace) == 0 { if len(namespace) == 0 {
return "", errors.New("no namespace found") return "", errors.New("no namespace found")

61
view/json_test.go Normal file
View File

@ -0,0 +1,61 @@
package view
import (
"local/dndex/config"
"net/http"
"net/url"
"os"
"testing"
)
func TestGetNamespace(t *testing.T) {
os.Args = os.Args[:1]
cases := map[string]struct {
url string
want string
invalid bool
}{
"no query param, not files, should fail": {
url: "/",
invalid: true,
},
"empty query param, not files, should fail": {
url: "/a?namespace=",
invalid: true,
},
"query param, not files": {
url: "/a?namespace=OK",
want: "OK",
},
"files, no query param": {
url: config.New().FilePrefix + "/OK",
want: "OK",
},
}
for name, d := range cases {
c := d
t.Run(name, func(t *testing.T) {
c.url = "http://host.tld:80" + c.url
uri, err := url.Parse(c.url)
if err != nil {
t.Fatal(err)
}
ns, err := getNamespace(&http.Request{URL: uri})
if err != nil && !c.invalid {
t.Fatal(c.invalid, err)
} else if err == nil && ns != c.want {
t.Fatal(c.want, ns)
}
authns, err := getAuthNamespace(&http.Request{URL: uri})
if err != nil && !c.invalid {
t.Fatal(c.invalid, err)
} else if err == nil && authns != c.want+"."+AuthKey {
t.Fatal(c.want+"."+AuthKey, authns)
}
})
}
}