438 lines
8.7 KiB
Go
438 lines
8.7 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
type badRouter struct {
|
|
accept []string
|
|
}
|
|
|
|
func (br *badRouter) Add(p string, h http.HandlerFunc) error {
|
|
for i := range br.accept {
|
|
if br.accept[i] == p {
|
|
return nil
|
|
}
|
|
}
|
|
return errors.New("rejected path")
|
|
}
|
|
|
|
func (br *badRouter) ServeHTTP(http.ResponseWriter, *http.Request) {}
|
|
|
|
func TestServerRoutesBadRouter(t *testing.T) {
|
|
server, _, _ := mockServer()
|
|
br := badRouter{
|
|
accept: make([]string, 0),
|
|
}
|
|
server.router = &br
|
|
|
|
toAdd := []string{
|
|
"/nil",
|
|
"/admin/register/{}",
|
|
"/register/{}",
|
|
"/generate/{}/{}",
|
|
"/retrieve/{}/{}",
|
|
"/revoke/{}/{}",
|
|
"/lookup/{}/{}",
|
|
"/policies",
|
|
}
|
|
|
|
for _, path := range toAdd {
|
|
br.accept = append(br.accept, path)
|
|
if err := server.Routes(); err == nil {
|
|
t.Errorf("can add non-allowed routes")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestServerRoutes(t *testing.T) {
|
|
server, _, _ := mockServer()
|
|
if err := server.db.Register("a"); err != nil {
|
|
t.Fatalf("cannot register: %v", err)
|
|
}
|
|
token, err := server.db.New("a", "b")
|
|
if err != nil {
|
|
t.Fatalf("cannot new: %v", err)
|
|
}
|
|
|
|
paths := []string{
|
|
"retrieve/a/b/" + token.Accessor,
|
|
"revoke/a/" + token.Accessor,
|
|
"lookup/a/b",
|
|
"policies",
|
|
"admin/register/a",
|
|
"register/a",
|
|
"generate/a/b",
|
|
}
|
|
|
|
for _, p := range paths {
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest("GET", p, nil)
|
|
server.ServeHTTP(w, r)
|
|
if w.Code == 404 {
|
|
t.Errorf("not found for %v", p)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestServerAdminRegister(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
status int
|
|
}{
|
|
{
|
|
name: "",
|
|
status: 404,
|
|
},
|
|
{
|
|
name: "name",
|
|
status: 200,
|
|
},
|
|
}
|
|
|
|
path := "admin/register"
|
|
server, _, _ := mockServer()
|
|
|
|
for i, c := range cases {
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest("GET", path, nil)
|
|
r = addParam(r, "name", c.name)
|
|
server.ServeHTTP(w, r)
|
|
if w.Code != c.status {
|
|
t.Errorf("[%d] wrong code for %s: got %d, expected %d", i, path, w.Code, c.status)
|
|
}
|
|
body, _ := ioutil.ReadAll(w.Body)
|
|
if len(body) < 1 {
|
|
t.Errorf("[%d] empty body in admin/register response: %q", i, body)
|
|
}
|
|
t.Logf("[%d] admin/register body: %q", i, body)
|
|
}
|
|
}
|
|
|
|
func TestServerRegister(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
status int
|
|
}{
|
|
{
|
|
name: "",
|
|
status: 404,
|
|
},
|
|
{
|
|
name: "name",
|
|
status: 200,
|
|
},
|
|
}
|
|
|
|
path := "register"
|
|
server, _, _ := mockServer()
|
|
|
|
for i, c := range cases {
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest("GET", path, nil)
|
|
r = addParam(r, "name", c.name)
|
|
auth, err := server.authdb.New(serverNS, c.name)
|
|
if err != nil {
|
|
t.Fatalf("cannot authdb new: %v", err)
|
|
}
|
|
r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token)
|
|
server.ServeHTTP(w, r)
|
|
if w.Code != c.status {
|
|
t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestServerGenerate(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
to string
|
|
status int
|
|
}{
|
|
{
|
|
name: "",
|
|
to: "",
|
|
status: 404,
|
|
},
|
|
{
|
|
name: "name",
|
|
to: "",
|
|
status: 404,
|
|
},
|
|
{
|
|
name: "",
|
|
to: "to",
|
|
status: 404,
|
|
},
|
|
{
|
|
name: "name",
|
|
to: "to",
|
|
status: 200,
|
|
},
|
|
}
|
|
|
|
path := "generate"
|
|
server, _, _ := mockServer()
|
|
|
|
for i, c := range cases {
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest("GET", path, nil)
|
|
r = addParam(r, "name", c.name)
|
|
r = addParam(r, "to", c.to)
|
|
auth, err := server.authdb.New(serverNS, c.name)
|
|
if err != nil {
|
|
t.Fatalf("cannot authdb new: %v", err)
|
|
}
|
|
r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token)
|
|
server.ServeHTTP(w, r)
|
|
if w.Code != c.status {
|
|
t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status)
|
|
}
|
|
var result struct {
|
|
Token string `json:"token"`
|
|
Acc string `json:"accessor"`
|
|
TTL int `json:"TTL"`
|
|
To string `json:"to"`
|
|
}
|
|
if err := json.NewDecoder(w.Body).Decode(&result); c.status == 200 && err != nil {
|
|
t.Errorf("invalid body: %v", err)
|
|
} else if c.status == 200 {
|
|
if result.To != c.to {
|
|
t.Errorf("wrong `to` in response: got %v, want %v", result.To, c.to)
|
|
}
|
|
if len(result.Token) == 0 {
|
|
t.Errorf("empty `token` in response")
|
|
}
|
|
if len(result.Acc) == 0 {
|
|
t.Errorf("empty `accessor` in response")
|
|
}
|
|
if result.TTL < 100 {
|
|
t.Errorf("short TTL in response")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestServerRetrieve(t *testing.T) {
|
|
server, defaultName, defaultAccessor := mockServer()
|
|
|
|
cases := []struct {
|
|
name string
|
|
acc string
|
|
status int
|
|
}{
|
|
{
|
|
name: "",
|
|
acc: "",
|
|
status: 404,
|
|
},
|
|
{
|
|
name: defaultName,
|
|
acc: "",
|
|
status: 404,
|
|
},
|
|
{
|
|
name: "",
|
|
acc: defaultAccessor,
|
|
status: 404,
|
|
},
|
|
{
|
|
name: defaultName,
|
|
acc: defaultAccessor,
|
|
status: 200,
|
|
},
|
|
{
|
|
name: "fake",
|
|
acc: "fake",
|
|
status: 400,
|
|
},
|
|
}
|
|
|
|
path := "retrieve"
|
|
for i, c := range cases {
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest("GET", path, nil)
|
|
r = addParam(r, "name", c.name)
|
|
r = addParam(r, "to", "to")
|
|
r = addParam(r, "accessor", c.acc)
|
|
auth, err := server.authdb.New(serverNS, c.name)
|
|
if err != nil {
|
|
t.Fatalf("cannot authdb new: %v", err)
|
|
}
|
|
r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token)
|
|
server.ServeHTTP(w, r)
|
|
if w.Code != c.status {
|
|
//t.Errorf("%d: wrong code for %s with %v: got %d, expected %d", i, path, c, w.Code, c.status)
|
|
t.Fatalf("%d: wrong code for %s with %v: got %d, expected %d", i, path, c, w.Code, c.status)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestServerRevoke(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
status int
|
|
}{
|
|
{
|
|
name: "",
|
|
status: 404,
|
|
},
|
|
{
|
|
name: "name",
|
|
status: 200,
|
|
},
|
|
}
|
|
|
|
path := "register"
|
|
server, _, _ := mockServer()
|
|
|
|
for i, c := range cases {
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest("GET", path, nil)
|
|
r = addParam(r, "name", c.name)
|
|
auth, err := server.authdb.New(serverNS, c.name)
|
|
if err != nil {
|
|
t.Fatalf("cannot authdb new: %v", err)
|
|
}
|
|
r.SetBasicAuth(c.name, auth.Accessor+":"+auth.Token)
|
|
server.ServeHTTP(w, r)
|
|
if w.Code != c.status {
|
|
t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestServerLookup(t *testing.T) {
|
|
from := "from"
|
|
to := "to"
|
|
cases := []struct {
|
|
from string
|
|
to string
|
|
status int
|
|
}{
|
|
{
|
|
from: "",
|
|
to: "",
|
|
status: 404,
|
|
},
|
|
{
|
|
from: "",
|
|
to: to,
|
|
status: 404,
|
|
},
|
|
{
|
|
from: from,
|
|
to: "",
|
|
status: 404,
|
|
},
|
|
{
|
|
from: from,
|
|
to: to,
|
|
status: 200,
|
|
},
|
|
}
|
|
|
|
path := "lookup"
|
|
server, _, _ := mockServer()
|
|
server.db.Register(from)
|
|
generated, _ := server.db.New(from, to)
|
|
|
|
for i, c := range cases {
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest("GET", path, nil)
|
|
r = addParam(r, "from", c.from)
|
|
r = addParam(r, "to", c.to)
|
|
auth, err := server.authdb.New(serverNS, c.from)
|
|
if err != nil {
|
|
t.Fatalf("cannot authdb new: %v", err)
|
|
}
|
|
r.SetBasicAuth(c.from, auth.Accessor+":"+auth.Token)
|
|
server.ServeHTTP(w, r)
|
|
if w.Code != c.status {
|
|
t.Errorf("%d: wrong code for %s: got %d, expected %d", i, path, w.Code, c.status)
|
|
}
|
|
if w.Code != http.StatusOK {
|
|
continue
|
|
}
|
|
b, err := ioutil.ReadAll(w.Body)
|
|
if err != nil {
|
|
t.Fatalf("%d: cannot read body: %v", i, err)
|
|
}
|
|
if !bytes.Contains(b, []byte(generated.Accessor)) {
|
|
t.Errorf("%d: response didn't contain accessor: got %s, want %s", i, b, generated.Accessor)
|
|
}
|
|
}
|
|
}
|
|
|
|
func echoHTTP(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte(r.URL.Path))
|
|
}
|
|
|
|
func addParam(r *http.Request, key, value string) *http.Request {
|
|
r.URL.Path += "/" + value
|
|
r.Header.Set(key, value)
|
|
return r
|
|
}
|
|
|
|
func TestServerAuthenticate(t *testing.T) {
|
|
server, _, _ := mockServer()
|
|
|
|
name := "name"
|
|
server.authdb.Register(serverNS)
|
|
token, err := server.authdb.New(serverNS, name)
|
|
if err != nil {
|
|
t.Fatalf("cannot authdb new: %v", err)
|
|
}
|
|
|
|
nilHandle := func(http.ResponseWriter, *http.Request) {}
|
|
authFunc := server.authenticate(nilHandle)
|
|
|
|
cases := []struct {
|
|
name string
|
|
token string
|
|
accessor string
|
|
code int
|
|
}{
|
|
{
|
|
name: "bad",
|
|
token: token.Token,
|
|
accessor: token.Accessor,
|
|
code: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: name,
|
|
token: token.Token,
|
|
accessor: "bad",
|
|
code: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: name,
|
|
token: "bad",
|
|
accessor: token.Accessor,
|
|
code: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: name,
|
|
token: token.Token,
|
|
accessor: token.Accessor,
|
|
code: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for i, c := range cases {
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest("GET", "/any", nil)
|
|
r.SetBasicAuth(c.name, c.accessor+":"+c.token)
|
|
authFunc(w, r)
|
|
if w.Code != c.code {
|
|
t.Errorf("[case %d] failed auth: got %v, wanted %v", i, w.Code, c.code)
|
|
}
|
|
}
|
|
}
|