tts-room/src/server/session.go

106 lines
1.7 KiB
Go

package server
import (
"context"
"encoding/json"
"log"
"net/http"
"sync"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"golang.org/x/time/rate"
)
type session struct {
ctx context.Context
can context.CancelFunc
ws *websocket.Conn
wg sync.WaitGroup
cb func(message) error
id string
scatterc chan (message)
}
var upgrader = websocket.Upgrader{}
func newSession(w http.ResponseWriter, r *http.Request, cb func(message) error) (*session, error) {
c, err := upgrader.Upgrade(w, r, nil)
ctx, can := context.WithCancel(r.Context())
return &session{
ctx: ctx,
can: can,
ws: c,
cb: cb,
id: uuid.New().String(),
scatterc: make(chan message, 20),
}, err
}
func (s *session) Close() {
if s.ws != nil {
s.ws.Close()
}
s.ws = nil
s.can()
s.wg.Wait()
}
func (s *session) Run() error {
defer s.Close()
go s.gather()
go s.scatter()
<-s.ctx.Done()
return s.ctx.Err()
}
func (s *session) gather() {
s.while(func() error {
mt, msg, err := s.ws.ReadMessage()
if err != nil {
return err
}
if mt != 1 {
return nil
}
var m message
if err := json.Unmarshal(msg, &m); err != nil {
return err
}
log.Printf("gathered %+v (%s)", m, msg)
return s.cb(m)
})
}
func (s *session) scatter() {
s.while(func() error {
select {
case m := <-s.scatterc:
log.Printf("scattering %+v", m)
b, _ := json.Marshal(m)
return s.ws.WriteMessage(1, b)
case <-s.ctx.Done():
return s.ctx.Err()
}
})
}
func (s *session) while(foo func() error) {
defer s.can()
s.wg.Add(1)
defer s.wg.Done()
l := rate.NewLimiter(20, 1)
for l.Wait(s.ctx) == nil {
if err := foo(); err != nil {
log.Println(err)
return
}
}
}