package handlers import ( "encoding/json" "log" "net/http" "sync" "time" "github.com/gorilla/websocket" "task-dashboard/internal/models" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { // Allow all connections (CORS disabled for WebSocket) return true }, } // WSClient represents a connected browser type WSClient struct { conn *websocket.Conn send chan []byte } // WSHub manages WebSocket connections type WSHub struct { clients map[*WSClient]bool broadcast chan []byte register chan *WSClient unregister chan *WSClient mu sync.RWMutex } var wsHub *WSHub var hubOnce sync.Once // GetWSHub returns the singleton WebSocket hub func GetWSHub() *WSHub { hubOnce.Do(func() { wsHub = &WSHub{ clients: make(map[*WSClient]bool), broadcast: make(chan []byte, 256), register: make(chan *WSClient), unregister: make(chan *WSClient), } go wsHub.run() }) return wsHub } // run handles hub operations func (h *WSHub) run() { for { select { case client := <-h.register: h.mu.Lock() h.clients[client] = true h.mu.Unlock() log.Printf("WebSocket client connected, total: %d", len(h.clients)) case client := <-h.unregister: h.mu.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) close(client.send) } h.mu.Unlock() log.Printf("WebSocket client disconnected, total: %d", len(h.clients)) case message := <-h.broadcast: h.mu.RLock() for client := range h.clients { select { case client.send <- message: default: // Client buffer full, skip } } h.mu.RUnlock() } } } // Broadcast sends a message to all connected clients func (h *WSHub) Broadcast(msg []byte) { h.broadcast <- msg } // ClientCount returns the number of connected clients func (h *WSHub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } // WSMessage is the structure for WebSocket messages type WSMessage struct { Type string `json:"type"` Payload interface{} `json:"payload"` } // BroadcastAgentRequest sends an agent request notification to all browsers func (h *Handler) BroadcastAgentRequest(session *models.AgentSession, trustLevel models.AgentTrustLevel) { hub := GetWSHub() payload := AgentRequestPayload{ RequestToken: session.RequestToken, AgentName: session.AgentName, AgentID: session.AgentID, TrustLevel: trustLevel, ExpiresAt: session.ExpiresAt, } msg := WSMessage{ Type: "agent_request", Payload: payload, } data, err := json.Marshal(msg) if err != nil { log.Printf("Failed to marshal WebSocket message: %v", err) return } hub.Broadcast(data) log.Printf("Broadcasted agent request from %s to %d clients", session.AgentName, hub.ClientCount()) } // HandleWebSocket handles the WebSocket connection at /ws/notifications func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { // Check if the request is a WebSocket upgrade request if !websocket.IsWebSocketUpgrade(r) { http.Error(w, "Expected WebSocket Upgrade request", http.StatusBadRequest) return } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("WebSocket upgrade failed: %v", err) return } hub := GetWSHub() client := &WSClient{ conn: conn, send: make(chan []byte, 256), } hub.register <- client // Start goroutines for reading and writing go client.writePump() go client.readPump(hub) } // writePump sends messages to the client func (c *WSClient) writePump() { ticker := time.NewTicker(30 * time.Second) defer func() { ticker.Stop() c.conn.Close() }() for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if !ok { // Channel closed c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { return } case <-ticker.C: // Send ping c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // readPump reads messages from the client func (c *WSClient) readPump(hub *WSHub) { defer func() { hub.unregister <- c c.conn.Close() }() c.conn.SetReadLimit(512) c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) for { _, _, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error: %v", err) } break } } }