From 058ff7d699f088edb851336928dd3eea2934cc07 Mon Sep 17 00:00:00 2001 From: Peter Stone Date: Wed, 28 Jan 2026 22:18:40 -1000 Subject: Refactor agent handlers for simplicity and clarity - Reuse BuildTimeline() from timeline_logic.go instead of duplicating fetch logic (~60 lines removed) - Add section headers for code organization - Extract isSessionExpired() and renderAgentTemplate() helpers - Move AgentRequestPayload from websocket.go to agent.go - Use config.Now() and config.Today() for consistent timezone handling Co-Authored-By: Claude Opus 4.5 --- internal/handlers/agent.go | 560 +++++++++++++++++++++++++++++++ internal/handlers/agent_test.go | 706 ++++++++++++++++++++++++++++++++++++++++ internal/handlers/websocket.go | 216 ++++++++++++ 3 files changed, 1482 insertions(+) create mode 100644 internal/handlers/agent.go create mode 100644 internal/handlers/agent_test.go create mode 100644 internal/handlers/websocket.go (limited to 'internal') diff --git a/internal/handlers/agent.go b/internal/handlers/agent.go new file mode 100644 index 0000000..6f47524 --- /dev/null +++ b/internal/handlers/agent.go @@ -0,0 +1,560 @@ +package handlers + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "net/http" + "time" + + "task-dashboard/internal/config" + "task-dashboard/internal/models" +) + +// ----------------------------------------------------------------------------- +// Constants +// ----------------------------------------------------------------------------- + +const ( + AgentRequestExpiry = 5 * time.Minute + AgentSessionTTL = 1 * time.Hour + TokenBytes = 32 +) + +// Context key for agent session in request context +type contextKey string + +const agentSessionContextKey contextKey = "agent_session" + +// ----------------------------------------------------------------------------- +// Types +// ----------------------------------------------------------------------------- + +// AgentRequestPayload is sent via WebSocket when an agent requests access +type AgentRequestPayload struct { + RequestToken string `json:"request_token"` + AgentName string `json:"agent_name"` + AgentID string `json:"agent_id"` + TrustLevel models.AgentTrustLevel `json:"trust_level"` + ExpiresAt time.Time `json:"expires_at"` +} + +// agentContextItem is the JSON-serializable timeline item for agent context API +type agentContextItem struct { + ID string `json:"id"` + Source string `json:"source"` + Type string `json:"type"` + Title string `json:"title"` + Description string `json:"description,omitempty"` + Due *time.Time `json:"due,omitempty"` + Priority int `json:"priority,omitempty"` + Completable bool `json:"completable"` + URL string `json:"url,omitempty"` +} + +// ----------------------------------------------------------------------------- +// Helpers +// ----------------------------------------------------------------------------- + +// generateToken creates a cryptographically random token +func generateToken() (string, error) { + b := make([]byte, TokenBytes) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} + +// isSessionExpired checks if a pending session has expired +func isSessionExpired(session *models.AgentSession) bool { + return time.Now().After(session.ExpiresAt) && session.Status == "pending" +} + +// timelineItemToAgentItem converts a TimelineItem to the agent API format +func timelineItemToAgentItem(item models.TimelineItem) agentContextItem { + t := item.Time + return agentContextItem{ + ID: item.ID, + Source: item.Source, + Type: string(item.Type), + Title: item.Title, + Description: item.Description, + Due: &t, + Completable: item.Type == models.TimelineItemTypeTask || item.Type == models.TimelineItemTypeCard || item.Type == models.TimelineItemTypeGTask, + URL: item.URL, + } +} + +// renderAgentTemplate renders an agent template with common error handling +func (h *Handler) renderAgentTemplate(w http.ResponseWriter, templateName string, data interface{}) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if h.templates == nil { + h.renderAgentError(w, "Templates not loaded", http.StatusInternalServerError) + return + } + if err := h.templates.ExecuteTemplate(w, templateName, data); err != nil { + h.renderAgentError(w, "Template error", http.StatusInternalServerError) + } +} + +// ----------------------------------------------------------------------------- +// Auth Handlers +// ----------------------------------------------------------------------------- + +// HandleAgentAuthRequest handles POST /agent/auth/request +func (h *Handler) HandleAgentAuthRequest(w http.ResponseWriter, r *http.Request) { + var req models.AgentAuthRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + if req.Name == "" || req.AgentID == "" { + http.Error(w, "name and agent_id are required", http.StatusBadRequest) + return + } + + // Invalidate any previous sessions for this agent + if err := h.store.InvalidatePreviousAgentSessions(req.AgentID); err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Generate request token + requestToken, err := generateToken() + if err != nil { + http.Error(w, "Failed to generate token", http.StatusInternalServerError) + return + } + + // Create pending session + session := &models.AgentSession{ + RequestToken: requestToken, + AgentName: req.Name, + AgentID: req.AgentID, + ExpiresAt: time.Now().Add(AgentRequestExpiry), + } + if err := h.store.CreateAgentSession(session); err != nil { + http.Error(w, "Failed to create session", http.StatusInternalServerError) + return + } + + // Check trust level for WebSocket notification + trustLevel, err := h.store.CheckAgentTrust(req.Name, req.AgentID) + if err != nil { + trustLevel = models.AgentTrustNew + } + + // Broadcast to connected browsers via WebSocket + h.BroadcastAgentRequest(session, trustLevel) + + // Return response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(models.AgentAuthResponse{ + RequestToken: requestToken, + Status: "pending", + }) +} + +// HandleAgentAuthPoll handles GET /agent/auth/poll +func (h *Handler) HandleAgentAuthPoll(w http.ResponseWriter, r *http.Request) { + token := r.URL.Query().Get("token") + if token == "" { + http.Error(w, "token parameter required", http.StatusBadRequest) + return + } + + session, err := h.store.GetAgentSessionByRequestToken(token) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + if session == nil { + http.Error(w, "Session not found", http.StatusNotFound) + return + } + + if isSessionExpired(session) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(models.AgentPollResponse{Status: "expired"}) + return + } + + resp := models.AgentPollResponse{Status: session.Status} + + if session.Status == "approved" && session.SessionToken != "" { + resp.SessionToken = session.SessionToken + resp.ExpiresAt = session.SessionExpiresAt + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +// HandleAgentAuthApprove handles POST /agent/auth/approve (browser auth required) +func (h *Handler) HandleAgentAuthApprove(w http.ResponseWriter, r *http.Request) { + var req struct { + RequestToken string `json:"request_token"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + if req.RequestToken == "" { + http.Error(w, "request_token required", http.StatusBadRequest) + return + } + + // Verify session exists and is pending + session, err := h.store.GetAgentSessionByRequestToken(req.RequestToken) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + if session == nil { + http.Error(w, "Session not found", http.StatusNotFound) + return + } + if session.Status != "pending" { + http.Error(w, "Session already processed", http.StatusConflict) + return + } + if time.Now().After(session.ExpiresAt) { + http.Error(w, "Session expired", http.StatusGone) + return + } + + // Generate session token + sessionToken, err := generateToken() + if err != nil { + http.Error(w, "Failed to generate session token", http.StatusInternalServerError) + return + } + + sessionExpiresAt := time.Now().Add(AgentSessionTTL) + + // Approve the session + if err := h.store.ApproveAgentSession(req.RequestToken, sessionToken, sessionExpiresAt); err != nil { + http.Error(w, "Failed to approve session", http.StatusInternalServerError) + return + } + + // Register/update agent in the trusted agents table + if err := h.store.CreateOrUpdateAgent(session.AgentName, session.AgentID); err != nil { + // Log but don't fail - the session was approved + // This just affects future trust level checks + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "approved"}) +} + +// HandleAgentAuthDeny handles POST /agent/auth/deny (browser auth required) +func (h *Handler) HandleAgentAuthDeny(w http.ResponseWriter, r *http.Request) { + var req struct { + RequestToken string `json:"request_token"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + if req.RequestToken == "" { + http.Error(w, "request_token required", http.StatusBadRequest) + return + } + + if err := h.store.DenyAgentSession(req.RequestToken); err != nil { + http.Error(w, "Failed to deny session", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "denied"}) +} + +// ----------------------------------------------------------------------------- +// Context Handlers +// ----------------------------------------------------------------------------- + +// HandleAgentContext handles GET /agent/context (agent auth required) +func (h *Handler) HandleAgentContext(w http.ResponseWriter, r *http.Request) { + session := r.Context().Value(agentSessionContextKey).(*models.AgentSession) + _ = h.store.UpdateAgentLastSeen(session.AgentID) + + now := config.Now() + startDate := config.Today() + endDate := startDate.Add(7 * 24 * time.Hour) + + timeline := h.buildAgentContext(r.Context(), startDate, endDate) + + resp := map[string]interface{}{ + "generated_at": now.Format(time.RFC3339), + "range": map[string]string{ + "start": startDate.Format("2006-01-02"), + "end": endDate.Format("2006-01-02"), + }, + "timeline": timeline, + "summary": h.buildContextSummary(timeline, startDate), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +// buildAgentContext builds the context timeline by reusing BuildTimeline +func (h *Handler) buildAgentContext(ctx context.Context, start, end time.Time) []agentContextItem { + // Reuse the main BuildTimeline function (excludes live API calls for Google services) + timelineItems, err := BuildTimeline(ctx, h.store, nil, nil, start, end) + if err != nil { + return nil + } + + // Convert to agent API format, filtering completed items + var items []agentContextItem + for _, item := range timelineItems { + if item.IsCompleted { + continue + } + items = append(items, timelineItemToAgentItem(item)) + } + return items +} + +// buildContextSummary builds summary statistics for the agent context +func (h *Handler) buildContextSummary(items []agentContextItem, today time.Time) map[string]interface{} { + bySource := make(map[string]int) + var overdue, todayCount int + endOfToday := today.Add(24 * time.Hour) + + for _, item := range items { + bySource[item.Source]++ + if item.Due != nil { + if item.Due.Before(today) { + overdue++ + } else if item.Due.Before(endOfToday) { + todayCount++ + } + } + } + + return map[string]interface{}{ + "total_items": len(items), + "by_source": bySource, + "overdue": overdue, + "today": todayCount, + } +} + +// ----------------------------------------------------------------------------- +// Middleware +// ----------------------------------------------------------------------------- + +// AgentAuthMiddleware verifies agent session token +func (h *Handler) AgentAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " { + http.Error(w, "Authorization header required", http.StatusUnauthorized) + return + } + + token := authHeader[7:] + + session, err := h.store.GetAgentSessionBySessionToken(token) + if err != nil || session == nil { + http.Error(w, "Invalid session token", http.StatusUnauthorized) + return + } + + // Check session expiry + if session.SessionExpiresAt != nil && time.Now().After(*session.SessionExpiresAt) { + http.Error(w, "Session expired", http.StatusUnauthorized) + return + } + + // Add session to context + ctx := context.WithValue(r.Context(), agentSessionContextKey, session) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// ----------------------------------------------------------------------------- +// Web Handlers (HTML pages for browser-only agents) +// ----------------------------------------------------------------------------- + +// HandleAgentWebRequest handles GET /agent/web/request for browser-only agents +func (h *Handler) HandleAgentWebRequest(w http.ResponseWriter, r *http.Request) { + name := r.URL.Query().Get("name") + agentID := r.URL.Query().Get("agent_id") + + if name == "" || agentID == "" { + h.renderAgentError(w, "name and agent_id query parameters are required", http.StatusBadRequest) + return + } + + // Check for existing pending session + existingSession, err := h.store.GetPendingAgentSessionByAgentID(agentID) + if err != nil { + h.renderAgentError(w, "Internal server error", http.StatusInternalServerError) + return + } + + if existingSession != nil { + // Return existing pending session + h.renderAgentRequest(w, existingSession) + return + } + + // Invalidate any previous sessions for this agent + if err := h.store.InvalidatePreviousAgentSessions(agentID); err != nil { + h.renderAgentError(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Generate request token + requestToken, err := generateToken() + if err != nil { + h.renderAgentError(w, "Failed to generate token", http.StatusInternalServerError) + return + } + + // Create pending session + session := &models.AgentSession{ + RequestToken: requestToken, + AgentName: name, + AgentID: agentID, + ExpiresAt: time.Now().Add(AgentRequestExpiry), + } + if err := h.store.CreateAgentSession(session); err != nil { + h.renderAgentError(w, "Failed to create session", http.StatusInternalServerError) + return + } + + // Check trust level for WebSocket notification + trustLevel, err := h.store.CheckAgentTrust(name, agentID) + if err != nil { + trustLevel = models.AgentTrustNew + } + + // Broadcast to connected browsers via WebSocket + h.BroadcastAgentRequest(session, trustLevel) + + h.renderAgentRequest(w, session) +} + +// HandleAgentWebStatus handles GET /agent/web/status for browser-only agents +// Returns HTML page with approval status and session token if approved +func (h *Handler) HandleAgentWebStatus(w http.ResponseWriter, r *http.Request) { + token := r.URL.Query().Get("token") + if token == "" { + h.renderAgentError(w, "token query parameter required", http.StatusBadRequest) + return + } + + session, err := h.store.GetAgentSessionByRequestToken(token) + if err != nil { + h.renderAgentError(w, "Internal server error", http.StatusInternalServerError) + return + } + if session == nil { + h.renderAgentError(w, "Session not found", http.StatusNotFound) + return + } + + h.renderAgentStatus(w, session) +} + +// HandleAgentWebContext handles GET /agent/web/context for browser-only agents +func (h *Handler) HandleAgentWebContext(w http.ResponseWriter, r *http.Request) { + sessionToken := r.URL.Query().Get("session") + if sessionToken == "" { + h.renderAgentError(w, "session query parameter required", http.StatusBadRequest) + return + } + + session, err := h.store.GetAgentSessionBySessionToken(sessionToken) + if err != nil || session == nil { + h.renderAgentError(w, "Invalid session token", http.StatusUnauthorized) + return + } + + if session.SessionExpiresAt != nil && time.Now().After(*session.SessionExpiresAt) { + h.renderAgentError(w, "Session expired", http.StatusUnauthorized) + return + } + + _ = h.store.UpdateAgentLastSeen(session.AgentID) + + now := config.Now() + startDate := config.Today() + endDate := startDate.Add(7 * 24 * time.Hour) + + timeline := h.buildAgentContext(r.Context(), startDate, endDate) + h.renderAgentContext(w, session, timeline, startDate, endDate, now) +} + +// renderAgentError renders an error page for agent web endpoints +func (h *Handler) renderAgentError(w http.ResponseWriter, message string, status int) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(status) + if h.templates != nil { + _ = h.templates.ExecuteTemplate(w, "agent-error.html", map[string]interface{}{ + "Error": message, + "Status": status, + }) + } else { + // Fallback if template not loaded + w.Write([]byte(`Error

