diff --git a/main.go b/main.go index 2834f62..b5ed6c7 100644 --- a/main.go +++ b/main.go @@ -25,21 +25,15 @@ func main() { // need to load old from file allFeeds := make(map[string]*rss.Feed) mon, err := monitor.New(config.MonitorPort, func(url string) { - var err error - var items []*rss.Item feed, ok := allFeeds[url] if !ok { - feed, items, err = rss.New(url, "Blue", "") - if err != nil { - logger.Log("can't create new RSS %q: %v", url, err) - return - } - } else { - items, err = allFeeds[url].Update() - if err != nil { - logger.Log("can't update old RSS %q: %v", url, err) - return - } + logger.Log("unknown feed triggered in monitor: %q", url) + return + } + items, err := allFeeds[url].Update() + if err != nil { + logger.Log("can't update old RSS %q: %v", url, err) + return } b, err := feed.Serialize() if err != nil { @@ -71,20 +65,14 @@ func main() { panic(err) } defer mon.Stop() - /* - go func() { // API submissions to itemsNew - itemsNew <- *(monitor.NewItem("https://xkcd.com/rss.xml", time.Minute)) - }() - go func() { - for doneItem := range itemsDone { - if err := fetcher.FetchProcess(doneItem.URL); err != nil { - logger.Log(err) - } - } - }() - */ - server, err := server.New(config.Port, func(url string, interval time.Duration) { + server, err := server.New(config.Port, func(url string, itemFilter, contentFilter string, interval time.Duration) { + feed, err := rss.New(url, itemFilter, contentFilter) + if err != nil { + logger.Log("can't create new RSS %q: %v", url, err) + return + } + allFeeds[url] = feed if err := mon.Submit(url, interval); err != nil { logger.Log("Cannot accept new feed %q: %v", url, err) } diff --git a/rss/feed.go b/rss/feed.go index a7db5ae..28d0f55 100644 --- a/rss/feed.go +++ b/rss/feed.go @@ -35,12 +35,12 @@ func (feed *Feed) ID() string { return strings.Join(regexp.MustCompile("[a-zA-Z0-9]*").FindAllString(feed.Link, -1), "_") } -func New(source, itemFilter, contentFilter string) (*Feed, []*Item, error) { +func New(source, itemFilter, contentFilter string) (*Feed, error) { if _, err := regexp.Compile(itemFilter); err != nil { - return nil, nil, err + return nil, err } if _, err := regexp.Compile(contentFilter); err != nil { - return nil, nil, err + return nil, err } f := &Feed{ Items: []string{}, @@ -48,11 +48,7 @@ func New(source, itemFilter, contentFilter string) (*Feed, []*Item, error) { ContentFilter: contentFilter, Link: source, } - items, err := f.Update() - if err != nil { - return nil, nil, err - } - return f, items, nil + return f, nil } func Deserialize(src []byte) (*Feed, error) { diff --git a/rss/feed_test.go b/rss/feed_test.go index 2ba130d..566628f 100644 --- a/rss/feed_test.go +++ b/rss/feed_test.go @@ -31,10 +31,14 @@ func Test_RSSFeed(t *testing.T) { }, } for _, c := range cases { - feed, items, err := New(s.URL, c.itemFilter, c.contentFilter) + feed, err := New(s.URL, c.itemFilter, c.contentFilter) if err != nil { t.Errorf("couldn't create new feed %v: %v", feed, err) } + items, err := feed.Update() + if err != nil { + t.Errorf("cannot update feed %q: %v", s.URL, err) + } if len(items) != c.itemsOut { t.Errorf("couldn't get all items from feed: got %v, wanted %v", len(items), c.itemsOut) } diff --git a/server/server.go b/server/server.go index dad1d7b..e7074a5 100644 --- a/server/server.go +++ b/server/server.go @@ -1,10 +1,13 @@ package server import ( + "encoding/json" "net/http" + "net/url" "os" "os/signal" "path" + "regexp" "strings" "syscall" "time" @@ -12,10 +15,10 @@ import ( type Server struct { addr string - newItemHandler func(string, time.Duration) + newItemHandler func(string, string, string, time.Duration) } -func New(addr string, newItemHandler func(string, time.Duration)) (*Server, error) { +func New(addr string, newItemHandler func(string, string, string, time.Duration)) (*Server, error) { return &Server{ addr: addr, newItemHandler: newItemHandler, @@ -47,20 +50,28 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "api": s.api(w, r) default: - s.bad(w, r) + s.notFound(w, r) } } -func (s *Server) bad(w http.ResponseWriter, r *http.Request) { +func (s *Server) notFound(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) } +func (s *Server) bad(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) +} + +func (s *Server) mybad(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) +} + func (s *Server) api(w http.ResponseWriter, r *http.Request) { switch advance(r) { case "feed": s.feed(w, r) default: - s.bad(w, r) + s.notFound(w, r) } } @@ -68,12 +79,51 @@ func (s *Server) feed(w http.ResponseWriter, r *http.Request) { switch r.Method { case "GET": case "POST": + s.newItem(w, r) case "PUT": + s.newItem(w, r) default: - s.bad(w, r) + s.notFound(w, r) } } +func (s *Server) newItem(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + newItemBody := struct { + URL string `json:"url"` + Refresh string `json:"refresh"` + ItemFilter string `json:"items"` + ContentFilter string `json:"content"` + }{} + if err := json.NewDecoder(r.Body).Decode(&newItemBody); err != nil { + s.bad(w, r) + return + } + interval, err := time.ParseDuration(newItemBody.Refresh) + if err != nil { + s.bad(w, r) + return + } + if !validURL(newItemBody.URL) { + s.bad(w, r) + return + } + if _, err := regexp.Compile(newItemBody.ItemFilter); err != nil { + s.bad(w, r) + return + } + if _, err := regexp.Compile(newItemBody.ContentFilter); err != nil { + s.bad(w, r) + return + } + s.newItemHandler(newItemBody.URL, newItemBody.ItemFilter, newItemBody.ContentFilter, interval) +} + +func validURL(loc string) bool { + _, err := url.ParseRequestURI(loc) + return err == nil +} + func advance(r *http.Request) string { p := path.Clean("/" + r.URL.Path) i := strings.Index(p[1:], "/") + 1 diff --git a/server/server_test.go b/server/server_test.go index aaa505b..8761966 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "fmt" "net/http" "syscall" @@ -18,7 +19,7 @@ func Test_Server(t *testing.T) { for _, _ = range cases { var err error - s, err := New(testPort) + s, err := New(testPort, func(string, string, string, time.Duration) {}) if err != nil { t.Errorf("failed to create server: %v", err) } @@ -33,19 +34,29 @@ func Test_Server(t *testing.T) { if err := checkStatus("GET", "api/feed", http.StatusOK); err != nil { t.Errorf(err.Error()) } - if err := checkStatus("POST", "api/feed", http.StatusOK); err != nil { + if err := checkStatus("POST", "api/feed", http.StatusBadRequest); err != nil { t.Errorf(err.Error()) } - if err := checkStatus("PUT", "api/feed", http.StatusOK); err != nil { + if err := checkStatus("PUT", "api/feed", http.StatusBadRequest, "invalid json"); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("POST", "api/feed", http.StatusBadRequest, `{"url":"hello/world", "refresh":"1m"}`); err != nil { + t.Errorf(err.Error()) + } + if err := checkStatus("PUT", "api/feed", http.StatusOK, `{"url":"localhost:1234", "refresh":"1m"}`); err != nil { t.Errorf(err.Error()) } syscall.Kill(syscall.Getpid(), syscall.SIGINT) } } -func checkStatus(method, path string, code int) error { +func checkStatus(method, path string, code int, body ...string) error { + b := bytes.NewBuffer(nil) + if len(body) > 0 { + b = bytes.NewBufferString(body[0]) + } client := &http.Client{} - r, _ := http.NewRequest(method, "http://localhost:"+testPort+"/"+path, nil) + r, _ := http.NewRequest(method, "http://localhost:"+testPort+"/"+path, b) resp, err := client.Do(r) if err != nil { return fmt.Errorf("failed to %v server: %v", method, err)