diff --git a/server/server.go b/server/server.go index 6e53252..731429b 100644 --- a/server/server.go +++ b/server/server.go @@ -18,17 +18,19 @@ import ( type Server struct { addr string - newFeedHandler func(string, string, string, time.Duration) + newFeedHandler func(string, string, string, []string, time.Duration) getFeedHandler func(string, int) (string, error) getFeedItemHandler func(string) (string, error) + getFeedTagHandler func(string) (string, error) } -func New(addr string, newFeedHandler func(string, string, string, time.Duration), getFeedHandler func(string, int) (string, error), getFeedItemHandler func(string) (string, error)) (*Server, error) { +func New(addr string, newFeedHandler func(string, string, string, []string, time.Duration), getFeedHandler func(string, int) (string, error), getFeedItemHandler func(string) (string, error), getFeedTagHandler func(string) (string, error)) (*Server, error) { return &Server{ addr: addr, newFeedHandler: newFeedHandler, getFeedHandler: getFeedHandler, getFeedItemHandler: getFeedItemHandler, + getFeedTagHandler: getFeedTagHandler, }, nil } @@ -87,9 +89,12 @@ func (s *Server) api(w http.ResponseWriter, r *http.Request) { func (s *Server) feed(w http.ResponseWriter, r *http.Request) { switch r.Method { case "GET": - if advance(r) == "item" { + switch advance(r) { + case "item": s.getFeedItem(w, r) - } else { + case "tag": + s.getFeedTag(w, r) + case "": s.getFeed(w, r) } case "POST": @@ -104,10 +109,11 @@ func (s *Server) feed(w http.ResponseWriter, r *http.Request) { func (s *Server) newFeed(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() newFeedBody := struct { - URL string `json:"url"` - Refresh string `json:"refresh"` - ItemFilter string `json:"items"` - ContentFilter string `json:"content"` + URL string `json:"url"` + Refresh string `json:"refresh"` + ItemFilter string `json:"items"` + ContentFilter string `json:"content"` + Tags []string `json:"tags"` }{ Refresh: "3h", } @@ -132,7 +138,23 @@ func (s *Server) newFeed(w http.ResponseWriter, r *http.Request) { s.bad(w, r) return } - s.newFeedHandler(newFeedBody.URL, newFeedBody.ItemFilter, newFeedBody.ContentFilter, interval) + s.newFeedHandler(newFeedBody.URL, newFeedBody.ItemFilter, newFeedBody.ContentFilter, newFeedBody.Tags, interval) +} + +func (s *Server) getFeedTag(w http.ResponseWriter, r *http.Request) { + url, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + logger.Logf("cannot get feed tag to read: %v", err) + s.mybad(w, r) + return + } + feedBody, err := s.getFeedTagHandler(url.Get("url")) + if err != nil { + logger.Logf("cannot get feed tag %s: %v", url.Get("url"), err) + s.mybad(w, r) + return + } + fmt.Fprintln(w, feedBody) } func (s *Server) getFeedItem(w http.ResponseWriter, r *http.Request) { diff --git a/server/server_test.go b/server/server_test.go index 55ecc4e..82fa0fc 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4,50 +4,58 @@ import ( "bytes" "fmt" "net/http" + "net/http/httptest" + "strings" "syscall" "testing" "time" ) -const testPort = "39231" +var testPort = "39231" func Test_Server(t *testing.T) { - cases := []struct { - }{ - {}, - } + server := httptest.NewUnstartedServer(nil) + server.Close() + testPort = strings.Split(server.Listener.Addr().String(), ":")[1] - for _, _ = range cases { - var err error - s, err := New(testPort, func(string, string, string, time.Duration) {}, func(string, int) (string, error) { return "", nil }, func(string) (string, error) { return "", nil }) - if err != nil { - t.Errorf("failed to create server: %v", err) - } - go s.Serve() - time.Sleep(time.Second * 1) - if err := checkStatus("GET", "", http.StatusNotFound); err != nil { - t.Errorf(err.Error()) - } - if err := checkStatus("GET", "api", http.StatusNotFound); err != nil { - t.Errorf(err.Error()) - } - if err := checkStatus("GET", "api/feed", http.StatusOK); err != nil { - t.Errorf(err.Error()) - } - if err := checkStatus("POST", "api/feed", http.StatusBadRequest); err != nil { - t.Errorf(err.Error()) - } - if err := checkStatus("PUT", "api/feed", http.StatusBadRequest, "invalid json"); err != nil { - t.Errorf(err.Error()) - } - if err := checkStatus("POST", "api/feed", http.StatusBadRequest, `{"url":"hello/world", "refresh":"1m"}`); err != nil { - t.Errorf(err.Error()) - } - if err := checkStatus("PUT", "api/feed", http.StatusOK, `{"url":"localhost:1234", "refresh":"1m"}`); err != nil { - t.Errorf(err.Error()) - } - syscall.Kill(syscall.Getpid(), syscall.SIGINT) + var err error + s, err := New(testPort, func(string, string, string, []string, time.Duration) {}, func(string, int) (string, error) { return "", nil }, func(string) (string, error) { return "", nil }, func(string) (string, error) { return "", nil }) + if err != nil { + t.Errorf("failed to create server: %v", err) } + go s.Serve() + time.Sleep(time.Second * 1) + if err := checkStatus("GET", "", http.StatusNotFound); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("GET", "api", http.StatusNotFound); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("GET", "api/feed", http.StatusOK); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("POST", "api/feed", http.StatusBadRequest); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("PUT", "api/feed", http.StatusBadRequest, "invalid json"); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("POST", "api/feed", http.StatusBadRequest, `{"url":"hello/world", "refresh":"1m"}`); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("PUT", "api/feed", http.StatusOK, `{"url":"localhost:1234", "refresh":"1m", "tags":["a", "b"]}`); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("GET", "api/feed?url=localhost_1234", http.StatusOK); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("GET", "api/feed/item?url=localhost_1234", http.StatusOK); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("GET", "api/feed/tag?url=b", http.StatusOK); err != nil { + t.Errorf(err.Error()) + } + syscall.Kill(syscall.Getpid(), syscall.SIGINT) } func checkStatus(method, path string, code int, body ...string) error {