impl with.CachedHTTP
This commit is contained in:
140
http.go
Normal file
140
http.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package with
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RoundTripper struct {
|
||||||
|
http.RoundTripper
|
||||||
|
kv SQLKV
|
||||||
|
}
|
||||||
|
|
||||||
|
func CachedHTTP(ctx context.Context, foo func(*http.Client) error) error {
|
||||||
|
return Sqlite(ctx, ":memory:", func(db *sql.DB) error {
|
||||||
|
return KV(ctx, db, func(kv SQLKV) error {
|
||||||
|
return foo(&http.Client{
|
||||||
|
Timeout: time.Minute,
|
||||||
|
Transport: RoundTripper{
|
||||||
|
RoundTripper: &http.Transport{
|
||||||
|
DisableKeepAlives: true,
|
||||||
|
},
|
||||||
|
kv: kv,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c RoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
req := newCacheableRequest(r)
|
||||||
|
if v, err := c.kv.Get(r.Context(), req.cacheK()); err != nil {
|
||||||
|
} else if resp := parseCacheableResponse(v); resp != nil {
|
||||||
|
return resp.response(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.RoundTripper.RoundTrip(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.kv.Set(r.Context(), req.cacheK(), newCacheableResponse(resp).cacheV())
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheableRequest struct {
|
||||||
|
URLHost string
|
||||||
|
URLPath string
|
||||||
|
URLQuery cacheableHTTPHeader
|
||||||
|
Header cacheableHTTPHeader
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCacheableRequest(r *http.Request) cacheableRequest {
|
||||||
|
defer r.Body.Close()
|
||||||
|
b, _ := io.ReadAll(r.Body)
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(b))
|
||||||
|
return cacheableRequest{
|
||||||
|
URLHost: r.URL.Host,
|
||||||
|
URLPath: r.URL.Path,
|
||||||
|
URLQuery: newCacheableHTTPHeader(r.URL.Query()),
|
||||||
|
Header: newCacheableHTTPHeader(r.Header),
|
||||||
|
Body: string(b),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cacheableRequest) cacheK() string {
|
||||||
|
return fmt.Sprint(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheableResponse struct {
|
||||||
|
Code int
|
||||||
|
Header cacheableHTTPHeader
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCacheableResponse(resp *http.Response) cacheableResponse {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(b))
|
||||||
|
return cacheableResponse{
|
||||||
|
Code: resp.StatusCode,
|
||||||
|
Header: newCacheableHTTPHeader(resp.Header),
|
||||||
|
Body: string(b),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCacheableResponse(b []byte) *cacheableResponse {
|
||||||
|
var c cacheableResponse
|
||||||
|
if err := json.Unmarshal(b, &c); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cacheableResponse) cacheV() []byte {
|
||||||
|
b, _ := json.Marshal(c)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cacheableResponse) response() *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: c.Code,
|
||||||
|
Header: c.Header.header(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader([]byte(c.Body))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheableHTTPHeader [][]string
|
||||||
|
|
||||||
|
func newCacheableHTTPHeader(m map[string][]string) cacheableHTTPHeader {
|
||||||
|
result := [][]string{}
|
||||||
|
for k, v := range m {
|
||||||
|
result = append(result, append([]string{k}, v...))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cacheableHTTPHeader) header() http.Header {
|
||||||
|
return http.Header(c.m())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cacheableHTTPHeader) query() url.Values {
|
||||||
|
return url.Values(c.m())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cacheableHTTPHeader) m() map[string][]string {
|
||||||
|
m := map[string][]string{}
|
||||||
|
for _, v := range c {
|
||||||
|
v := v
|
||||||
|
m[v[0]] = v[1:]
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
119
http_test.go
Normal file
119
http_test.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package with_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.bel.blue/bel/with"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCachedHTTP(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
code := http.StatusAccepted
|
||||||
|
headers := http.Header{
|
||||||
|
"K": []string{"v"},
|
||||||
|
"K2": []string{"v2a", "v2b"},
|
||||||
|
}
|
||||||
|
body := "response body"
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
for k, v := range headers {
|
||||||
|
for _, subv := range v {
|
||||||
|
w.Header().Add(k, subv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(code)
|
||||||
|
w.Write([]byte(body))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
testResp := func(resp *http.Response) {
|
||||||
|
if resp.StatusCode != code {
|
||||||
|
t.Errorf("initial request wrong status code: %d", resp.StatusCode)
|
||||||
|
} else if !slices.Equal(resp.Header["K"], headers["K"]) {
|
||||||
|
t.Errorf("initial request wrong headers[k]: %+v in %+v", resp.Header["K"], resp.Header)
|
||||||
|
} else if !slices.Equal(resp.Header["K2"], headers["K2"]) {
|
||||||
|
t.Errorf("initial request wrong headers[k2]: %+v in %+v", resp.Header["K2"], resp.Header)
|
||||||
|
} else if b, _ := io.ReadAll(resp.Body); string(b) != body {
|
||||||
|
t.Errorf("initial request wrong resp body: %q", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := with.CachedHTTP(context.Background(), func(c *http.Client) error {
|
||||||
|
req := func() *http.Request {
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, server.URL+"/my/path", strings.NewReader("my body"))
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp, err := c.Do(req()); err != nil {
|
||||||
|
t.Fatalf("failed initial request: %v", err)
|
||||||
|
} else if !called {
|
||||||
|
t.Errorf("initial request didnt hit server")
|
||||||
|
} else {
|
||||||
|
testResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
called = false
|
||||||
|
if resp, err := c.Do(req()); err != nil {
|
||||||
|
t.Fatalf("failed second request: %v", err)
|
||||||
|
} else if called {
|
||||||
|
t.Errorf("second request didnt hit cache")
|
||||||
|
} else {
|
||||||
|
testResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
called = false
|
||||||
|
reqDiffURLPath := req()
|
||||||
|
reqDiffURLPath.URL.Path += "/teehee"
|
||||||
|
if resp, err := c.Do(reqDiffURLPath); err != nil {
|
||||||
|
t.Fatalf("failed diff url path request: %v", err)
|
||||||
|
} else if !called {
|
||||||
|
t.Errorf("new initial hit cache")
|
||||||
|
} else {
|
||||||
|
testResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
called = false
|
||||||
|
reqDiffURLQuery := req()
|
||||||
|
reqDiffURLQuery.URL.RawQuery += "hello=world"
|
||||||
|
if resp, err := c.Do(reqDiffURLQuery); err != nil {
|
||||||
|
t.Fatalf("failed diff url query request: %v", err)
|
||||||
|
} else if !called {
|
||||||
|
t.Errorf("new initial hit cache")
|
||||||
|
} else {
|
||||||
|
testResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
called = false
|
||||||
|
reqDiffHeader := req()
|
||||||
|
reqDiffHeader.Header.Set("Misc", "Misc")
|
||||||
|
if resp, err := c.Do(reqDiffHeader); err != nil {
|
||||||
|
t.Fatalf("failed diff header request: %v", err)
|
||||||
|
} else if !called {
|
||||||
|
t.Errorf("new initial hit cache")
|
||||||
|
} else {
|
||||||
|
testResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
called = false
|
||||||
|
reqDiffBody := req()
|
||||||
|
reqDiffBody.Body = io.NopCloser(strings.NewReader("diff"))
|
||||||
|
reqDiffBody.ContentLength = int64(len("diff"))
|
||||||
|
if resp, err := c.Do(reqDiffBody); err != nil {
|
||||||
|
t.Fatalf("failed diff body request: %v", err)
|
||||||
|
} else if !called {
|
||||||
|
t.Errorf("new initial hit cache")
|
||||||
|
} else {
|
||||||
|
testResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user