diff --git a/go.mod b/go.mod
index a10a9f1..02c5cda 100644
--- a/go.mod
+++ b/go.mod
@@ -6,3 +6,5 @@ require (
github.com/gorilla/websocket v1.5.3
golang.org/x/time v0.14.0
)
+
+require github.com/google/uuid v1.6.0 // indirect
diff --git a/go.sum b/go.sum
index 296e66c..4290850 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,5 @@
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
diff --git a/src/public/index.html b/src/public/index.html
index c6c0c2f..2dd0649 100644
--- a/src/public/index.html
+++ b/src/public/index.html
@@ -1,77 +1,55 @@
-
-
-
-
-
-
-|
- Click "Open" to create a connection to the server,
-"Send" to send a message to the server and "Close" to close the connection.
-You can change the message and send multiple times.
-
-
- |
-
- |
-
+
+
+
+
+
+
+
diff --git a/src/server/message.go b/src/server/message.go
new file mode 100644
index 0000000..6eb841e
--- /dev/null
+++ b/src/server/message.go
@@ -0,0 +1,5 @@
+package server
+
+type message struct {
+ Text string
+}
diff --git a/src/server/server.go b/src/server/server.go
index 8229ace..855ef6d 100644
--- a/src/server/server.go
+++ b/src/server/server.go
@@ -5,13 +5,15 @@ import (
"net/http"
)
-type Server struct{}
-
-func NewServer() Server {
- return Server{}
+type Server struct {
+ sessions []*session
}
-func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+func NewServer() *Server {
+ return &Server{}
+}
+
+func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/ws":
if err := s.WS(w, r); err != nil {
@@ -22,12 +24,35 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
-func (s Server) WS(w http.ResponseWriter, r *http.Request) error {
- sess, err := newSession(w, r)
+func (s *Server) WS(w http.ResponseWriter, r *http.Request) error {
+ sess, err := newSession(w, r, nil)
if err != nil {
return err
}
defer sess.Close()
+ sess.cb = func(m message) error {
+ log.Printf("cbing to all other sessions %+v", m)
+ for i := range s.sessions {
+ if s.sessions[i].id != sess.id {
+ select {
+ case s.sessions[i].scatterc <- m:
+ case <-s.sessions[i].ctx.Done():
+ }
+ }
+ }
+ return nil
+ }
+
+ s.sessions = append(s.sessions, sess)
+ defer func() {
+ for i := range s.sessions {
+ if s.sessions[i].id == sess.id {
+ s.sessions = append(s.sessions[:i], s.sessions[i+1:]...)
+ return
+ }
+ }
+ }()
+
return sess.Run()
}
diff --git a/src/server/session.go b/src/server/session.go
index 84e6e28..fab1277 100644
--- a/src/server/session.go
+++ b/src/server/session.go
@@ -2,30 +2,38 @@ package server
import (
"context"
+ "encoding/json"
"log"
"net/http"
"sync"
+ "github.com/google/uuid"
"github.com/gorilla/websocket"
"golang.org/x/time/rate"
)
type session struct {
- ctx context.Context
- can context.CancelFunc
- ws *websocket.Conn
- wg sync.WaitGroup
+ ctx context.Context
+ can context.CancelFunc
+ ws *websocket.Conn
+ wg sync.WaitGroup
+ cb func(message) error
+ id string
+ scatterc chan (message)
}
var upgrader = websocket.Upgrader{}
-func newSession(w http.ResponseWriter, r *http.Request) (*session, error) {
+func newSession(w http.ResponseWriter, r *http.Request, cb func(message) error) (*session, error) {
c, err := upgrader.Upgrade(w, r, nil)
ctx, can := context.WithCancel(r.Context())
return &session{
- ctx: ctx,
- can: can,
- ws: c,
+ ctx: ctx,
+ can: can,
+ ws: c,
+ cb: cb,
+ id: uuid.New().String(),
+ scatterc: make(chan message, 20),
}, err
}
@@ -50,18 +58,33 @@ func (s *session) Run() error {
func (s *session) gather() {
s.while(func() error {
- mt, message, err := s.ws.ReadMessage()
+ mt, msg, err := s.ws.ReadMessage()
if err != nil {
return err
}
- log.Println(" read:", mt, message) // TODO
- return nil
+ if mt != 1 {
+ return nil
+ }
+
+ var m message
+ if err := json.Unmarshal(msg, &m); err != nil {
+ return err
+ }
+
+ log.Printf("gathered %+v", m)
+ return s.cb(m)
})
}
func (s *session) scatter() {
s.while(func() error {
- return s.ws.WriteMessage(1, []byte("message")) // TODO
+ select {
+ case m := <-s.scatterc:
+ log.Printf("scattering %+v", m)
+ return s.ws.WriteMessage(1, []byte(m.Text))
+ case <-s.ctx.Done():
+ return s.ctx.Err()
+ }
})
}