diff options
Diffstat (limited to 'internal/api/websocket.go')
| -rw-r--r-- | internal/api/websocket.go | 37 |
1 files changed, 24 insertions, 13 deletions
diff --git a/internal/api/websocket.go b/internal/api/websocket.go index b5bf728..25522dc 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -16,33 +16,40 @@ import ( var ( wsPingInterval = 30 * time.Second wsPingDeadline = 10 * time.Second - - // maxWsClients caps the number of concurrent WebSocket connections. - // Exposed as a var so tests can override it. - maxWsClients = 1000 ) +const defaultMaxWsClients = 1000 + // Hub manages WebSocket connections and broadcasts messages. type Hub struct { - mu sync.RWMutex - clients map[*websocket.Conn]bool - logger *slog.Logger + mu sync.RWMutex + clients map[*websocket.Conn]bool + maxClients int + logger *slog.Logger } func NewHub() *Hub { return &Hub{ - clients: make(map[*websocket.Conn]bool), - logger: slog.Default(), + clients: make(map[*websocket.Conn]bool), + maxClients: defaultMaxWsClients, + logger: slog.Default(), } } // Run is a no-op loop kept for future cleanup/heartbeat logic. func (h *Hub) Run() {} +// SetMaxClients configures the maximum number of concurrent WebSocket clients. +func (h *Hub) SetMaxClients(n int) { + h.mu.Lock() + h.maxClients = n + h.mu.Unlock() +} + func (h *Hub) Register(ws *websocket.Conn) error { h.mu.Lock() defer h.mu.Unlock() - if len(h.clients) >= maxWsClients { + if len(h.clients) >= h.maxClients { return errors.New("max WebSocket clients reached") } h.clients[ws] = true @@ -74,7 +81,10 @@ func (h *Hub) ClientCount() int { } func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { - if s.hub.ClientCount() >= maxWsClients { + s.hub.mu.RLock() + atCap := len(s.hub.clients) >= s.hub.maxClients + s.hub.mu.RUnlock() + if atCap { http.Error(w, "too many connections", http.StatusServiceUnavailable) return } @@ -104,15 +114,16 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { // causing the read loop below to exit and unregister the client. done := make(chan struct{}) defer close(done) + pingInterval, pingDeadline := wsPingInterval, wsPingDeadline // capture before goroutine starts go func() { - ticker := time.NewTicker(wsPingInterval) + ticker := time.NewTicker(pingInterval) defer ticker.Stop() for { select { case <-done: return case <-ticker.C: - ws.SetWriteDeadline(time.Now().Add(wsPingDeadline)) + ws.SetWriteDeadline(time.Now().Add(pingDeadline)) err := websocket.Message.Send(ws, "ping") ws.SetWriteDeadline(time.Time{}) if err != nil { |
