diff --git a/main.go b/main.go index a754e62..2879dae 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "log" "math/rand" "net/http" + "strconv" "strings" "time" @@ -31,23 +32,45 @@ func main() { links := bytes.Split(b, []byte{'\n'}) log.Print(port) limiter := rate.NewLimiter(1, 1) - if err := http.ListenAndServe(fmt.Sprintf(":%d", port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := http.ListenAndServe(fmt.Sprintf(":%d", port), ServeHTTP(limiter, links, proxy)); err != nil { + panic(err) + } +} + +func ServeHTTP(limiter *rate.Limiter, links [][]byte, proxy bool) http.HandlerFunc { + handler := serveHTTP(limiter, links, proxy) + return func(w http.ResponseWriter, r *http.Request) { + s := r.URL.Query().Get("n") + if s == "" { + s = "1" + } + n, err := strconv.Atoi(s) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + for i := 0; i < n; i++ { + handler(w, r) + } + } +} + +func serveHTTP(limiter *rate.Limiter, links [][]byte, proxy bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { limiter.Wait(r.Context()) link := "" for len(link) == 0 { link = string(links[rand.Intn(len(links))]) } if strings.HasPrefix(link, "http") { - serveHTTP(w, r, proxy, link) + _serveHTTP(w, r, proxy, link) } else if strings.HasPrefix(link, "literal://") { w.Write([]byte(strings.TrimPrefix(link, "literal://") + "\n")) } - })); err != nil { - panic(err) } } -func serveHTTP(w http.ResponseWriter, r *http.Request, proxy bool, link string) { +func _serveHTTP(w http.ResponseWriter, r *http.Request, proxy bool, link string) { if proxy { resp, err := http.Get(link) if err != nil {