diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-03-08 21:03:50 +0000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-03-08 21:03:50 +0000 |
| commit | 632ea5a44731af94b6238f330a3b5440906c8ae7 (patch) | |
| tree | d8c780412598d66b89ef390b5729e379fdfd9d5b /internal/api/websocket.go | |
| parent | 406247b14985ab57902e8e42898dc8cb8960290d (diff) | |
| parent | 93a4c852bf726b00e8014d385165f847763fa214 (diff) | |
merge: pull latest from master and resolve conflicts
- Resolve conflicts in API server, CLI, and executor.
- Maintain Gemini classification and assignment logic.
- Update UI to use generic agent config and project_dir.
- Fix ProjectDir/WorkingDir inconsistencies in Gemini runner.
- All tests passing after merge.
Diffstat (limited to 'internal/api/websocket.go')
| -rw-r--r-- | internal/api/websocket.go | 71 |
1 files changed, 67 insertions, 4 deletions
diff --git a/internal/api/websocket.go b/internal/api/websocket.go index 6bd8c88..b5bf728 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -1,13 +1,27 @@ package api import ( + "errors" "log/slog" "net/http" + "strings" "sync" + "time" "golang.org/x/net/websocket" ) +// wsPingInterval and wsPingDeadline control heartbeat timing. +// Exposed as vars so tests can override them without rebuilding. +var ( + wsPingInterval = 30 * time.Second + wsPingDeadline = 10 * time.Second + + // maxWsClients caps the number of concurrent WebSocket connections. + // Exposed as a var so tests can override it. + maxWsClients = 1000 +) + // Hub manages WebSocket connections and broadcasts messages. type Hub struct { mu sync.RWMutex @@ -25,10 +39,14 @@ func NewHub() *Hub { // Run is a no-op loop kept for future cleanup/heartbeat logic. func (h *Hub) Run() {} -func (h *Hub) Register(ws *websocket.Conn) { +func (h *Hub) Register(ws *websocket.Conn) error { h.mu.Lock() + defer h.mu.Unlock() + if len(h.clients) >= maxWsClients { + return errors.New("max WebSocket clients reached") + } h.clients[ws] = true - h.mu.Unlock() + return nil } func (h *Hub) Unregister(ws *websocket.Conn) { @@ -56,11 +74,56 @@ func (h *Hub) ClientCount() int { } func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + if s.hub.ClientCount() >= maxWsClients { + http.Error(w, "too many connections", http.StatusServiceUnavailable) + return + } + + if s.apiToken != "" { + token := r.URL.Query().Get("token") + if token == "" { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token = strings.TrimPrefix(auth, "Bearer ") + } + } + if token != s.apiToken { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + handler := websocket.Handler(func(ws *websocket.Conn) { - s.hub.Register(ws) + if err := s.hub.Register(ws); err != nil { + return + } defer s.hub.Unregister(ws) - // Keep connection alive until client disconnects. + // Ping goroutine: detect dead connections by sending periodic pings. + // A write failure (including write deadline exceeded) closes the conn, + // causing the read loop below to exit and unregister the client. + done := make(chan struct{}) + defer close(done) + go func() { + ticker := time.NewTicker(wsPingInterval) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + ws.SetWriteDeadline(time.Now().Add(wsPingDeadline)) + err := websocket.Message.Send(ws, "ping") + ws.SetWriteDeadline(time.Time{}) + if err != nil { + ws.Close() + return + } + } + } + }() + + // Keep connection alive until client disconnects or ping fails. buf := make([]byte, 1024) for { if _, err := ws.Read(buf); err != nil { |
