a chat room

This commit is contained in:
bel
2025-10-14 22:04:30 -06:00
parent c9c4800d68
commit 13b583a77e
6 changed files with 128 additions and 93 deletions

5
src/server/message.go Normal file
View File

@@ -0,0 +1,5 @@
package server
type message struct {
Text string
}

View File

@@ -5,13 +5,15 @@ import (
"net/http"
)
type Server struct{}
func NewServer() Server {
return Server{}
type Server struct {
sessions []*session
}
func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func NewServer() *Server {
return &Server{}
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/ws":
if err := s.WS(w, r); err != nil {
@@ -22,12 +24,35 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (s Server) WS(w http.ResponseWriter, r *http.Request) error {
sess, err := newSession(w, r)
func (s *Server) WS(w http.ResponseWriter, r *http.Request) error {
sess, err := newSession(w, r, nil)
if err != nil {
return err
}
defer sess.Close()
sess.cb = func(m message) error {
log.Printf("cbing to all other sessions %+v", m)
for i := range s.sessions {
if s.sessions[i].id != sess.id {
select {
case s.sessions[i].scatterc <- m:
case <-s.sessions[i].ctx.Done():
}
}
}
return nil
}
s.sessions = append(s.sessions, sess)
defer func() {
for i := range s.sessions {
if s.sessions[i].id == sess.id {
s.sessions = append(s.sessions[:i], s.sessions[i+1:]...)
return
}
}
}()
return sess.Run()
}

View File

@@ -2,30 +2,38 @@ package server
import (
"context"
"encoding/json"
"log"
"net/http"
"sync"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"golang.org/x/time/rate"
)
type session struct {
ctx context.Context
can context.CancelFunc
ws *websocket.Conn
wg sync.WaitGroup
ctx context.Context
can context.CancelFunc
ws *websocket.Conn
wg sync.WaitGroup
cb func(message) error
id string
scatterc chan (message)
}
var upgrader = websocket.Upgrader{}
func newSession(w http.ResponseWriter, r *http.Request) (*session, error) {
func newSession(w http.ResponseWriter, r *http.Request, cb func(message) error) (*session, error) {
c, err := upgrader.Upgrade(w, r, nil)
ctx, can := context.WithCancel(r.Context())
return &session{
ctx: ctx,
can: can,
ws: c,
ctx: ctx,
can: can,
ws: c,
cb: cb,
id: uuid.New().String(),
scatterc: make(chan message, 20),
}, err
}
@@ -50,18 +58,33 @@ func (s *session) Run() error {
func (s *session) gather() {
s.while(func() error {
mt, message, err := s.ws.ReadMessage()
mt, msg, err := s.ws.ReadMessage()
if err != nil {
return err
}
log.Println(" read:", mt, message) // TODO
return nil
if mt != 1 {
return nil
}
var m message
if err := json.Unmarshal(msg, &m); err != nil {
return err
}
log.Printf("gathered %+v", m)
return s.cb(m)
})
}
func (s *session) scatter() {
s.while(func() error {
return s.ws.WriteMessage(1, []byte("message")) // TODO
select {
case m := <-s.scatterc:
log.Printf("scattering %+v", m)
return s.ws.WriteMessage(1, []byte(m.Text))
case <-s.ctx.Done():
return s.ctx.Err()
}
})
}