106 lines
1.7 KiB
Go
106 lines
1.7 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
type session struct {
|
|
ctx context.Context
|
|
can context.CancelFunc
|
|
ws *websocket.Conn
|
|
wg sync.WaitGroup
|
|
cb func(message) error
|
|
id string
|
|
scatterc chan (message)
|
|
}
|
|
|
|
var upgrader = websocket.Upgrader{}
|
|
|
|
func newSession(w http.ResponseWriter, r *http.Request, cb func(message) error) (*session, error) {
|
|
c, err := upgrader.Upgrade(w, r, nil)
|
|
ctx, can := context.WithCancel(r.Context())
|
|
return &session{
|
|
ctx: ctx,
|
|
can: can,
|
|
ws: c,
|
|
cb: cb,
|
|
id: uuid.New().String(),
|
|
scatterc: make(chan message, 20),
|
|
}, err
|
|
}
|
|
|
|
func (s *session) Close() {
|
|
if s.ws != nil {
|
|
s.ws.Close()
|
|
}
|
|
s.ws = nil
|
|
s.can()
|
|
s.wg.Wait()
|
|
}
|
|
|
|
func (s *session) Run() error {
|
|
defer s.Close()
|
|
|
|
go s.gather()
|
|
go s.scatter()
|
|
|
|
<-s.ctx.Done()
|
|
return s.ctx.Err()
|
|
}
|
|
|
|
func (s *session) gather() {
|
|
s.while(func() error {
|
|
mt, msg, err := s.ws.ReadMessage()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if mt != 1 {
|
|
return nil
|
|
}
|
|
|
|
var m message
|
|
if err := json.Unmarshal(msg, &m); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Printf("gathered %+v (%s)", m, msg)
|
|
return s.cb(m)
|
|
})
|
|
}
|
|
|
|
func (s *session) scatter() {
|
|
s.while(func() error {
|
|
select {
|
|
case m := <-s.scatterc:
|
|
log.Printf("scattering %+v", m)
|
|
b, _ := json.Marshal(m)
|
|
return s.ws.WriteMessage(1, b)
|
|
case <-s.ctx.Done():
|
|
return s.ctx.Err()
|
|
}
|
|
})
|
|
}
|
|
|
|
func (s *session) while(foo func() error) {
|
|
defer s.can()
|
|
|
|
s.wg.Add(1)
|
|
defer s.wg.Done()
|
|
|
|
l := rate.NewLimiter(20, 1)
|
|
for l.Wait(s.ctx) == nil {
|
|
if err := foo(); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
}
|
|
}
|