pool
parent
f7e82ff588
commit
5a51ebf884
|
|
@ -0,0 +1,144 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Pool struct {
|
||||
wg sync.WaitGroup
|
||||
lock sync.RWMutex
|
||||
p int
|
||||
errs []error
|
||||
jobs chan job
|
||||
}
|
||||
|
||||
type job struct {
|
||||
name string
|
||||
foo func() error
|
||||
}
|
||||
|
||||
func New(p int) *Pool {
|
||||
return &Pool{
|
||||
p: p,
|
||||
wg: sync.WaitGroup{},
|
||||
lock: sync.RWMutex{},
|
||||
errs: []error{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) Go(ctx context.Context, name string, foo func() error) error {
|
||||
p.spawn()
|
||||
select {
|
||||
case p.jobs <- job{foo: foo, name: name}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
func (p *Pool) Wait(ctx context.Context) error {
|
||||
waited := make(chan bool)
|
||||
defer close(waited)
|
||||
go func() {
|
||||
c := time.NewTicker(100 * time.Millisecond)
|
||||
defer c.Stop()
|
||||
|
||||
if p.jobs != nil {
|
||||
for len(p.jobs) > 0 && ctx.Err() == nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-c.C:
|
||||
}
|
||||
}
|
||||
close(p.jobs)
|
||||
}
|
||||
|
||||
p.wg.Wait()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case waited <- true:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-waited:
|
||||
p.jobs = nil
|
||||
}
|
||||
|
||||
if len(p.errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := ""
|
||||
for _, err := range p.errs {
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
if result != "" {
|
||||
result += "\n"
|
||||
}
|
||||
result += fmt.Sprintf("* %s", err.Error())
|
||||
}
|
||||
p.errs = []error{}
|
||||
|
||||
return fmt.Errorf("%s", result)
|
||||
}
|
||||
|
||||
func (p *Pool) spawn() {
|
||||
if p.alive() {
|
||||
return
|
||||
}
|
||||
|
||||
p.withLock(func() {
|
||||
if p._alive() {
|
||||
return
|
||||
}
|
||||
|
||||
p.jobs = make(chan job)
|
||||
for i := int(0); i < p.p; i++ {
|
||||
p.wg.Add(1)
|
||||
go func() {
|
||||
defer p.wg.Done()
|
||||
p.doJobs()
|
||||
}()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Pool) doJobs() {
|
||||
for job := range p.jobs {
|
||||
if err := job.foo(); err != nil {
|
||||
p.withRLock(func() {
|
||||
p.errs = append(p.errs, fmt.Errorf("%s: %w", job.name, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) alive() bool {
|
||||
f := false
|
||||
p.withRLock(func() {
|
||||
f = p._alive()
|
||||
})
|
||||
return f
|
||||
}
|
||||
|
||||
func (p *Pool) _alive() bool {
|
||||
return p.jobs != nil
|
||||
}
|
||||
|
||||
func (p *Pool) withRLock(foo func()) {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
foo()
|
||||
}
|
||||
|
||||
func (p *Pool) withLock(foo func()) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
foo()
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
package pool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"show-rss/src/pool"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPool(t *testing.T) {
|
||||
ctx, can := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer can()
|
||||
|
||||
p := pool.New(1)
|
||||
|
||||
if err := p.Wait(ctx); err != nil {
|
||||
t.Fatalf("failed to wait for empty pool: %v", err)
|
||||
} else if err := p.Wait(ctx); err != nil {
|
||||
t.Fatalf("failed redundant wait for empty pool: %v", err)
|
||||
}
|
||||
|
||||
done := false
|
||||
if err := p.Go(ctx, "first", func() error {
|
||||
done = true
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to go { return nil }: %v", err)
|
||||
}
|
||||
|
||||
if err := p.Wait(ctx); err != nil {
|
||||
t.Fatalf("failed to wait for 1: %v", err)
|
||||
} else if !done {
|
||||
t.Fatalf("wait didnt actually run func: done=%v", done)
|
||||
}
|
||||
|
||||
n := &atomic.Uint32{}
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := p.Go(ctx, strconv.Itoa(i), func() error {
|
||||
n.Add(1)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to go { return nil }: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.Wait(ctx); err != nil {
|
||||
t.Fatalf("failed to wait for 100: %v", err)
|
||||
} else if n := n.Load(); n != 100 {
|
||||
t.Fatalf("only called %d of 100", n)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue