main
Bel LaPointe 2025-04-24 20:21:28 -06:00
parent f7e82ff588
commit 5a51ebf884
2 changed files with 197 additions and 0 deletions

144
src/pool/pool.go Normal file
View File

@ -0,0 +1,144 @@
package pool
import (
"context"
"fmt"
"sync"
"time"
)
type Pool struct {
wg sync.WaitGroup
lock sync.RWMutex
p int
errs []error
jobs chan job
}
type job struct {
name string
foo func() error
}
func New(p int) *Pool {
return &Pool{
p: p,
wg: sync.WaitGroup{},
lock: sync.RWMutex{},
errs: []error{},
}
}
func (p *Pool) Go(ctx context.Context, name string, foo func() error) error {
p.spawn()
select {
case p.jobs <- job{foo: foo, name: name}:
case <-ctx.Done():
}
return ctx.Err()
}
func (p *Pool) Wait(ctx context.Context) error {
waited := make(chan bool)
defer close(waited)
go func() {
c := time.NewTicker(100 * time.Millisecond)
defer c.Stop()
if p.jobs != nil {
for len(p.jobs) > 0 && ctx.Err() == nil {
select {
case <-ctx.Done():
case <-c.C:
}
}
close(p.jobs)
}
p.wg.Wait()
select {
case <-ctx.Done():
case waited <- true:
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-waited:
p.jobs = nil
}
if len(p.errs) == 0 {
return nil
}
result := ""
for _, err := range p.errs {
if err == nil {
continue
}
if result != "" {
result += "\n"
}
result += fmt.Sprintf("* %s", err.Error())
}
p.errs = []error{}
return fmt.Errorf("%s", result)
}
func (p *Pool) spawn() {
if p.alive() {
return
}
p.withLock(func() {
if p._alive() {
return
}
p.jobs = make(chan job)
for i := int(0); i < p.p; i++ {
p.wg.Add(1)
go func() {
defer p.wg.Done()
p.doJobs()
}()
}
})
}
func (p *Pool) doJobs() {
for job := range p.jobs {
if err := job.foo(); err != nil {
p.withRLock(func() {
p.errs = append(p.errs, fmt.Errorf("%s: %w", job.name, err))
})
}
}
}
func (p *Pool) alive() bool {
f := false
p.withRLock(func() {
f = p._alive()
})
return f
}
func (p *Pool) _alive() bool {
return p.jobs != nil
}
func (p *Pool) withRLock(foo func()) {
p.lock.RLock()
defer p.lock.RUnlock()
foo()
}
func (p *Pool) withLock(foo func()) {
p.lock.Lock()
defer p.lock.Unlock()
foo()
}

53
src/pool/pool_test.go Normal file
View File

@ -0,0 +1,53 @@
package pool_test
import (
"context"
"show-rss/src/pool"
"strconv"
"sync/atomic"
"testing"
"time"
)
func TestPool(t *testing.T) {
ctx, can := context.WithTimeout(context.Background(), 5*time.Second)
defer can()
p := pool.New(1)
if err := p.Wait(ctx); err != nil {
t.Fatalf("failed to wait for empty pool: %v", err)
} else if err := p.Wait(ctx); err != nil {
t.Fatalf("failed redundant wait for empty pool: %v", err)
}
done := false
if err := p.Go(ctx, "first", func() error {
done = true
return nil
}); err != nil {
t.Fatalf("failed to go { return nil }: %v", err)
}
if err := p.Wait(ctx); err != nil {
t.Fatalf("failed to wait for 1: %v", err)
} else if !done {
t.Fatalf("wait didnt actually run func: done=%v", done)
}
n := &atomic.Uint32{}
for i := 0; i < 100; i++ {
if err := p.Go(ctx, strconv.Itoa(i), func() error {
n.Add(1)
return nil
}); err != nil {
t.Fatalf("failed to go { return nil }: %v", err)
}
}
if err := p.Wait(ctx); err != nil {
t.Fatalf("failed to wait for 100: %v", err)
} else if n := n.Load(); n != 100 {
t.Fatalf("only called %d of 100", n)
}
}