summaryrefslogtreecommitdiff
path: root/internal/api/websocket.go
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-08 20:40:15 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-08 20:40:15 +0000
commit363fc9ead6276cba51b4a72b4349d49ce7ca0f3d (patch)
tree001ac77ba0896720fde1202dfb588b8d2bdc73fc /internal/api/websocket.go
parent2cf6d97593d8a45c412f7d546abbaaeb23db0fd1 (diff)
api: WebSocket auth, client cap, and ping keepalive
- Require bearer token on WebSocket connections when apiToken is set - Cap concurrent WebSocket clients at maxWsClients (1000, overridable) - Send periodic pings every 30s; close dead connections after 10s write deadline Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal/api/websocket.go')
-rw-r--r--internal/api/websocket.go71
1 files changed, 67 insertions, 4 deletions
diff --git a/internal/api/websocket.go b/internal/api/websocket.go
index 6bd8c88..b5bf728 100644
--- a/internal/api/websocket.go
+++ b/internal/api/websocket.go
@@ -1,13 +1,27 @@
package api
import (
+ "errors"
"log/slog"
"net/http"
+ "strings"
"sync"
+ "time"
"golang.org/x/net/websocket"
)
+// wsPingInterval and wsPingDeadline control heartbeat timing.
+// Exposed as vars so tests can override them without rebuilding.
+var (
+ wsPingInterval = 30 * time.Second
+ wsPingDeadline = 10 * time.Second
+
+ // maxWsClients caps the number of concurrent WebSocket connections.
+ // Exposed as a var so tests can override it.
+ maxWsClients = 1000
+)
+
// Hub manages WebSocket connections and broadcasts messages.
type Hub struct {
mu sync.RWMutex
@@ -25,10 +39,14 @@ func NewHub() *Hub {
// Run is a no-op loop kept for future cleanup/heartbeat logic.
func (h *Hub) Run() {}
-func (h *Hub) Register(ws *websocket.Conn) {
+func (h *Hub) Register(ws *websocket.Conn) error {
h.mu.Lock()
+ defer h.mu.Unlock()
+ if len(h.clients) >= maxWsClients {
+ return errors.New("max WebSocket clients reached")
+ }
h.clients[ws] = true
- h.mu.Unlock()
+ return nil
}
func (h *Hub) Unregister(ws *websocket.Conn) {
@@ -56,11 +74,56 @@ func (h *Hub) ClientCount() int {
}
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
+ if s.hub.ClientCount() >= maxWsClients {
+ http.Error(w, "too many connections", http.StatusServiceUnavailable)
+ return
+ }
+
+ if s.apiToken != "" {
+ token := r.URL.Query().Get("token")
+ if token == "" {
+ auth := r.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Bearer ") {
+ token = strings.TrimPrefix(auth, "Bearer ")
+ }
+ }
+ if token != s.apiToken {
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
+ return
+ }
+ }
+
handler := websocket.Handler(func(ws *websocket.Conn) {
- s.hub.Register(ws)
+ if err := s.hub.Register(ws); err != nil {
+ return
+ }
defer s.hub.Unregister(ws)
- // Keep connection alive until client disconnects.
+ // Ping goroutine: detect dead connections by sending periodic pings.
+ // A write failure (including write deadline exceeded) closes the conn,
+ // causing the read loop below to exit and unregister the client.
+ done := make(chan struct{})
+ defer close(done)
+ go func() {
+ ticker := time.NewTicker(wsPingInterval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-done:
+ return
+ case <-ticker.C:
+ ws.SetWriteDeadline(time.Now().Add(wsPingDeadline))
+ err := websocket.Message.Send(ws, "ping")
+ ws.SetWriteDeadline(time.Time{})
+ if err != nil {
+ ws.Close()
+ return
+ }
+ }
+ }
+ }()
+
+ // Keep connection alive until client disconnects or ping fails.
buf := make([]byte, 1024)
for {
if _, err := ws.Read(buf); err != nil {