diff --git a/config/config.go b/config/config.go index 1345055..9b357a7 100644 --- a/config/config.go +++ b/config/config.go @@ -13,6 +13,7 @@ type Config struct { Database string Driver []string FilePrefix string + APIPrefix string FileRoot string Auth bool AuthLifetime time.Duration @@ -33,6 +34,7 @@ func New() Config { as.Append(args.INT, "p", "port to listen on", 18114) as.Append(args.STRING, "fileprefix", "path prefix for file service", "/__files__") + as.Append(args.STRING, "api-prefix", "path prefix for api", "api") as.Append(args.STRING, "fileroot", "path to file hosting root", "/tmp/") as.Append(args.STRING, "database", "database name to use", "db") as.Append(args.STRING, "driver", "database driver args to use, like [local/storage.Type,arg1,arg2...] or [/path/to/boltdb]", "map") @@ -60,5 +62,6 @@ func New() Config { MaxFileSize: int64(as.GetInt("max-file-size")), RPS: as.GetInt("rps"), SysRPS: as.GetInt("sys-rps"), + APIPrefix: strings.TrimPrefix(as.GetString("api-prefix"), "/"), } } diff --git a/main_test.go b/main_test.go index 3ab7e47..8aadba4 100644 --- a/main_test.go +++ b/main_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "local/dndex/config" "local/dndex/server/auth" "local/dndex/storage/entity" "net/http" @@ -27,6 +28,8 @@ func Test(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(nil)) s.Close() p := strings.Split(s.URL, ":")[2] + os.Args = []string{"a"} + s.URL = s.URL + "/" + config.New().APIPrefix os.Args = strings.Split(fmt.Sprintf(`dndex -auth=true -database db -delay 5ms -driver map -fileprefix /files -fileroot %s -p %v -rps 50 -sys-rps 40`, d, p), " ") go main() diff --git a/server/auth/generate.go b/server/auth/generate.go index 016df87..dee4891 100644 --- a/server/auth/generate.go +++ b/server/auth/generate.go @@ -1,13 +1,19 @@ package auth import ( + "bytes" "context" + "encoding/json" "errors" + "fmt" + "io" + "io/ioutil" "local/dndex/storage" "local/dndex/storage/entity" "net/http" "github.com/google/uuid" + "gopkg.in/mgo.v2/bson" ) func GeneratePlain(g storage.RateLimitedGraph, r *http.Request) (string, error) { @@ -47,7 +53,26 @@ func readRequestedNamespace(r *http.Request) string { } func readRequested(r *http.Request, key string) string { - return r.FormValue(key) + switch r.Header.Get("Content-Type") { + case "application/json": + b, _ := ioutil.ReadAll(r.Body) + r.Body = struct { + io.Reader + io.Closer + }{ + Reader: bytes.NewReader(b), + Closer: r.Body, + } + m := bson.M{} + json.Unmarshal(b, &m) + v, ok := m[key] + if !ok { + return "" + } + return fmt.Sprint(v) + default: + return r.FormValue(key) + } } func getKeyForNamespace(ctx context.Context, g storage.RateLimitedGraph, namespace string) (string, error) { diff --git a/server/auth/generate_test.go b/server/auth/generate_test.go index e65dce9..34aadb9 100644 --- a/server/auth/generate_test.go +++ b/server/auth/generate_test.go @@ -95,3 +95,30 @@ func TestGenerate(t *testing.T) { } }) } + +func TestReadRequested(t *testing.T) { + t.Run("form: ignore query params", func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/a=c", nil) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if got := readRequested(r, "a"); got != "" { + t.Fatal(got) + } + }) + + t.Run("form: body beats query params", func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/a=c", strings.NewReader(`a=b`)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if got := readRequested(r, "a"); got != "b" { + t.Fatal(got) + } + }) + + t.Run("json: OK", func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/a=c", strings.NewReader(`{"a": "b"}`)) + r.Header.Set("Content-Type", "application/json") + if got := readRequested(r, "a"); got != "b" { + t.Fatal(got) + } + }) + +} diff --git a/server/middleware.go b/server/middleware.go index 8e71493..6cdc460 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -6,6 +6,7 @@ import ( "local/dndex/server/auth" "local/gziphttp" "net/http" + "path" "strings" "time" ) @@ -69,3 +70,11 @@ func (rest *REST) shift(foo http.HandlerFunc) http.HandlerFunc { foo(w, r) } } + +func (rest *REST) deprefix(foo http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + newpath := strings.TrimPrefix(r.URL.Path, path.Join("/", config.New().APIPrefix)) + r.URL.Path = newpath + foo(w, r) + } +} diff --git a/server/middleware_test.go b/server/middleware_test.go index 97cf17f..4e48277 100644 --- a/server/middleware_test.go +++ b/server/middleware_test.go @@ -131,3 +131,20 @@ func TestMiddlewareShift(t *testing.T) { }) } } + +func TestMiddlewareDeprefix(t *testing.T) { + resolved := "" + bar := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resolved = r.URL.Path + }) + + os.Args = []string{"a"} + os.Setenv("API_PREFIX", "myprefix") + r := httptest.NewRequest(http.MethodGet, "/myprefix/abc", nil) + rest := &REST{} + rest.deprefix(bar)(nil, r) + + if resolved != "/abc" { + t.Fatal(resolved) + } +} diff --git a/server/rest.go b/server/rest.go index 3aa39f1..bce926a 100644 --- a/server/rest.go +++ b/server/rest.go @@ -7,6 +7,7 @@ import ( "local/dndex/storage" "local/router" "net/http" + "path" "strings" ) @@ -42,11 +43,11 @@ func NewREST(g storage.RateLimitedGraph) (*REST, error) { fmt.Sprintf("dump"): rest.dump, } - for path, foo := range paths { + for urlpath, foo := range paths { bar := foo bar = rest.shift(bar) bar = rest.scoped(bar) - switch strings.Split(path, "/")[0] { + switch strings.Split(urlpath, "/")[0] { case "users": case "version": default: @@ -54,7 +55,9 @@ func NewREST(g storage.RateLimitedGraph) (*REST, error) { } bar = rest.defend(bar) bar = rest.delay(bar) - if err := rest.router.Add(path, bar); err != nil { + bar = rest.deprefix(bar) + routerpath := path.Join("/", config.New().APIPrefix, urlpath) + if err := rest.router.Add(routerpath, bar); err != nil { return nil, err } } diff --git a/server/rest_test.go b/server/rest_test.go index 00467c5..3786296 100644 --- a/server/rest_test.go +++ b/server/rest_test.go @@ -163,10 +163,10 @@ func TestRESTRouter(t *testing.T) { for name, d := range cases { c := d - path := name + urlpath := path.Join("/", config.New().APIPrefix, name) rest, setAuth, clean := testREST(t) defer clean() - r := httptest.NewRequest(c.method, path, strings.NewReader(``)) + r := httptest.NewRequest(c.method, urlpath, strings.NewReader(``)) setAuth(r) w := httptest.NewRecorder() rest.router.ServeHTTP(w, r) diff --git a/server/users.go b/server/users.go index 73f6b05..1f4468b 100644 --- a/server/users.go +++ b/server/users.go @@ -16,7 +16,7 @@ func (rest *REST) users(w http.ResponseWriter, r *http.Request) { rest.respNotFound(w) return } - r.Header.Set("Application-Type", "application/x-www-form-urlencoded") + rest.usersContentType(r) switch r.URL.Path { case "/register": rest.usersRegister(w, r) @@ -28,6 +28,7 @@ func (rest *REST) users(w http.ResponseWriter, r *http.Request) { } func (rest *REST) usersRegister(w http.ResponseWriter, r *http.Request) { + rest.usersContentType(r) err := auth.Register(rest.g, r) if err != nil { rest.respError(w, err) @@ -37,6 +38,7 @@ func (rest *REST) usersRegister(w http.ResponseWriter, r *http.Request) { } func (rest *REST) usersLogin(w http.ResponseWriter, r *http.Request) { + rest.usersContentType(r) salt := uuid.New().String()[:5] var token string var err error @@ -56,3 +58,9 @@ func (rest *REST) usersLogin(w http.ResponseWriter, r *http.Request) { "salt": salt, }) } + +func (rest *REST) usersContentType(r *http.Request) { + if r.Header.Get("Application-Type") == "" { + r.Header.Set("Application-Type", "application/x-www-form-urlencoded") + } +} diff --git a/server/users_test.go b/server/users_test.go index f73506a..d284602 100644 --- a/server/users_test.go +++ b/server/users_test.go @@ -17,9 +17,11 @@ func TestUsersRegister(t *testing.T) { defer clean() t.Run("register ok", func(t *testing.T) { - user := uuid.New().String()[:5] - pwd := uuid.New().String()[:5] - testRegisterOK(t, rest, user, pwd) + for _, json := range []bool{false, true} { + user := uuid.New().String()[:5] + pwd := uuid.New().String()[:5] + testRegisterOK(t, rest, user, pwd, json) + } }) t.Run("register 400: nil body", func(t *testing.T) { @@ -76,36 +78,54 @@ func TestUsersLogin(t *testing.T) { defer clean() t.Run("login ok", func(t *testing.T) { - user := uuid.New().String()[:5] - pwd := uuid.New().String()[:5] - testRegisterOK(t, rest, user, pwd) - testLoginOK(t, rest, user, pwd) + for _, json := range []bool{false, true} { + user := uuid.New().String()[:5] + pwd := uuid.New().String()[:5] + testRegisterOK(t, rest, user, pwd, json) + testLoginOK(t, rest, user, pwd, json) + } }) t.Run("login 404 user", func(t *testing.T) { - pwd := uuid.New().String()[:5] - testLoginNotOK(t, rest, "bad", pwd) + for _, json := range []bool{false, true} { + pwd := uuid.New().String()[:5] + testLoginNotOK(t, rest, "bad", pwd, json) + } }) t.Run("login bad user", func(t *testing.T) { - user := uuid.New().String()[:5] - pwd := uuid.New().String()[:5] - testRegisterOK(t, rest, user, pwd) - testLoginNotOK(t, rest, "bad", pwd) + for _, json := range []bool{false, true} { + user := uuid.New().String()[:5] + pwd := uuid.New().String()[:5] + testRegisterOK(t, rest, user, pwd, json) + testLoginNotOK(t, rest, "bad", pwd, json) + } }) t.Run("login bad pwd", func(t *testing.T) { - user := uuid.New().String()[:5] - pwd := uuid.New().String()[:5] - testRegisterOK(t, rest, user, pwd) - testLoginNotOK(t, rest, user, "bad") + for _, json := range []bool{false, true} { + user := uuid.New().String()[:5] + pwd := uuid.New().String()[:5] + testRegisterOK(t, rest, user, pwd, json) + testLoginNotOK(t, rest, user, "bad", json) + } }) } -func testRegisterOK(t *testing.T, rest *REST, user, pwd string) { +func testRegisterOK(t *testing.T, rest *REST, user, pwd string, useJSON bool) { body := fmt.Sprintf(`%s=%s&%s=%s`, auth.UserKey, user, auth.AuthKey, pwd) + if useJSON { + s, _ := json.Marshal(map[string]string{ + auth.UserKey: user, + auth.AuthKey: pwd, + }) + body = string(s) + } r := httptest.NewRequest(http.MethodPost, "/register", strings.NewReader(body)) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if useJSON { + r.Header.Set("Content-Type", "application/json") + } w := httptest.NewRecorder() rest.users(w, r) if w.Code != http.StatusOK { @@ -113,10 +133,19 @@ func testRegisterOK(t *testing.T, rest *REST, user, pwd string) { } } -func testLoginNotOK(t *testing.T, rest *REST, user, pwd string) { +func testLoginNotOK(t *testing.T, rest *REST, user, pwd string, useJSON bool) { body := fmt.Sprintf(`%s=%s`, auth.UserKey, user) + if useJSON { + s, _ := json.Marshal(map[string]string{ + auth.UserKey: user, + }) + body = string(s) + } r := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(body)) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if useJSON { + r.Header.Set("Content-Type", "application/json") + } w := httptest.NewRecorder() rest.users(w, r) if w.Code < http.StatusBadRequest { @@ -136,10 +165,19 @@ func testLoginNotOK(t *testing.T, rest *REST, user, pwd string) { } } -func testLoginOK(t *testing.T, rest *REST, user, pwd string) string { +func testLoginOK(t *testing.T, rest *REST, user, pwd string, useJSON bool) string { body := fmt.Sprintf(`%s=%s`, auth.UserKey, user) + if useJSON { + s, _ := json.Marshal(map[string]string{ + auth.UserKey: user, + }) + body = string(s) + } r := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(body)) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if useJSON { + r.Header.Set("Content-Type", "application/json") + } w := httptest.NewRecorder() rest.users(w, r) if w.Code != http.StatusOK {