From 3f1921b0237c8557299c81927e543d323b2643c0 Mon Sep 17 00:00:00 2001 From: Bel LaPointe <153096461+breel-render@users.noreply.github.com> Date: Tue, 12 Nov 2024 08:11:24 -0700 Subject: [PATCH] tests are good --- main.go | 75 +++++++++++++++++++++++++++++++++++++----------- main_test.go | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 17 deletions(-) create mode 100644 main_test.go diff --git a/main.go b/main.go index 6fb2972..e8e6fc9 100644 --- a/main.go +++ b/main.go @@ -45,13 +45,33 @@ func (h Handler) serveHTTP(w http.ResponseWriter, r *http.Request) error { return err } + var duplicate int + if err := h.idempotency.QueryRowContext(r.Context(), `SELECT 1 FROM payloads WHERE payload=$1;`, b).Scan(&duplicate); err != sql.ErrNoRows && err != nil { + log.Println("!", err) + } else if duplicate > 0 { + log.Println("+") + return nil + } + + if err := h.prune(time.Now().Add(-1 * time.Hour * 24 * 7)); err != nil { + log.Println("!", err) + } + resp, err := proxy(h.target, r, io.NopCloser(bytes.NewReader(b))) if err != nil { return err } defer resp.Body.Close() - return forward(w, resp) + if err := forward(w, resp); err != nil { + return err + } + + if _, err := h.idempotency.ExecContext(r.Context(), `INSERT INTO payloads (ts, payload) VALUES ($1, $2);`, time.Now(), b); err != nil { + log.Println("!", err) + } + + return nil } func main() { @@ -63,40 +83,47 @@ func main() { if err := fs.Parse(os.Args[1:]); err != nil { panic(err) } + if err := run(*p, *t, *y, *db); err != nil { + panic(err) + } +} - db, err := sql.Open("sqlite", *db) +func run(p int, t string, y string, db string) error { + idempotency, err := sql.Open("sqlite", db) if err != nil { - panic(err) + return err } - defer db.Close() - if err := db.PingContext(context.Background()); err != nil { - panic(err) + defer idempotency.Close() + if err := idempotency.PingContext(context.Background()); err != nil { + return err + } else if _, err := idempotency.ExecContext(context.Background(), `CREATE TABLE IF NOT EXISTS payloads (payload TEXT, ts TIMESTAMP NOT NULL)`); err != nil { + return err } - tmpl, err := template.New("").Parse(*t) + tmpl, err := template.New("").Parse(t) if err != nil { - panic(err) + return err } - u, err := url.Parse(*y) + u, err := url.Parse(y) if err != nil { - panic(err) + return err } - h := Handler{tmpl: tmpl, target: u, idempotency: db} - log.Println("listening on", *p) - if err := http.ListenAndServe(":"+strconv.Itoa(*p), h); err != nil { - panic(err) - } + h := Handler{tmpl: tmpl, target: u, idempotency: idempotency} + log.Println("listening on", p) + return http.ListenAndServe(":"+strconv.Itoa(p), h) } func adapt(r io.Reader, tmpl *template.Template) ([]byte, error) { b, _ := io.ReadAll(io.LimitReader(r, 1024*1024)) var v interface{} buff := bytes.NewBuffer(nil) - if err := json.Unmarshal(b, &v); err != nil { + if len(b) == 0 { + } else if err := json.Unmarshal(b, &v); err != nil { return nil, err - } else if err := tmpl.Execute(buff, v); err != nil { + } + if err := tmpl.Execute(buff, v); err != nil { return nil, err } log.Printf("%s => %s", b, buff.Bytes()) @@ -135,3 +162,17 @@ func forward(w http.ResponseWriter, resp *http.Response) error { io.Copy(w, resp.Body) return nil } + +func (h Handler) prune(ts time.Time) error { + result, err := h.idempotency.ExecContext(context.Background(), `DELETE FROM payloads WHERE ts < $1;`, ts) + if err != nil { + return err + } + + rows, _ := result.RowsAffected() + if rows > 1 { + log.Println("-", rows) + } + + return nil +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..d0852af --- /dev/null +++ b/main_test.go @@ -0,0 +1,80 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "path" + "strconv" + "strings" + "testing" + "time" +) + +func TestRun(t *testing.T) { + targetCalls := 0 + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + targetCalls += 1 + })) + defer target.Close() + + p := func() int { + s := httptest.NewServer(http.HandlerFunc(http.NotFound)) + s.Close() + u := s.URL + ps := strings.Split(u, ":")[2] + n, err := strconv.Atoi(ps) + if err != nil { + t.Fatal(err) + } + return n + }() + tmpl := `{{ . }}` + x := fmt.Sprintf(`http://localhost:%d`, p) + y := target.URL + db := path.Join(t.TempDir(), "db") + + do := func(method, body string) bool { + req, _ := http.NewRequest(method, x, strings.NewReader(body)) + resp, err := httpc.Do(req) + if err != nil { + return false + } + resp.Body.Close() + return true + } + + go func() { + if err := run(p, tmpl, y, db); err != nil { + t.Fatal(err) + } + }() + + for !do(http.MethodGet, "") { + time.Sleep(time.Millisecond * 100) + } + + if targetCalls != 1 { + t.Error("empty req body no called target") + } + + if !do(http.MethodGet, "") { + t.Error("couldnt get a second time with no body") + } else if targetCalls != 1 { + t.Error("no dedupe no body") + } + + if !do(http.MethodPost, "1") { + t.Error("couldnt get a third time with new body") + } else if targetCalls != 2 { + t.Error("deduped new body") + } else if !do(http.MethodPost, "1") { + t.Error("couldnt get a fourth time with new body again") + } else if targetCalls != 2 { + t.Error("no deduped new body again") + } else if !do(http.MethodPost, "2") { + t.Error("couldnt get a fifth time with new new body") + } else if targetCalls != 3 { + t.Error("deduped new new body") + } +}