package server import ( "context" "encoding/json" "log" "net/http" "sync" "github.com/google/uuid" "github.com/gorilla/websocket" "golang.org/x/time/rate" ) type session struct { ctx context.Context can context.CancelFunc ws *websocket.Conn wg sync.WaitGroup cb func(message) error id string scatterc chan (message) } var upgrader = websocket.Upgrader{} func newSession(w http.ResponseWriter, r *http.Request, cb func(message) error) (*session, error) { c, err := upgrader.Upgrade(w, r, nil) ctx, can := context.WithCancel(r.Context()) return &session{ ctx: ctx, can: can, ws: c, cb: cb, id: uuid.New().String(), scatterc: make(chan message, 20), }, err } func (s *session) Close() { if s.ws != nil { s.ws.Close() } s.ws = nil s.can() s.wg.Wait() } func (s *session) Run() error { defer s.Close() go s.gather() go s.scatter() <-s.ctx.Done() return s.ctx.Err() } func (s *session) gather() { s.while(func() error { mt, msg, err := s.ws.ReadMessage() if err != nil { return err } if mt != 1 { return nil } var m message if err := json.Unmarshal(msg, &m); err != nil { return err } log.Printf("gathered %+v (%s)", m, msg) return s.cb(m) }) } func (s *session) scatter() { s.while(func() error { select { case m := <-s.scatterc: log.Printf("scattering %+v", m) b, _ := json.Marshal(m) return s.ws.WriteMessage(1, b) case <-s.ctx.Done(): return s.ctx.Err() } }) } func (s *session) while(foo func() error) { defer s.can() s.wg.Add(1) defer s.wg.Done() l := rate.NewLimiter(20, 1) for l.Wait(s.ctx) == nil { if err := foo(); err != nil { log.Println(err) return } } }