diff options
Diffstat (limited to 'internal/handlers/websocket.go')
| -rw-r--r-- | internal/handlers/websocket.go | 216 |
1 files changed, 216 insertions, 0 deletions
diff --git a/internal/handlers/websocket.go b/internal/handlers/websocket.go new file mode 100644 index 0000000..1677f88 --- /dev/null +++ b/internal/handlers/websocket.go @@ -0,0 +1,216 @@ +package handlers + +import ( + "encoding/json" + "log" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + + "task-dashboard/internal/models" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + // Allow all connections (CORS disabled for WebSocket) + return true + }, +} + +// WSClient represents a connected browser +type WSClient struct { + conn *websocket.Conn + send chan []byte +} + +// WSHub manages WebSocket connections +type WSHub struct { + clients map[*WSClient]bool + broadcast chan []byte + register chan *WSClient + unregister chan *WSClient + mu sync.RWMutex +} + +var wsHub *WSHub +var hubOnce sync.Once + +// GetWSHub returns the singleton WebSocket hub +func GetWSHub() *WSHub { + hubOnce.Do(func() { + wsHub = &WSHub{ + clients: make(map[*WSClient]bool), + broadcast: make(chan []byte, 256), + register: make(chan *WSClient), + unregister: make(chan *WSClient), + } + go wsHub.run() + }) + return wsHub +} + +// run handles hub operations +func (h *WSHub) run() { + for { + select { + case client := <-h.register: + h.mu.Lock() + h.clients[client] = true + h.mu.Unlock() + log.Printf("WebSocket client connected, total: %d", len(h.clients)) + + case client := <-h.unregister: + h.mu.Lock() + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) + } + h.mu.Unlock() + log.Printf("WebSocket client disconnected, total: %d", len(h.clients)) + + case message := <-h.broadcast: + h.mu.RLock() + for client := range h.clients { + select { + case client.send <- message: + default: + // Client buffer full, skip + } + } + h.mu.RUnlock() + } + } +} + +// Broadcast sends a message to all connected clients +func (h *WSHub) Broadcast(msg []byte) { + h.broadcast <- msg +} + +// ClientCount returns the number of connected clients +func (h *WSHub) ClientCount() int { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.clients) +} + +// WSMessage is the structure for WebSocket messages +type WSMessage struct { + Type string `json:"type"` + Payload interface{} `json:"payload"` +} + +// BroadcastAgentRequest sends an agent request notification to all browsers +func (h *Handler) BroadcastAgentRequest(session *models.AgentSession, trustLevel models.AgentTrustLevel) { + hub := GetWSHub() + + payload := AgentRequestPayload{ + RequestToken: session.RequestToken, + AgentName: session.AgentName, + AgentID: session.AgentID, + TrustLevel: trustLevel, + ExpiresAt: session.ExpiresAt, + } + + msg := WSMessage{ + Type: "agent_request", + Payload: payload, + } + + data, err := json.Marshal(msg) + if err != nil { + log.Printf("Failed to marshal WebSocket message: %v", err) + return + } + + hub.Broadcast(data) + log.Printf("Broadcasted agent request from %s to %d clients", session.AgentName, hub.ClientCount()) +} + +// HandleWebSocket handles the WebSocket connection at /ws/notifications +func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { + // Check if the request is a WebSocket upgrade request + if !websocket.IsWebSocketUpgrade(r) { + http.Error(w, "Expected WebSocket Upgrade request", http.StatusBadRequest) + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket upgrade failed: %v", err) + return + } + + hub := GetWSHub() + client := &WSClient{ + conn: conn, + send: make(chan []byte, 256), + } + + hub.register <- client + + // Start goroutines for reading and writing + go client.writePump() + go client.readPump(hub) +} + +// writePump sends messages to the client +func (c *WSClient) writePump() { + ticker := time.NewTicker(30 * time.Second) + defer func() { + ticker.Stop() + c.conn.Close() + }() + + for { + select { + case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if !ok { + // Channel closed + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { + return + } + + case <-ticker.C: + // Send ping + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +// readPump reads messages from the client +func (c *WSClient) readPump(hub *WSHub) { + defer func() { + hub.unregister <- c + c.conn.Close() + }() + + c.conn.SetReadLimit(512) + c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + c.conn.SetPongHandler(func(string) error { + c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + for { + _, _, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket error: %v", err) + } + break + } + } +} |
