diff --git a/go.mod b/go.mod index a10a9f1..02c5cda 100644 --- a/go.mod +++ b/go.mod @@ -6,3 +6,5 @@ require ( github.com/gorilla/websocket v1.5.3 golang.org/x/time v0.14.0 ) + +require github.com/google/uuid v1.6.0 // indirect diff --git a/go.sum b/go.sum index 296e66c..4290850 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= diff --git a/src/public/index.html b/src/public/index.html index c6c0c2f..2dd0649 100644 --- a/src/public/index.html +++ b/src/public/index.html @@ -1,77 +1,55 @@ - - - - - - -
-

Click "Open" to create a connection to the server, -"Send" to send a message to the server and "Close" to close the connection. -You can change the message and send multiple times. -

-

- - -

- -

-
-
-
- + + + + + +
+
+

+ +

+
+
+ diff --git a/src/server/message.go b/src/server/message.go new file mode 100644 index 0000000..6eb841e --- /dev/null +++ b/src/server/message.go @@ -0,0 +1,5 @@ +package server + +type message struct { + Text string +} diff --git a/src/server/server.go b/src/server/server.go index 8229ace..855ef6d 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -5,13 +5,15 @@ import ( "net/http" ) -type Server struct{} - -func NewServer() Server { - return Server{} +type Server struct { + sessions []*session } -func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func NewServer() *Server { + return &Server{} +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/ws": if err := s.WS(w, r); err != nil { @@ -22,12 +24,35 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s Server) WS(w http.ResponseWriter, r *http.Request) error { - sess, err := newSession(w, r) +func (s *Server) WS(w http.ResponseWriter, r *http.Request) error { + sess, err := newSession(w, r, nil) if err != nil { return err } defer sess.Close() + sess.cb = func(m message) error { + log.Printf("cbing to all other sessions %+v", m) + for i := range s.sessions { + if s.sessions[i].id != sess.id { + select { + case s.sessions[i].scatterc <- m: + case <-s.sessions[i].ctx.Done(): + } + } + } + return nil + } + + s.sessions = append(s.sessions, sess) + defer func() { + for i := range s.sessions { + if s.sessions[i].id == sess.id { + s.sessions = append(s.sessions[:i], s.sessions[i+1:]...) + return + } + } + }() + return sess.Run() } diff --git a/src/server/session.go b/src/server/session.go index 84e6e28..fab1277 100644 --- a/src/server/session.go +++ b/src/server/session.go @@ -2,30 +2,38 @@ package server import ( "context" + "encoding/json" "log" "net/http" "sync" + "github.com/google/uuid" "github.com/gorilla/websocket" "golang.org/x/time/rate" ) type session struct { - ctx context.Context - can context.CancelFunc - ws *websocket.Conn - wg sync.WaitGroup + ctx context.Context + can context.CancelFunc + ws *websocket.Conn + wg sync.WaitGroup + cb func(message) error + id string + scatterc chan (message) } var upgrader = websocket.Upgrader{} -func newSession(w http.ResponseWriter, r *http.Request) (*session, error) { +func newSession(w http.ResponseWriter, r *http.Request, cb func(message) error) (*session, error) { c, err := upgrader.Upgrade(w, r, nil) ctx, can := context.WithCancel(r.Context()) return &session{ - ctx: ctx, - can: can, - ws: c, + ctx: ctx, + can: can, + ws: c, + cb: cb, + id: uuid.New().String(), + scatterc: make(chan message, 20), }, err } @@ -50,18 +58,33 @@ func (s *session) Run() error { func (s *session) gather() { s.while(func() error { - mt, message, err := s.ws.ReadMessage() + mt, msg, err := s.ws.ReadMessage() if err != nil { return err } - log.Println(" read:", mt, message) // TODO - return nil + if mt != 1 { + return nil + } + + var m message + if err := json.Unmarshal(msg, &m); err != nil { + return err + } + + log.Printf("gathered %+v", m) + return s.cb(m) }) } func (s *session) scatter() { s.while(func() error { - return s.ws.WriteMessage(1, []byte("message")) // TODO + select { + case m := <-s.scatterc: + log.Printf("scattering %+v", m) + return s.ws.WriteMessage(1, []byte(m.Text)) + case <-s.ctx.Done(): + return s.ctx.Err() + } }) }