summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/websocket.go37
-rw-r--r--internal/api/websocket_test.go5
2 files changed, 25 insertions, 17 deletions
diff --git a/internal/api/websocket.go b/internal/api/websocket.go
index b5bf728..25522dc 100644
--- a/internal/api/websocket.go
+++ b/internal/api/websocket.go
@@ -16,33 +16,40 @@ import (
var (
wsPingInterval = 30 * time.Second
wsPingDeadline = 10 * time.Second
-
- // maxWsClients caps the number of concurrent WebSocket connections.
- // Exposed as a var so tests can override it.
- maxWsClients = 1000
)
+const defaultMaxWsClients = 1000
+
// Hub manages WebSocket connections and broadcasts messages.
type Hub struct {
- mu sync.RWMutex
- clients map[*websocket.Conn]bool
- logger *slog.Logger
+ mu sync.RWMutex
+ clients map[*websocket.Conn]bool
+ maxClients int
+ logger *slog.Logger
}
func NewHub() *Hub {
return &Hub{
- clients: make(map[*websocket.Conn]bool),
- logger: slog.Default(),
+ clients: make(map[*websocket.Conn]bool),
+ maxClients: defaultMaxWsClients,
+ logger: slog.Default(),
}
}
// Run is a no-op loop kept for future cleanup/heartbeat logic.
func (h *Hub) Run() {}
+// SetMaxClients configures the maximum number of concurrent WebSocket clients.
+func (h *Hub) SetMaxClients(n int) {
+ h.mu.Lock()
+ h.maxClients = n
+ h.mu.Unlock()
+}
+
func (h *Hub) Register(ws *websocket.Conn) error {
h.mu.Lock()
defer h.mu.Unlock()
- if len(h.clients) >= maxWsClients {
+ if len(h.clients) >= h.maxClients {
return errors.New("max WebSocket clients reached")
}
h.clients[ws] = true
@@ -74,7 +81,10 @@ func (h *Hub) ClientCount() int {
}
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
- if s.hub.ClientCount() >= maxWsClients {
+ s.hub.mu.RLock()
+ atCap := len(s.hub.clients) >= s.hub.maxClients
+ s.hub.mu.RUnlock()
+ if atCap {
http.Error(w, "too many connections", http.StatusServiceUnavailable)
return
}
@@ -104,15 +114,16 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
// causing the read loop below to exit and unregister the client.
done := make(chan struct{})
defer close(done)
+ pingInterval, pingDeadline := wsPingInterval, wsPingDeadline // capture before goroutine starts
go func() {
- ticker := time.NewTicker(wsPingInterval)
+ ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
- ws.SetWriteDeadline(time.Now().Add(wsPingDeadline))
+ ws.SetWriteDeadline(time.Now().Add(pingDeadline))
err := websocket.Message.Send(ws, "ping")
ws.SetWriteDeadline(time.Time{})
if err != nil {
diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go
index 72b83f2..682a555 100644
--- a/internal/api/websocket_test.go
+++ b/internal/api/websocket_test.go
@@ -99,11 +99,8 @@ func TestWebSocket_NoTokenConfigured(t *testing.T) {
// TestWebSocket_RejectsConnectionWhenAtMaxClients verifies that when the hub
// is at capacity, new WebSocket upgrade requests are rejected with 503.
func TestWebSocket_RejectsConnectionWhenAtMaxClients(t *testing.T) {
- orig := maxWsClients
- maxWsClients = 0 // immediately at capacity
- t.Cleanup(func() { maxWsClients = orig })
-
srv, _ := testServer(t)
+ srv.hub.SetMaxClients(0) // immediately at capacity
srv.StartHub()
req := httptest.NewRequest("GET", "/api/ws", nil)