diff --git a/src/pool/pool.go b/src/pool/pool.go new file mode 100644 index 0000000..5faee9f --- /dev/null +++ b/src/pool/pool.go @@ -0,0 +1,144 @@ +package pool + +import ( + "context" + "fmt" + "sync" + "time" +) + +type Pool struct { + wg sync.WaitGroup + lock sync.RWMutex + p int + errs []error + jobs chan job +} + +type job struct { + name string + foo func() error +} + +func New(p int) *Pool { + return &Pool{ + p: p, + wg: sync.WaitGroup{}, + lock: sync.RWMutex{}, + errs: []error{}, + } +} + +func (p *Pool) Go(ctx context.Context, name string, foo func() error) error { + p.spawn() + select { + case p.jobs <- job{foo: foo, name: name}: + case <-ctx.Done(): + } + return ctx.Err() +} + +func (p *Pool) Wait(ctx context.Context) error { + waited := make(chan bool) + defer close(waited) + go func() { + c := time.NewTicker(100 * time.Millisecond) + defer c.Stop() + + if p.jobs != nil { + for len(p.jobs) > 0 && ctx.Err() == nil { + select { + case <-ctx.Done(): + case <-c.C: + } + } + close(p.jobs) + } + + p.wg.Wait() + select { + case <-ctx.Done(): + case waited <- true: + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-waited: + p.jobs = nil + } + + if len(p.errs) == 0 { + return nil + } + + result := "" + for _, err := range p.errs { + if err == nil { + continue + } + if result != "" { + result += "\n" + } + result += fmt.Sprintf("* %s", err.Error()) + } + p.errs = []error{} + + return fmt.Errorf("%s", result) +} + +func (p *Pool) spawn() { + if p.alive() { + return + } + + p.withLock(func() { + if p._alive() { + return + } + + p.jobs = make(chan job) + for i := int(0); i < p.p; i++ { + p.wg.Add(1) + go func() { + defer p.wg.Done() + p.doJobs() + }() + } + }) +} + +func (p *Pool) doJobs() { + for job := range p.jobs { + if err := job.foo(); err != nil { + p.withRLock(func() { + p.errs = append(p.errs, fmt.Errorf("%s: %w", job.name, err)) + }) + } + } +} + +func (p *Pool) alive() bool { + f := false + p.withRLock(func() { + f = p._alive() + }) + return f +} + +func (p *Pool) _alive() bool { + return p.jobs != nil +} + +func (p *Pool) withRLock(foo func()) { + p.lock.RLock() + defer p.lock.RUnlock() + foo() +} + +func (p *Pool) withLock(foo func()) { + p.lock.Lock() + defer p.lock.Unlock() + foo() +} diff --git a/src/pool/pool_test.go b/src/pool/pool_test.go new file mode 100644 index 0000000..c6ef43f --- /dev/null +++ b/src/pool/pool_test.go @@ -0,0 +1,53 @@ +package pool_test + +import ( + "context" + "show-rss/src/pool" + "strconv" + "sync/atomic" + "testing" + "time" +) + +func TestPool(t *testing.T) { + ctx, can := context.WithTimeout(context.Background(), 5*time.Second) + defer can() + + p := pool.New(1) + + if err := p.Wait(ctx); err != nil { + t.Fatalf("failed to wait for empty pool: %v", err) + } else if err := p.Wait(ctx); err != nil { + t.Fatalf("failed redundant wait for empty pool: %v", err) + } + + done := false + if err := p.Go(ctx, "first", func() error { + done = true + return nil + }); err != nil { + t.Fatalf("failed to go { return nil }: %v", err) + } + + if err := p.Wait(ctx); err != nil { + t.Fatalf("failed to wait for 1: %v", err) + } else if !done { + t.Fatalf("wait didnt actually run func: done=%v", done) + } + + n := &atomic.Uint32{} + for i := 0; i < 100; i++ { + if err := p.Go(ctx, strconv.Itoa(i), func() error { + n.Add(1) + return nil + }); err != nil { + t.Fatalf("failed to go { return nil }: %v", err) + } + } + + if err := p.Wait(ctx); err != nil { + t.Fatalf("failed to wait for 100: %v", err) + } else if n := n.Load(); n != 100 { + t.Fatalf("only called %d of 100", n) + } +}