commit 32326de6b2a36ef34f30ed4ea2cab0376e6f00f3 Author: Bel LaPointe Date: Mon Mar 18 09:27:11 2019 -0600 Router a good diff --git a/router.go b/router.go new file mode 100644 index 0000000..dc9925f --- /dev/null +++ b/router.go @@ -0,0 +1,46 @@ +package router + +import ( + "errors" + "net/http" +) + +type Router struct { + t *tree +} + +func New() *Router { + return &Router{ + t: newTree(), + } +} + +func (rt *Router) Add(path string, foo http.HandlerFunc) error { + return rt.t.Insert(path, foo) +} + +func (rt *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { + foo := rt.t.Lookup(r.URL.Path) + if foo == nil { + http.NotFound(w, r) + return + } + foo(w, r) +} + +func Params(r *http.Request, toSet ...*string) error { + params := r.Header[WildcardHeader] + if len(params) != len(toSet) { + return errors.New("missing params") + } + for i := range params { + if len(params[i]) < 1 { + return errors.New("empty params") + } + if toSet[i] == nil { + return errors.New("cannot set nil param") + } + *toSet[i] = params[i] + } + return nil +} diff --git a/router_test.go b/router_test.go new file mode 100644 index 0000000..6a5bbcf --- /dev/null +++ b/router_test.go @@ -0,0 +1,118 @@ +package router + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestRouter(t *testing.T) { + rt := New() + + paths := []string{ + "/not/found", + "/hello", + "/world", + "/hello/world", + "/hello/{}/other", + } + + for i, p := range paths { + if i == 0 { + continue + } + if err := rt.Add(p, getHandler(p)); err != nil { + t.Errorf("router cannot add %v: %v", p, err) + } + } + if err := rt.Add(paths[len(paths)-1], getHandler(paths[len(paths)-1])); err == nil { + t.Errorf("router can re-add %v: %v", paths[len(paths)-1], err) + } + + for i, p := range paths { + w := httptest.NewRecorder() + gpath := strings.Replace(p, "{}", "seq", -1) + if req, err := http.NewRequest("GET", gpath[1:]+"/", nil); err != nil { + t.Fatalf("cannot make http req: %v", err) + } else { + rt.ServeHTTP(w, req) + b, err := ioutil.ReadAll(w.Body) + if i > 0 && (err != nil || w.Code != 200 || string(b) != p) { + t.Errorf("did not check %v: %v %v %q %q", p, err, w.Code, b, p) + } + } + } +} + +func getHandler(s string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, s) + } +} + +func TestParams(t *testing.T) { + var s string + cases := []struct { + actual int + put string + want int + assign *string + err bool + }{ + { + put: "a", + assign: &s, + actual: 0, + want: 1, + err: true, + }, + { + put: "a", + assign: &s, + actual: 1, + want: 0, + err: true, + }, + { + put: "a", + assign: &s, + actual: 1, + want: 1, + err: false, + }, + { + put: "i", + assign: nil, + actual: 1, + want: 1, + err: true, + }, + { + put: "", + assign: &s, + actual: 1, + want: 1, + err: true, + }, + } + + for z, c := range cases { + r, _ := http.NewRequest("GET", "localhost:123", nil) + actual := []string{} + for i := 0; i < c.actual; i++ { + actual = append(actual, c.put) + } + r.Header[WildcardHeader] = actual + want := []*string{} + for i := 0; i < c.want; i++ { + want = append(want, c.assign) + } + err := Params(r, want...) + if (err != nil) != c.err { + t.Errorf("[%d] get params didn't find err %v when %v vs %v: %v", z, c.err, c.actual, c.want, err) + } + } +} diff --git a/tree.go b/tree.go new file mode 100644 index 0000000..cc2ff0d --- /dev/null +++ b/tree.go @@ -0,0 +1,67 @@ +package router + +import ( + "errors" + "net/http" + "path" + "strings" +) + +var Wildcard = "{}" +var WildcardHeader = "__wildcard_value__" + +type tree struct { + next map[string]*tree + handler http.HandlerFunc +} + +func newTree() *tree { + return &tree{ + next: make(map[string]*tree), + handler: nil, + } +} + +func (t *tree) Lookup(path string) http.HandlerFunc { + if path == "/" || path == "" { + return t.handler + } + key, following := nextPathSegment(path) + if n, ok := t.next[key]; ok { + return n.Lookup(following) + } else if n, ok := t.next[Wildcard]; ok { + foo := n.Lookup(following) + if foo != nil { + return func(w http.ResponseWriter, r *http.Request) { + r.Header.Add(WildcardHeader, key) + foo(w, r) + } + } + } + return nil +} + +func (t *tree) Insert(path string, foo http.HandlerFunc) error { + if path == "/" { + if t.handler != nil { + return errors.New("occupied path") + } + t.handler = foo + return nil + } + key, following := nextPathSegment(path) + _, ok := t.next[key] + if !ok { + t.next[key] = newTree() + } + return t.next[key].Insert(following, foo) +} + +func nextPathSegment(p string) (string, string) { + p = path.Clean("/" + p) + i := strings.Index(p[1:], "/") + 1 + if i <= 0 { + return p[1:], "/" + } + return p[1:i], p[i:] +} diff --git a/tree_test.go b/tree_test.go new file mode 100644 index 0000000..765c293 --- /dev/null +++ b/tree_test.go @@ -0,0 +1,125 @@ +package router + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +var nilHandle = func(w http.ResponseWriter, r *http.Request) {} + +func TestNewTree(t *testing.T) { + newTree() +} + +func TestTreeInsert(t *testing.T) { + tree := newTree() + if err := tree.Insert("/hello/world", nilHandle); err != nil { + t.Errorf("failed to insert first path: %v", err) + } + if err := tree.Insert("/hello/world", nilHandle); err == nil { + t.Errorf("succeeded to insert dupe path: %v", err) + } + if err := tree.Insert("/hello/", nilHandle); err != nil { + t.Errorf("failed to insert sub path: %v", err) + } + if err := tree.Insert("/world/", nilHandle); err != nil { + t.Errorf("failed to insert new path: %v", err) + } +} + +func TestTreeLookup(t *testing.T) { + tree := newTree() + subtree := newTree() + checked := false + subtree.handler = func(w http.ResponseWriter, r *http.Request) { + checked = true + } + tree.next["hi"] = subtree + foo := tree.Lookup("/hi/") + if foo == nil { + t.Errorf("cannot lookup path: %v", "/hi/") + } else { + foo(nil, nil) + } + if !checked { + t.Errorf("lookup returned wrong function") + } +} + +func TestTreeInsertLookup(t *testing.T) { + tree := newTree() + checked := false + foo := func(_ http.ResponseWriter, _ *http.Request) { + checked = true + } + + paths := []string{ + "/hello", + "/hello/world", + "/world", + } + + for _, p := range paths { + if err := tree.Insert(p, foo); err != nil { + t.Fatalf("cannot insert: %v", err) + } + } + + for _, p := range paths { + if bar := tree.Lookup(p[1:] + "/"); bar == nil { + t.Fatalf("cannot lookup: %v", p) + } else { + checked = false + bar(nil, nil) + if !checked { + t.Errorf("failed to call %v: %v", p, checked) + } + } + } +} + +func TestTreeWildcard(t *testing.T) { + tree := newTree() + checked := false + foo := func(w http.ResponseWriter, r *http.Request) { + checked = true + fmt.Fprintf(w, "%v", r.Header[WildcardHeader]) + } + + paths := []string{ + "/hello/{}", + "/hello/{}/{}/world", + } + + for _, p := range paths { + if err := tree.Insert(p, foo); err != nil { + t.Fatalf("cannot insert: %v", err) + } + } + + for _, p := range paths { + dpath := strings.Replace(p, "{}", "seq", -1) + if bar := tree.Lookup(dpath[1:] + "/"); bar == nil { + t.Fatalf("cannot lookup: %v", p) + } else { + checked = false + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", dpath, nil) + bar(w, r) + if !checked { + t.Errorf("failed to call %v: %v", p, checked) + } + b, err := ioutil.ReadAll(w.Body) + if err != nil { + t.Errorf("failed to read all: %v", err) + } + if strings.Count(string(b), "seq") != strings.Count(p, "{}") { + t.Errorf("failed to decode wildcards: %s", b) + } + } + } +}