summaryrefslogtreecommitdiff
path: root/internal/handlers/websocket.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/handlers/websocket.go')
-rw-r--r--internal/handlers/websocket.go216
1 files changed, 216 insertions, 0 deletions
diff --git a/internal/handlers/websocket.go b/internal/handlers/websocket.go
new file mode 100644
index 0000000..1677f88
--- /dev/null
+++ b/internal/handlers/websocket.go
@@ -0,0 +1,216 @@
+package handlers
+
+import (
+ "encoding/json"
+ "log"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/gorilla/websocket"
+
+ "task-dashboard/internal/models"
+)
+
+var upgrader = websocket.Upgrader{
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ CheckOrigin: func(r *http.Request) bool {
+ // Allow all connections (CORS disabled for WebSocket)
+ return true
+ },
+}
+
+// WSClient represents a connected browser
+type WSClient struct {
+ conn *websocket.Conn
+ send chan []byte
+}
+
+// WSHub manages WebSocket connections
+type WSHub struct {
+ clients map[*WSClient]bool
+ broadcast chan []byte
+ register chan *WSClient
+ unregister chan *WSClient
+ mu sync.RWMutex
+}
+
+var wsHub *WSHub
+var hubOnce sync.Once
+
+// GetWSHub returns the singleton WebSocket hub
+func GetWSHub() *WSHub {
+ hubOnce.Do(func() {
+ wsHub = &WSHub{
+ clients: make(map[*WSClient]bool),
+ broadcast: make(chan []byte, 256),
+ register: make(chan *WSClient),
+ unregister: make(chan *WSClient),
+ }
+ go wsHub.run()
+ })
+ return wsHub
+}
+
+// run handles hub operations
+func (h *WSHub) run() {
+ for {
+ select {
+ case client := <-h.register:
+ h.mu.Lock()
+ h.clients[client] = true
+ h.mu.Unlock()
+ log.Printf("WebSocket client connected, total: %d", len(h.clients))
+
+ case client := <-h.unregister:
+ h.mu.Lock()
+ if _, ok := h.clients[client]; ok {
+ delete(h.clients, client)
+ close(client.send)
+ }
+ h.mu.Unlock()
+ log.Printf("WebSocket client disconnected, total: %d", len(h.clients))
+
+ case message := <-h.broadcast:
+ h.mu.RLock()
+ for client := range h.clients {
+ select {
+ case client.send <- message:
+ default:
+ // Client buffer full, skip
+ }
+ }
+ h.mu.RUnlock()
+ }
+ }
+}
+
+// Broadcast sends a message to all connected clients
+func (h *WSHub) Broadcast(msg []byte) {
+ h.broadcast <- msg
+}
+
+// ClientCount returns the number of connected clients
+func (h *WSHub) ClientCount() int {
+ h.mu.RLock()
+ defer h.mu.RUnlock()
+ return len(h.clients)
+}
+
+// WSMessage is the structure for WebSocket messages
+type WSMessage struct {
+ Type string `json:"type"`
+ Payload interface{} `json:"payload"`
+}
+
+// BroadcastAgentRequest sends an agent request notification to all browsers
+func (h *Handler) BroadcastAgentRequest(session *models.AgentSession, trustLevel models.AgentTrustLevel) {
+ hub := GetWSHub()
+
+ payload := AgentRequestPayload{
+ RequestToken: session.RequestToken,
+ AgentName: session.AgentName,
+ AgentID: session.AgentID,
+ TrustLevel: trustLevel,
+ ExpiresAt: session.ExpiresAt,
+ }
+
+ msg := WSMessage{
+ Type: "agent_request",
+ Payload: payload,
+ }
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ log.Printf("Failed to marshal WebSocket message: %v", err)
+ return
+ }
+
+ hub.Broadcast(data)
+ log.Printf("Broadcasted agent request from %s to %d clients", session.AgentName, hub.ClientCount())
+}
+
+// HandleWebSocket handles the WebSocket connection at /ws/notifications
+func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
+ // Check if the request is a WebSocket upgrade request
+ if !websocket.IsWebSocketUpgrade(r) {
+ http.Error(w, "Expected WebSocket Upgrade request", http.StatusBadRequest)
+ return
+ }
+
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ log.Printf("WebSocket upgrade failed: %v", err)
+ return
+ }
+
+ hub := GetWSHub()
+ client := &WSClient{
+ conn: conn,
+ send: make(chan []byte, 256),
+ }
+
+ hub.register <- client
+
+ // Start goroutines for reading and writing
+ go client.writePump()
+ go client.readPump(hub)
+}
+
+// writePump sends messages to the client
+func (c *WSClient) writePump() {
+ ticker := time.NewTicker(30 * time.Second)
+ defer func() {
+ ticker.Stop()
+ c.conn.Close()
+ }()
+
+ for {
+ select {
+ case message, ok := <-c.send:
+ c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
+ if !ok {
+ // Channel closed
+ c.conn.WriteMessage(websocket.CloseMessage, []byte{})
+ return
+ }
+
+ if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
+ return
+ }
+
+ case <-ticker.C:
+ // Send ping
+ c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
+ if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
+ return
+ }
+ }
+ }
+}
+
+// readPump reads messages from the client
+func (c *WSClient) readPump(hub *WSHub) {
+ defer func() {
+ hub.unregister <- c
+ c.conn.Close()
+ }()
+
+ c.conn.SetReadLimit(512)
+ c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
+ c.conn.SetPongHandler(func(string) error {
+ c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
+ return nil
+ })
+
+ for {
+ _, _, err := c.conn.ReadMessage()
+ if err != nil {
+ if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
+ log.Printf("WebSocket error: %v", err)
+ }
+ break
+ }
+ }
+}