package server import ( "bytes" "encoding/json" "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" ) type badRouter struct { accept []string } func (br *badRouter) Add(p string, h http.HandlerFunc) error { for i := range br.accept { if br.accept[i] == p { return nil } } return errors.New("rejected path") } func (br *badRouter) ServeHTTP(http.ResponseWriter, *http.Request) {} func TestServerRoutesBadRouter(t *testing.T) { server, _, _ := mockServer() br := badRouter{ accept: make([]string, 0), } server.router = &br toAdd := []string{ "/nil", "/admin/register/{}", "/register/{}", "/generate/{}/{}", "/retrieve/{}/{}", "/revoke/{}/{}", "/lookup/{}/{}", "/policies", } for _, path := range toAdd { br.accept = append(br.accept, path) if err := server.Routes(); err == nil { t.Errorf("can add non-allowed routes") } } } func TestServerRoutes(t *testing.T) { server, _, _ := mockServer() if err := server.db.Register("a"); err != nil { t.Fatalf("cannot register: %v", err) } token, err := server.db.New("a", "b") if err != nil { t.Fatalf("cannot new: %v", err) } paths := []string{ "retrieve/a/b/" + token.Accessor, "revoke/a/" + token.Accessor, "lookup/a/b", "policies", "admin/register/a", "register/a", "generate/a/b", } for _, p := range paths { w := httptest.NewRecorder() r, _ := http.NewRequest("GET", p, nil) server.ServeHTTP(w, r) if w.Code == 404 { t.Errorf("not found for %v", p) } } } func TestServerAdminRegister(t *testing.T) { cases := []struct { name string status int }{ { name: "", status: 404, }, { name: "name", status: 200, }, } path := "admin/register" server, _, _ := mockServer() for i, c := range cases { w := httptest.NewRecorder() r, _ := http.NewRequest("GET", path, nil) r = addParam(r, "name", c.name) server.ServeHTTP(w, r) if w.Code != c.status { t.Errorf("[%d] wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) } body, _ := ioutil.ReadAll(w.Body) if len(body) < 1 { t.Errorf("[%d] empty body in admin/register response: %q", i, body) } t.Logf("[%d] admin/register body: %q", i, body) } } func TestServerRegister(t *testing.T) { cases := []struct { name string status int }{ { name: "", status: 404, }, { name: "name", status: 200, }, } path := "register" server, _, _ := mockServer() for i, c := range cases { w := httptest.NewRecorder() r, _ := http.NewRequest("GET", path, nil) r = addParam(r, "name", c.name) auth, err := server.authdb.New(serverNS, c.name) if err != nil { t.Fatalf("cannot authdb new: %v", err) } r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token) server.ServeHTTP(w, r) if w.Code != c.status { t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) } } } func TestServerGenerate(t *testing.T) { cases := []struct { name string to string status int }{ { name: "", to: "", status: 404, }, { name: "name", to: "", status: 404, }, { name: "", to: "to", status: 404, }, { name: "name", to: "to", status: 200, }, } path := "generate" server, _, _ := mockServer() for i, c := range cases { w := httptest.NewRecorder() r, _ := http.NewRequest("GET", path, nil) r = addParam(r, "name", c.name) r = addParam(r, "to", c.to) auth, err := server.authdb.New(serverNS, c.name) if err != nil { t.Fatalf("cannot authdb new: %v", err) } r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token) server.ServeHTTP(w, r) if w.Code != c.status { t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) } var result struct { Token string `json:"token"` Acc string `json:"accessor"` TTL int `json:"TTL"` To string `json:"to"` } if err := json.NewDecoder(w.Body).Decode(&result); c.status == 200 && err != nil { t.Errorf("invalid body: %v", err) } else if c.status == 200 { if result.To != c.to { t.Errorf("wrong `to` in response: got %v, want %v", result.To, c.to) } if len(result.Token) == 0 { t.Errorf("empty `token` in response") } if len(result.Acc) == 0 { t.Errorf("empty `accessor` in response") } if result.TTL < 100 { t.Errorf("short TTL in response") } } } } func TestServerRetrieve(t *testing.T) { server, defaultName, defaultAccessor := mockServer() cases := []struct { name string acc string status int }{ { name: "", acc: "", status: 404, }, { name: defaultName, acc: "", status: 404, }, { name: "", acc: defaultAccessor, status: 404, }, { name: defaultName, acc: defaultAccessor, status: 200, }, { name: "fake", acc: "fake", status: 400, }, } path := "retrieve" for i, c := range cases { w := httptest.NewRecorder() r, _ := http.NewRequest("GET", path, nil) r = addParam(r, "name", c.name) r = addParam(r, "to", "to") r = addParam(r, "accessor", c.acc) auth, err := server.authdb.New(serverNS, c.name) if err != nil { t.Fatalf("cannot authdb new: %v", err) } r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token) server.ServeHTTP(w, r) if w.Code != c.status { //t.Errorf("%d: wrong code for %s with %v: got %d, expected %d", i, path, c, w.Code, c.status) t.Fatalf("%d: wrong code for %s with %v: got %d, expected %d", i, path, c, w.Code, c.status) } } } func TestServerRevoke(t *testing.T) { cases := []struct { name string status int }{ { name: "", status: 404, }, { name: "name", status: 200, }, } path := "register" server, _, _ := mockServer() for i, c := range cases { w := httptest.NewRecorder() r, _ := http.NewRequest("GET", path, nil) r = addParam(r, "name", c.name) auth, err := server.authdb.New(serverNS, c.name) if err != nil { t.Fatalf("cannot authdb new: %v", err) } r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token) server.ServeHTTP(w, r) if w.Code != c.status { t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) } } } func TestServerLookup(t *testing.T) { from := "from" to := "to" cases := []struct { from string to string status int }{ { from: "", to: "", status: 404, }, { from: "", to: to, status: 404, }, { from: from, to: "", status: 404, }, { from: from, to: to, status: 200, }, } path := "lookup" server, _, _ := mockServer() server.db.Register(from) generated, _ := server.db.New(from, to) for i, c := range cases { w := httptest.NewRecorder() r, _ := http.NewRequest("GET", path, nil) r = addParam(r, "from", c.from) r = addParam(r, "to", c.to) auth, err := server.authdb.New(serverNS, c.from) if err != nil { t.Fatalf("cannot authdb new: %v", err) } r.SetBasicAuth(c.from, auth.Accessor+":"+auth.Token) server.ServeHTTP(w, r) if w.Code != c.status { t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status) } if w.Code != http.StatusOK { continue } b, err := ioutil.ReadAll(w.Body) if err != nil { t.Fatalf("%d: cannot read body: %v", i, err) } if !bytes.Contains(b, []byte(generated.Accessor)) { t.Errorf("%d: response didn't contain accessor: got %s, want %s", i, b, generated.Accessor) } } } func echoHTTP(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.URL.Path)) } func addParam(r *http.Request, key, value string) *http.Request { r.URL.Path += "/" + value r.Header.Set(key, value) return r } func TestServerAuthenticate(t *testing.T) { server, _, _ := mockServer() name := "name" server.authdb.Register(serverNS) token, err := server.authdb.New(serverNS, name) if err != nil { t.Fatalf("cannot authdb new: %v", err) } nilHandle := func(http.ResponseWriter, *http.Request) {} authFunc := server.authenticate(nilHandle) cases := []struct { name string token string accessor string code int }{ { name: "bad", token: token.Token, accessor: token.Accessor, code: http.StatusUnauthorized, }, { name: name, token: token.Token, accessor: "bad", code: http.StatusUnauthorized, }, { name: name, token: "bad", accessor: token.Accessor, code: http.StatusUnauthorized, }, { name: name, token: token.Token, accessor: token.Accessor, code: http.StatusOK, }, } for i, c := range cases { w := httptest.NewRecorder() r, _ := http.NewRequest("GET", "/any", nil) r.SetBasicAuth(c.name, c.accessor+":"+c.token) authFunc(w, r) if w.Code != c.code { t.Errorf("[case %d] failed auth: got %v, wanted %v", i, w.Code, c.code) } } }