package api import ( "errors" "log/slog" "net/http" "strings" "sync" "time" "golang.org/x/net/websocket" ) // wsPingInterval and wsPingDeadline control heartbeat timing. // Exposed as vars so tests can override them without rebuilding. var ( wsPingInterval = 30 * time.Second wsPingDeadline = 10 * time.Second // maxWsClients caps the number of concurrent WebSocket connections. // Exposed as a var so tests can override it. maxWsClients = 1000 ) // Hub manages WebSocket connections and broadcasts messages. type Hub struct { mu sync.RWMutex clients map[*websocket.Conn]bool logger *slog.Logger } func NewHub() *Hub { return &Hub{ clients: make(map[*websocket.Conn]bool), logger: slog.Default(), } } // Run is a no-op loop kept for future cleanup/heartbeat logic. func (h *Hub) Run() {} func (h *Hub) Register(ws *websocket.Conn) error { h.mu.Lock() defer h.mu.Unlock() if len(h.clients) >= maxWsClients { return errors.New("max WebSocket clients reached") } h.clients[ws] = true return nil } func (h *Hub) Unregister(ws *websocket.Conn) { h.mu.Lock() delete(h.clients, ws) h.mu.Unlock() } // Broadcast sends a message to all connected WebSocket clients. func (h *Hub) Broadcast(msg []byte) { h.mu.RLock() defer h.mu.RUnlock() for conn := range h.clients { if _, err := conn.Write(msg); err != nil { h.logger.Error("websocket write error", "error", err) } } } // ClientCount returns the number of connected clients. func (h *Hub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { if s.hub.ClientCount() >= maxWsClients { http.Error(w, "too many connections", http.StatusServiceUnavailable) return } if s.apiToken != "" { token := r.URL.Query().Get("token") if token == "" { auth := r.Header.Get("Authorization") if strings.HasPrefix(auth, "Bearer ") { token = strings.TrimPrefix(auth, "Bearer ") } } if token != s.apiToken { http.Error(w, "unauthorized", http.StatusUnauthorized) return } } handler := websocket.Handler(func(ws *websocket.Conn) { if err := s.hub.Register(ws); err != nil { return } defer s.hub.Unregister(ws) // Ping goroutine: detect dead connections by sending periodic pings. // A write failure (including write deadline exceeded) closes the conn, // causing the read loop below to exit and unregister the client. done := make(chan struct{}) defer close(done) go func() { ticker := time.NewTicker(wsPingInterval) defer ticker.Stop() for { select { case <-done: return case <-ticker.C: ws.SetWriteDeadline(time.Now().Add(wsPingDeadline)) err := websocket.Message.Send(ws, "ping") ws.SetWriteDeadline(time.Time{}) if err != nil { ws.Close() return } } } }() // Keep connection alive until client disconnects or ping fails. buf := make([]byte, 1024) for { if _, err := ws.Read(buf); err != nil { break } } }) handler.ServeHTTP(w, r) }