From 287256a20d7fa797ab76b40a10185f4f7bd84c67 Mon Sep 17 00:00:00 2001 From: Bel LaPointe <153096461+breel-render@users.noreply.github.com> Date: Mon, 9 Mar 2026 09:41:12 -0600 Subject: [PATCH] impl with.CachedHTTP --- http.go | 140 +++++++++++++++++++++++++++++++++++++++++++++++++++ http_test.go | 119 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 259 insertions(+) create mode 100644 http.go create mode 100644 http_test.go diff --git a/http.go b/http.go new file mode 100644 index 0000000..06c3a0c --- /dev/null +++ b/http.go @@ -0,0 +1,140 @@ +package with + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +type RoundTripper struct { + http.RoundTripper + kv SQLKV +} + +func CachedHTTP(ctx context.Context, foo func(*http.Client) error) error { + return Sqlite(ctx, ":memory:", func(db *sql.DB) error { + return KV(ctx, db, func(kv SQLKV) error { + return foo(&http.Client{ + Timeout: time.Minute, + Transport: RoundTripper{ + RoundTripper: &http.Transport{ + DisableKeepAlives: true, + }, + kv: kv, + }, + }) + }) + }) +} + +func (c RoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + req := newCacheableRequest(r) + if v, err := c.kv.Get(r.Context(), req.cacheK()); err != nil { + } else if resp := parseCacheableResponse(v); resp != nil { + return resp.response(), nil + } + + resp, err := c.RoundTripper.RoundTrip(r) + if err != nil { + return nil, err + } + c.kv.Set(r.Context(), req.cacheK(), newCacheableResponse(resp).cacheV()) + + return resp, nil +} + +type cacheableRequest struct { + URLHost string + URLPath string + URLQuery cacheableHTTPHeader + Header cacheableHTTPHeader + Body string +} + +func newCacheableRequest(r *http.Request) cacheableRequest { + defer r.Body.Close() + b, _ := io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewReader(b)) + return cacheableRequest{ + URLHost: r.URL.Host, + URLPath: r.URL.Path, + URLQuery: newCacheableHTTPHeader(r.URL.Query()), + Header: newCacheableHTTPHeader(r.Header), + Body: string(b), + } +} + +func (c cacheableRequest) cacheK() string { + return fmt.Sprint(c) +} + +type cacheableResponse struct { + Code int + Header cacheableHTTPHeader + Body string +} + +func newCacheableResponse(resp *http.Response) cacheableResponse { + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + resp.Body = io.NopCloser(bytes.NewReader(b)) + return cacheableResponse{ + Code: resp.StatusCode, + Header: newCacheableHTTPHeader(resp.Header), + Body: string(b), + } +} + +func parseCacheableResponse(b []byte) *cacheableResponse { + var c cacheableResponse + if err := json.Unmarshal(b, &c); err != nil { + return nil + } + return &c +} + +func (c cacheableResponse) cacheV() []byte { + b, _ := json.Marshal(c) + return b +} + +func (c cacheableResponse) response() *http.Response { + return &http.Response{ + StatusCode: c.Code, + Header: c.Header.header(), + Body: io.NopCloser(bytes.NewReader([]byte(c.Body))), + } +} + +type cacheableHTTPHeader [][]string + +func newCacheableHTTPHeader(m map[string][]string) cacheableHTTPHeader { + result := [][]string{} + for k, v := range m { + result = append(result, append([]string{k}, v...)) + } + return result +} + +func (c cacheableHTTPHeader) header() http.Header { + return http.Header(c.m()) +} + +func (c cacheableHTTPHeader) query() url.Values { + return url.Values(c.m()) +} + +func (c cacheableHTTPHeader) m() map[string][]string { + m := map[string][]string{} + for _, v := range c { + v := v + m[v[0]] = v[1:] + } + return m +} diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..908dcef --- /dev/null +++ b/http_test.go @@ -0,0 +1,119 @@ +package with_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "slices" + "strings" + "testing" + + "gitea.bel.blue/bel/with" +) + +func TestCachedHTTP(t *testing.T) { + called := false + code := http.StatusAccepted + headers := http.Header{ + "K": []string{"v"}, + "K2": []string{"v2a", "v2b"}, + } + body := "response body" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + for k, v := range headers { + for _, subv := range v { + w.Header().Add(k, subv) + } + } + w.WriteHeader(code) + w.Write([]byte(body)) + })) + defer server.Close() + + testResp := func(resp *http.Response) { + if resp.StatusCode != code { + t.Errorf("initial request wrong status code: %d", resp.StatusCode) + } else if !slices.Equal(resp.Header["K"], headers["K"]) { + t.Errorf("initial request wrong headers[k]: %+v in %+v", resp.Header["K"], resp.Header) + } else if !slices.Equal(resp.Header["K2"], headers["K2"]) { + t.Errorf("initial request wrong headers[k2]: %+v in %+v", resp.Header["K2"], resp.Header) + } else if b, _ := io.ReadAll(resp.Body); string(b) != body { + t.Errorf("initial request wrong resp body: %q", b) + } + } + + if err := with.CachedHTTP(context.Background(), func(c *http.Client) error { + req := func() *http.Request { + req, _ := http.NewRequest(http.MethodGet, server.URL+"/my/path", strings.NewReader("my body")) + return req + } + + if resp, err := c.Do(req()); err != nil { + t.Fatalf("failed initial request: %v", err) + } else if !called { + t.Errorf("initial request didnt hit server") + } else { + testResp(resp) + } + + called = false + if resp, err := c.Do(req()); err != nil { + t.Fatalf("failed second request: %v", err) + } else if called { + t.Errorf("second request didnt hit cache") + } else { + testResp(resp) + } + + called = false + reqDiffURLPath := req() + reqDiffURLPath.URL.Path += "/teehee" + if resp, err := c.Do(reqDiffURLPath); err != nil { + t.Fatalf("failed diff url path request: %v", err) + } else if !called { + t.Errorf("new initial hit cache") + } else { + testResp(resp) + } + + called = false + reqDiffURLQuery := req() + reqDiffURLQuery.URL.RawQuery += "hello=world" + if resp, err := c.Do(reqDiffURLQuery); err != nil { + t.Fatalf("failed diff url query request: %v", err) + } else if !called { + t.Errorf("new initial hit cache") + } else { + testResp(resp) + } + + called = false + reqDiffHeader := req() + reqDiffHeader.Header.Set("Misc", "Misc") + if resp, err := c.Do(reqDiffHeader); err != nil { + t.Fatalf("failed diff header request: %v", err) + } else if !called { + t.Errorf("new initial hit cache") + } else { + testResp(resp) + } + + called = false + reqDiffBody := req() + reqDiffBody.Body = io.NopCloser(strings.NewReader("diff")) + reqDiffBody.ContentLength = int64(len("diff")) + if resp, err := c.Do(reqDiffBody); err != nil { + t.Fatalf("failed diff body request: %v", err) + } else if !called { + t.Errorf("new initial hit cache") + } else { + testResp(resp) + } + + return nil + }); err != nil { + t.Fatal(err) + } +}