From 88502651de0ff7011301adaa202b3440a2b6838e Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Tue, 9 Oct 2018 18:32:46 -0600 Subject: [PATCH] Add GET handler to server --- main.go | 69 +++++++++++++++++++++++++++------------ main_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 20 deletions(-) create mode 100644 main_test.go diff --git a/main.go b/main.go index fd7896e..e797c35 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "errors" "local1/logger" "local3/rssmon2/config" "local3/rssmon2/monitor" @@ -13,6 +14,10 @@ import ( const nsForFeeds = "FEEDS" func main() { + core() +} + +func core() { config := config.New() var sclient store.Client @@ -28,43 +33,43 @@ func main() { if !ok { f, err := rss.New(url, "", "", time.Minute) if err != nil { - logger.Log("cannot identify unknown feed triggered in monitor: %q: %v", url, err) + logger.Logf("cannot identify unknown feed triggered in monitor: %q: %v", url, err) return } b, err := sclient.Get(nsForFeeds, f.ID()) if err != nil { - logger.Log("cannot get unknown feed triggered in monitor: %q: %v", url, err) + logger.Logf("cannot get unknown feed triggered in monitor: %q: %v", url, err) return } feed, err = rss.Deserialize(b) if err != nil { - logger.Log("cannot deserialize feed triggered in monitor: %q: %v", url, err) + logger.Logf("cannot deserialize feed triggered in monitor: %q: %v", url, err) return } } items, err := allFeeds[url].Update() if err != nil { - logger.Log("can't update old RSS %q: %v", url, err) + logger.Logf("can't update old RSS %q: %v", url, err) return } b, err := feed.Serialize() if err != nil { - logger.Log("can't serialize to save RSS %q: %v", url, err) + logger.Logf("can't serialize to save RSS %q: %v", url, err) return } if err := sclient.Set(nsForFeeds, feed.ID(), b); err != nil { - logger.Log("can't save RSS %q.%q: %v", nsForFeeds, feed.ID(), err) + logger.Logf("can't save RSS %q.%q: %v", nsForFeeds, feed.ID(), err) return } logger.Log("Saved feed", feed) for i := range items { b, err := items[i].Serialize() if err != nil { - logger.Log("can't save rss item %q.%q: %v", url, items[i].Link, err) + logger.Logf("can't save rss item %q.%q: %v", url, items[i].Link, err) return } if err := sclient.Set(feed.ID(), items[i].ID(), b); err != nil { - logger.Log("can't save rss item %q.%q: %v", feed.ID(), items[i].ID(), err) + logger.Logf("can't save rss item %q.%q: %v", feed.ID(), items[i].ID(), err) return } logger.Log("Saved feed item", feed.ID(), items[i].ID(), items[i]) @@ -78,22 +83,46 @@ func main() { } defer mon.Stop() - server, err := server.New(config.Port, func(url string, itemFilter, contentFilter string, interval time.Duration) { - feed, err := rss.New(url, itemFilter, contentFilter, interval) - if err != nil { - logger.Log("can't create new RSS %q: %v", url, err) - return - } - allFeeds[url] = feed - if err := mon.Submit(url, feed.Interval); err != nil { - logger.Log("Cannot accept new feed %q: %v", url, err) - } - }) + server, err := server.New(config.Port, + func(url string, itemFilter, contentFilter string, interval time.Duration) { + feed, err := rss.New(url, itemFilter, contentFilter, interval) + if err != nil { + logger.Logf("can't create new RSS %q: %v", url, err) + return + } + allFeeds[url] = feed + if err := mon.Submit(url, feed.Interval); err != nil { + logger.Logf("Cannot accept new feed %q: %v", url, err) + } + }, + func(url string, n int) (string, error) { + feed, ok := allFeeds[url] + if !ok { + return "", errors.New("unknown feed " + url) + } + itemKeys, err := sclient.List(feed.ID(), "", false, n) + if err != nil { + return "", err + } + items := make([]*rss.Item, len(itemKeys)) + for i := range itemKeys { + b, err := sclient.Get(feed.ID(), itemKeys[i]) + if err != nil { + return "", errors.New("cannot get feed item " + itemKeys[i]) + } + items[i], err = rss.DeserializeItem(b) + if err != nil { + return "", errors.New("cannot deserialize feed item" + itemKeys[i]) + } + } + return rss.ToRSS(feed, items) + }, + ) if err != nil { panic(err) } - oldFeeds, err := sclient.List(nsForFeeds, "") + oldFeeds, err := sclient.List(nsForFeeds, "", true, -1) if err != nil { panic(err) } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..477f06d --- /dev/null +++ b/main_test.go @@ -0,0 +1,92 @@ +package main + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" +) + +func Test_Core(t *testing.T) { + server := httptest.NewUnstartedServer(nil) + server.Close() + wasDBPath := os.Getenv("DBPATH") + wasPort := os.Getenv("PORT") + defer os.Setenv("DBPATH", wasDBPath) + defer os.Setenv("PORT", wasPort) + tmpf, err := ioutil.TempFile("", "testdb.db") + if err != nil { + t.Fatalf("cannot create temp db file: %v", err) + } + defer os.Remove(tmpf.Name()) + os.Setenv("DBPATH", tmpf.Name()) + os.Setenv("PORT", ":"+strings.Split(server.Listener.Addr().String(), ":")[1]) + go core() + time.Sleep(time.Second * 5) + + client := &http.Client{ + Timeout: time.Second * 5, + } + testhost := "http://localhost" + os.Getenv("PORT") + + cases := []struct { + method string + path string + body string + status int + pre func() + post func() + }{ + { + method: "post", + path: "api/feed", + body: `{"url":"https://utw.me/feed/", "refresh":"1m", "items":"Fate", "content":"25"}`, + status: 200, + post: func() { time.Sleep(time.Second * 10) }, + }, + { + method: "get", + path: "api/feed", + body: "https://utw.me/feed/", + status: 200, + }, + } + for _, c := range cases { + c.method = strings.ToUpper(c.method) + if c.pre != nil { + c.pre() + } + loc, err := url.Parse(testhost + "/" + c.path) + if err != nil { + t.Errorf("cannot create loc %s+%s: %v", testhost, c.path, err) + } + var body io.Reader = nil + if c.method == "GET" && c.body != "" { + v := url.Values{} + v.Add("url", c.body) + loc.RawQuery = v.Encode() + } else if c.body != "" { + body = bytes.NewBuffer([]byte(c.body)) + } + req, err := http.NewRequest(c.method, loc.String(), body) + if err != nil { + t.Errorf("cannot create request %s:%s: %v", c.method, loc.String(), err) + } + if resp, err := client.Do(req); err != nil { + t.Errorf("cannot %s to %s: %v", c.method, loc.String(), err) + } else if resp.StatusCode != c.status { + t.Errorf("wrong %s status to %s: %v, expected %v", c.method, loc.String(), resp.StatusCode, c.status) + } else if _, err := ioutil.ReadAll(resp.Body); err != nil { + t.Errorf("cannot read body on %s to %s: %v", c.method, loc.String(), err) + } + if c.post != nil { + c.post() + } + } +}