From d2cebde4dc91923edc61146c6551bfd56fb9c9b3 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Wed, 13 May 2020 16:29:23 -0600 Subject: [PATCH] locking in prog --- pool.go | 21 ++++++++++++++++++--- ws.go | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pool.go b/pool.go index 25bb37e..3fef169 100644 --- a/pool.go +++ b/pool.go @@ -8,17 +8,32 @@ import ( "github.com/gorilla/websocket" ) +type Conn struct { + ws websocket.Conn + lock sync.Mutex +} + type Pool struct { - conns *sync.Map //map[string]*websocket.Conn + lock *sync.RWMutex + conns *sync.Map //map[string]*Conn } func NewPool() *Pool { return &Pool{ - conns: &sync.Map{}, //map[string]*websocket.Conn{}, + conns: &sync.Map{}, + lock: &sync.RWMutex{}, } } +func (p *Pool) Push(id string, conn *websocket.Conn) { + p.lock.Lock() + defer p.lock.Unlock() + p.conns.Store(id, &Conn{ws: *conn}) +} + func (p *Pool) Broadcast(mt int, r io.Reader) error { + p.lock.RLock() + defer p.lock.RUnlock() // io.MultiWriter exists but I like this b, err := ioutil.ReadAll(r) if err != nil { @@ -28,7 +43,7 @@ func (p *Pool) Broadcast(mt int, r io.Reader) error { cnt := 0 p.conns.Range(func(k, v interface{}) bool { k = k.(string) - conn := v.(*websocket.Conn) + conn := &v.(*Conn).ws cnt += 1 w, err := conn.NextWriter(mt) if err != nil { diff --git a/ws.go b/ws.go index 1578d66..cd94aa0 100644 --- a/ws.go +++ b/ws.go @@ -49,7 +49,7 @@ func (ws *WS) serveHTTP(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - pool.conns.Store(id, conn) + pool.Push(id, conn) // conns.Store(id, conn) for { mt, reader, err := conn.NextReader() if err != nil {