Error

` + message + `

`)) + } +} + +// renderAgentRequest renders the request page with token info +func (h *Handler) renderAgentRequest(w http.ResponseWriter, session *models.AgentSession) { + h.renderAgentTemplate(w, "agent-request.html", map[string]interface{}{ + "RequestToken": session.RequestToken, + "AgentName": session.AgentName, + "AgentID": session.AgentID, + "Status": "pending", + "PollURL": "/agent/web/status?token=" + session.RequestToken, + "ExpiresAt": session.ExpiresAt.Format(time.RFC3339), + }) +} + +// renderAgentStatus renders the status page +func (h *Handler) renderAgentStatus(w http.ResponseWriter, session *models.AgentSession) { + status := session.Status + if isSessionExpired(session) { + status = "expired" + } + + data := map[string]interface{}{ + "RequestToken": session.RequestToken, + "AgentName": session.AgentName, + "Status": status, + } + + if status == "approved" && session.SessionToken != "" { + data["SessionToken"] = session.SessionToken + data["ContextURL"] = "/agent/web/context?session=" + session.SessionToken + if session.SessionExpiresAt != nil { + data["SessionExpiresAt"] = session.SessionExpiresAt.Format(time.RFC3339) + } + } + + h.renderAgentTemplate(w, "agent-status.html", data) +} + +// renderAgentContext renders the context page with timeline data +func (h *Handler) renderAgentContext(w http.ResponseWriter, session *models.AgentSession, timeline []agentContextItem, startDate, endDate, now time.Time) { + h.renderAgentTemplate(w, "agent-context.html", map[string]interface{}{ + "AgentName": session.AgentName, + "GeneratedAt": now.Format(time.RFC3339), + "RangeStart": startDate.Format("2006-01-02"), + "RangeEnd": endDate.Format("2006-01-02"), + "Timeline": timeline, + "Summary": h.buildContextSummary(timeline, startDate), + }) +} diff --git a/internal/handlers/agent_test.go b/internal/handlers/agent_test.go new file mode 100644 index 0000000..7828650 --- /dev/null +++ b/internal/handlers/agent_test.go @@ -0,0 +1,706 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "task-dashboard/internal/config" + "task-dashboard/internal/models" +) + +func TestHandleAgentAuthRequest(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + tests := []struct { + name string + requestBody interface{} + expectedStatus int + checkResponse func(t *testing.T, resp *models.AgentAuthResponse) + }{ + { + name: "valid request", + requestBody: models.AgentAuthRequest{ + Name: "TestAgent", + AgentID: "test-uuid-12345", + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, resp *models.AgentAuthResponse) { + if resp.RequestToken == "" { + t.Error("Expected request_token in response") + } + if resp.Status != "pending" { + t.Errorf("Expected status 'pending', got '%s'", resp.Status) + } + }, + }, + { + name: "missing name", + requestBody: models.AgentAuthRequest{ + AgentID: "test-uuid-12345", + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "missing agent_id", + requestBody: models.AgentAuthRequest{ + Name: "TestAgent", + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "invalid JSON", + requestBody: "not json", + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var body []byte + var err error + if s, ok := tt.requestBody.(string); ok { + body = []byte(s) + } else { + body, err = json.Marshal(tt.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + } + + req := httptest.NewRequest(http.MethodPost, "/agent/auth/request", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.HandleAgentAuthRequest(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) + } + + if tt.checkResponse != nil && w.Code == http.StatusOK { + var resp models.AgentAuthResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + tt.checkResponse(t, &resp) + } + }) + } +} + +func TestHandleAgentAuthPoll(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + // Create a pending session first + session := &models.AgentSession{ + RequestToken: "test-token-123", + AgentName: "TestAgent", + AgentID: "agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + + tests := []struct { + name string + token string + expectedStatus int + expectedState string + }{ + { + name: "valid pending session", + token: "test-token-123", + expectedStatus: http.StatusOK, + expectedState: "pending", + }, + { + name: "missing token", + token: "", + expectedStatus: http.StatusBadRequest, + }, + { + name: "non-existent token", + token: "non-existent", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/agent/auth/poll" + if tt.token != "" { + url += "?token=" + tt.token + } + + req := httptest.NewRequest(http.MethodGet, url, nil) + w := httptest.NewRecorder() + + h.HandleAgentAuthPoll(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) + } + + if tt.expectedStatus == http.StatusOK { + var resp models.AgentPollResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + if resp.Status != tt.expectedState { + t.Errorf("Expected status '%s', got '%s'", tt.expectedState, resp.Status) + } + } + }) + } +} + +func TestHandleAgentAuthApprove(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + // Create a pending session + session := &models.AgentSession{ + RequestToken: "approve-test-token", + AgentName: "ApproveTestAgent", + AgentID: "approve-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + + // Test approval + body, _ := json.Marshal(map[string]string{"request_token": "approve-test-token"}) + req := httptest.NewRequest(http.MethodPost, "/agent/auth/approve", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.HandleAgentAuthApprove(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify session is now approved + updated, err := db.GetAgentSessionByRequestToken("approve-test-token") + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if updated.Status != "approved" { + t.Errorf("Expected session status 'approved', got '%s'", updated.Status) + } + if updated.SessionToken == "" { + t.Error("Expected session_token to be set") + } + + // Verify agent was registered + agent, err := db.GetAgentByAgentID("approve-agent-uuid") + if err != nil { + t.Fatalf("Failed to get agent: %v", err) + } + if agent == nil { + t.Error("Expected agent to be created") + } +} + +func TestHandleAgentAuthDeny(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + // Create a pending session + session := &models.AgentSession{ + RequestToken: "deny-test-token", + AgentName: "DenyTestAgent", + AgentID: "deny-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + + // Test denial + body, _ := json.Marshal(map[string]string{"request_token": "deny-test-token"}) + req := httptest.NewRequest(http.MethodPost, "/agent/auth/deny", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.HandleAgentAuthDeny(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify session is now denied + updated, err := db.GetAgentSessionByRequestToken("deny-test-token") + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if updated.Status != "denied" { + t.Errorf("Expected session status 'denied', got '%s'", updated.Status) + } +} + +func TestAgentAuthMiddleware(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + // Create an approved session + sessionToken := "valid-session-token" + sessionExpiry := time.Now().Add(1 * time.Hour) + session := &models.AgentSession{ + RequestToken: "middleware-test-token", + AgentName: "MiddlewareTestAgent", + AgentID: "middleware-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + if err := db.ApproveAgentSession("middleware-test-token", sessionToken, sessionExpiry); err != nil { + t.Fatalf("Failed to approve session: %v", err) + } + + // Create a test handler that the middleware protects + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("protected content")) + }) + + wrappedHandler := h.AgentAuthMiddleware(testHandler) + + tests := []struct { + name string + authHeader string + expectedStatus int + }{ + { + name: "valid token", + authHeader: "Bearer " + sessionToken, + expectedStatus: http.StatusOK, + }, + { + name: "missing header", + authHeader: "", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "invalid format", + authHeader: "Token " + sessionToken, + expectedStatus: http.StatusUnauthorized, + }, + { + name: "invalid token", + authHeader: "Bearer invalid-token", + expectedStatus: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/agent/context", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) + } + }) + } +} + +func TestAgentTrustLevel(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + // Test new agent (never seen before) + trust, err := db.CheckAgentTrust("NewAgent", "new-uuid") + if err != nil { + t.Fatalf("Failed to check trust: %v", err) + } + if trust != models.AgentTrustNew { + t.Errorf("Expected trust level 'new', got '%s'", trust) + } + + // Register the agent + if err := db.CreateOrUpdateAgent("NewAgent", "new-uuid"); err != nil { + t.Fatalf("Failed to register agent: %v", err) + } + + // Test recognized agent (same name + ID) + trust, err = db.CheckAgentTrust("NewAgent", "new-uuid") + if err != nil { + t.Fatalf("Failed to check trust: %v", err) + } + if trust != models.AgentTrustRecognized { + t.Errorf("Expected trust level 'recognized', got '%s'", trust) + } + + // Test suspicious agent (same name, different ID) + trust, err = db.CheckAgentTrust("NewAgent", "different-uuid") + if err != nil { + t.Fatalf("Failed to check trust: %v", err) + } + if trust != models.AgentTrustSuspicious { + t.Errorf("Expected trust level 'suspicious', got '%s'", trust) + } +} + +func TestInvalidatePreviousSessions(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + agentID := "invalidate-test-agent" + + // Create multiple sessions for the same agent + for i := 0; i < 3; i++ { + session := &models.AgentSession{ + RequestToken: "token-" + string(rune('a'+i)), + AgentName: "TestAgent", + AgentID: agentID, + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create session %d: %v", i, err) + } + } + + // Invalidate all sessions for this agent + if err := db.InvalidatePreviousAgentSessions(agentID); err != nil { + t.Fatalf("Failed to invalidate sessions: %v", err) + } + + // Verify all sessions are expired + for _, token := range []string{"token-a", "token-b", "token-c"} { + session, err := db.GetAgentSessionByRequestToken(token) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if session.Status != "expired" { + t.Errorf("Expected session %s to be expired, got '%s'", token, session.Status) + } + } +} + +func TestHandleAgentWebRequest(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + tests := []struct { + name string + queryParams string + expectedStatus int + checkBody func(t *testing.T, body string) + }{ + { + name: "valid request", + queryParams: "?name=TestAgent&agent_id=web-test-uuid", + expectedStatus: http.StatusOK, + checkBody: func(t *testing.T, body string) { + if body == "" { + t.Error("Expected non-empty response body") + } + // Should contain JSON data in script tag + if !contains(body, "application/json") { + t.Error("Expected JSON data in response") + } + if !contains(body, "request_token") { + t.Error("Expected request_token in response") + } + }, + }, + { + name: "missing name", + queryParams: "?agent_id=test-uuid", + expectedStatus: http.StatusBadRequest, + }, + { + name: "missing agent_id", + queryParams: "?name=TestAgent", + expectedStatus: http.StatusBadRequest, + }, + { + name: "missing both", + queryParams: "", + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/agent/web/request"+tt.queryParams, nil) + w := httptest.NewRecorder() + + h.HandleAgentWebRequest(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d: %s", tt.expectedStatus, w.Code, w.Body.String()) + } + + if tt.checkBody != nil && w.Code == http.StatusOK { + tt.checkBody(t, w.Body.String()) + } + }) + } +} + +func TestHandleAgentWebRequestReturnsExistingPending(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + agentID := "reuse-pending-uuid" + + // First request creates a session + req1 := httptest.NewRequest(http.MethodGet, "/agent/web/request?name=TestAgent&agent_id="+agentID, nil) + w1 := httptest.NewRecorder() + h.HandleAgentWebRequest(w1, req1) + + if w1.Code != http.StatusOK { + t.Fatalf("First request failed: %d, body: %s", w1.Code, w1.Body.String()) + } + body1 := w1.Body.String() + + // Verify session was created + if !contains(body1, "request_token") { + t.Fatal("First response doesn't contain request_token") + } + + // Second request should return the same pending session + req2 := httptest.NewRequest(http.MethodGet, "/agent/web/request?name=TestAgent&agent_id="+agentID, nil) + w2 := httptest.NewRecorder() + h.HandleAgentWebRequest(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("Second request failed: %d", w2.Code) + } + body2 := w2.Body.String() + + // Both responses should be identical (same session reused) + if body1 != body2 { + t.Error("Expected same response for existing pending session (session should be reused)") + } +} + +func TestHandleAgentWebStatus(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + // Create a pending session + session := &models.AgentSession{ + RequestToken: "web-status-test-token", + AgentName: "WebStatusTestAgent", + AgentID: "web-status-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + + tests := []struct { + name string + token string + expectedStatus int + checkBody func(t *testing.T, body string) + }{ + { + name: "valid pending session", + token: "web-status-test-token", + expectedStatus: http.StatusOK, + checkBody: func(t *testing.T, body string) { + if !contains(body, `"status": "pending"`) { + t.Error("Expected status 'pending' in response") + } + }, + }, + { + name: "missing token", + token: "", + expectedStatus: http.StatusBadRequest, + }, + { + name: "non-existent token", + token: "non-existent-token", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/agent/web/status" + if tt.token != "" { + url += "?token=" + tt.token + } + + req := httptest.NewRequest(http.MethodGet, url, nil) + w := httptest.NewRecorder() + + h.HandleAgentWebStatus(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d: %s", tt.expectedStatus, w.Code, w.Body.String()) + } + + if tt.checkBody != nil && w.Code == http.StatusOK { + tt.checkBody(t, w.Body.String()) + } + }) + } +} + +func TestHandleAgentWebStatusApproved(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + // Create and approve a session + session := &models.AgentSession{ + RequestToken: "web-approved-test-token", + AgentName: "WebApprovedTestAgent", + AgentID: "web-approved-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + sessionToken := "web-session-token-xyz" + if err := db.ApproveAgentSession("web-approved-test-token", sessionToken, time.Now().Add(1*time.Hour)); err != nil { + t.Fatalf("Failed to approve session: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/agent/web/status?token=web-approved-test-token", nil) + w := httptest.NewRecorder() + + h.HandleAgentWebStatus(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + body := w.Body.String() + if !contains(body, `"status": "approved"`) { + t.Error("Expected status 'approved' in response") + } + if !contains(body, `"session_token": "`+sessionToken) { + t.Error("Expected session_token in response") + } + if !contains(body, `"context_url":`) { + t.Error("Expected context_url in response") + } +} + +func TestHandleAgentWebContext(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + // Create and approve a session + session := &models.AgentSession{ + RequestToken: "web-context-test-token", + AgentName: "WebContextTestAgent", + AgentID: "web-context-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + sessionToken := "web-context-session-token" + if err := db.ApproveAgentSession("web-context-test-token", sessionToken, time.Now().Add(1*time.Hour)); err != nil { + t.Fatalf("Failed to approve session: %v", err) + } + + tests := []struct { + name string + sessionToken string + expectedStatus int + checkBody func(t *testing.T, body string) + }{ + { + name: "valid session", + sessionToken: sessionToken, + expectedStatus: http.StatusOK, + checkBody: func(t *testing.T, body string) { + if !contains(body, "generated_at") { + t.Error("Expected generated_at in response") + } + if !contains(body, "timeline") { + t.Error("Expected timeline in response") + } + }, + }, + { + name: "missing session", + sessionToken: "", + expectedStatus: http.StatusBadRequest, + }, + { + name: "invalid session", + sessionToken: "invalid-session-token", + expectedStatus: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/agent/web/context" + if tt.sessionToken != "" { + url += "?session=" + tt.sessionToken + } + + req := httptest.NewRequest(http.MethodGet, url, nil) + w := httptest.NewRecorder() + + h.HandleAgentWebContext(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d: %s", tt.expectedStatus, w.Code, w.Body.String()) + } + + if tt.checkBody != nil && w.Code == http.StatusOK { + tt.checkBody(t, w.Body.String()) + } + }) + } +} + +// contains is a helper function to check if a string contains a substring +func contains(s, substr string) bool { + return bytes.Contains([]byte(s), []byte(substr)) +} diff --git a/internal/handlers/websocket.go b/internal/handlers/websocket.go new file mode 100644 index 0000000..1677f88 --- /dev/null +++ b/internal/handlers/websocket.go @@ -0,0 +1,216 @@ +package handlers + +import ( + "encoding/json" + "log" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + + "task-dashboard/internal/models" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + // Allow all connections (CORS disabled for WebSocket) + return true + }, +} + +// WSClient represents a connected browser +type WSClient struct { + conn *websocket.Conn + send chan []byte +} + +// WSHub manages WebSocket connections +type WSHub struct { + clients map[*WSClient]bool + broadcast chan []byte + register chan *WSClient + unregister chan *WSClient + mu sync.RWMutex +} + +var wsHub *WSHub +var hubOnce sync.Once + +// GetWSHub returns the singleton WebSocket hub +func GetWSHub() *WSHub { + hubOnce.Do(func() { + wsHub = &WSHub{ + clients: make(map[*WSClient]bool), + broadcast: make(chan []byte, 256), + register: make(chan *WSClient), + unregister: make(chan *WSClient), + } + go wsHub.run() + }) + return wsHub +} + +// run handles hub operations +func (h *WSHub) run() { + for { + select { + case client := <-h.register: + h.mu.Lock() + h.clients[client] = true + h.mu.Unlock() + log.Printf("WebSocket client connected, total: %d", len(h.clients)) + + case client := <-h.unregister: + h.mu.Lock() + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) + } + h.mu.Unlock() + log.Printf("WebSocket client disconnected, total: %d", len(h.clients)) + + case message := <-h.broadcast: + h.mu.RLock() + for client := range h.clients { + select { + case client.send <- message: + default: + // Client buffer full, skip + } + } + h.mu.RUnlock() + } + } +} + +// Broadcast sends a message to all connected clients +func (h *WSHub) Broadcast(msg []byte) { + h.broadcast <- msg +} + +// ClientCount returns the number of connected clients +func (h *WSHub) ClientCount() int { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.clients) +} + +// WSMessage is the structure for WebSocket messages +type WSMessage struct { + Type string `json:"type"` + Payload interface{} `json:"payload"` +} + +// BroadcastAgentRequest sends an agent request notification to all browsers +func (h *Handler) BroadcastAgentRequest(session *models.AgentSession, trustLevel models.AgentTrustLevel) { + hub := GetWSHub() + + payload := AgentRequestPayload{ + RequestToken: session.RequestToken, + AgentName: session.AgentName, + AgentID: session.AgentID, + TrustLevel: trustLevel, + ExpiresAt: session.ExpiresAt, + } + + msg := WSMessage{ + Type: "agent_request", + Payload: payload, + } + + data, err := json.Marshal(msg) + if err != nil { + log.Printf("Failed to marshal WebSocket message: %v", err) + return + } + + hub.Broadcast(data) + log.Printf("Broadcasted agent request from %s to %d clients", session.AgentName, hub.ClientCount()) +} + +// HandleWebSocket handles the WebSocket connection at /ws/notifications +func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { + // Check if the request is a WebSocket upgrade request + if !websocket.IsWebSocketUpgrade(r) { + http.Error(w, "Expected WebSocket Upgrade request", http.StatusBadRequest) + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket upgrade failed: %v", err) + return + } + + hub := GetWSHub() + client := &WSClient{ + conn: conn, + send: make(chan []byte, 256), + } + + hub.register <- client + + // Start goroutines for reading and writing + go client.writePump() + go client.readPump(hub) +} + +// writePump sends messages to the client +func (c *WSClient) writePump() { + ticker := time.NewTicker(30 * time.Second) + defer func() { + ticker.Stop() + c.conn.Close() + }() + + for { + select { + case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if !ok { + // Channel closed + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { + return + } + + case <-ticker.C: + // Send ping + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +// readPump reads messages from the client +func (c *WSClient) readPump(hub *WSHub) { + defer func() { + hub.unregister <- c + c.conn.Close() + }() + + c.conn.SetReadLimit(512) + c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + c.conn.SetPongHandler(func(string) error { + c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + for { + _, _, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket error: %v", err) + } + break + } + } +} -- cgit v1.2.3