summaryrefslogtreecommitdiff
path: root/internal/api/websocket.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/websocket.go')
-rw-r--r--internal/api/websocket.go37
1 files changed, 24 insertions, 13 deletions
diff --git a/internal/api/websocket.go b/internal/api/websocket.go
index b5bf728..25522dc 100644
--- a/internal/api/websocket.go
+++ b/internal/api/websocket.go
@@ -16,33 +16,40 @@ import (
var (
wsPingInterval = 30 * time.Second
wsPingDeadline = 10 * time.Second
-
- // maxWsClients caps the number of concurrent WebSocket connections.
- // Exposed as a var so tests can override it.
- maxWsClients = 1000
)
+const defaultMaxWsClients = 1000
+
// Hub manages WebSocket connections and broadcasts messages.
type Hub struct {
- mu sync.RWMutex
- clients map[*websocket.Conn]bool
- logger *slog.Logger
+ mu sync.RWMutex
+ clients map[*websocket.Conn]bool
+ maxClients int
+ logger *slog.Logger
}
func NewHub() *Hub {
return &Hub{
- clients: make(map[*websocket.Conn]bool),
- logger: slog.Default(),
+ clients: make(map[*websocket.Conn]bool),
+ maxClients: defaultMaxWsClients,
+ logger: slog.Default(),
}
}
// Run is a no-op loop kept for future cleanup/heartbeat logic.
func (h *Hub) Run() {}
+// SetMaxClients configures the maximum number of concurrent WebSocket clients.
+func (h *Hub) SetMaxClients(n int) {
+ h.mu.Lock()
+ h.maxClients = n
+ h.mu.Unlock()
+}
+
func (h *Hub) Register(ws *websocket.Conn) error {
h.mu.Lock()
defer h.mu.Unlock()
- if len(h.clients) >= maxWsClients {
+ if len(h.clients) >= h.maxClients {
return errors.New("max WebSocket clients reached")
}
h.clients[ws] = true
@@ -74,7 +81,10 @@ func (h *Hub) ClientCount() int {
}
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
- if s.hub.ClientCount() >= maxWsClients {
+ s.hub.mu.RLock()
+ atCap := len(s.hub.clients) >= s.hub.maxClients
+ s.hub.mu.RUnlock()
+ if atCap {
http.Error(w, "too many connections", http.StatusServiceUnavailable)
return
}
@@ -104,15 +114,16 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
// causing the read loop below to exit and unregister the client.
done := make(chan struct{})
defer close(done)
+ pingInterval, pingDeadline := wsPingInterval, wsPingDeadline // capture before goroutine starts
go func() {
- ticker := time.NewTicker(wsPingInterval)
+ ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
- ws.SetWriteDeadline(time.Now().Add(wsPingDeadline))
+ ws.SetWriteDeadline(time.Now().Add(pingDeadline))
err := websocket.Message.Send(ws, "ping")
ws.SetWriteDeadline(time.Time{})
if err != nil {