package handlers import ( "bytes" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "task-dashboard/internal/config" "task-dashboard/internal/models" ) func TestHandleAgentAuthRequest(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} h := &Handler{store: db, config: cfg} tests := []struct { name string requestBody interface{} expectedStatus int checkResponse func(t *testing.T, resp *models.AgentAuthResponse) }{ { name: "valid request", requestBody: models.AgentAuthRequest{ Name: "TestAgent", AgentID: "test-uuid-12345", }, expectedStatus: http.StatusOK, checkResponse: func(t *testing.T, resp *models.AgentAuthResponse) { if resp.RequestToken == "" { t.Error("Expected request_token in response") } if resp.Status != "pending" { t.Errorf("Expected status 'pending', got '%s'", resp.Status) } }, }, { name: "missing name", requestBody: models.AgentAuthRequest{ AgentID: "test-uuid-12345", }, expectedStatus: http.StatusBadRequest, }, { name: "missing agent_id", requestBody: models.AgentAuthRequest{ Name: "TestAgent", }, expectedStatus: http.StatusBadRequest, }, { name: "invalid JSON", requestBody: "not json", expectedStatus: http.StatusBadRequest, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var body []byte var err error if s, ok := tt.requestBody.(string); ok { body = []byte(s) } else { body, err = json.Marshal(tt.requestBody) if err != nil { t.Fatalf("Failed to marshal request body: %v", err) } } req := httptest.NewRequest(http.MethodPost, "/agent/auth/request", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() h.HandleAgentAuthRequest(w, req) if w.Code != tt.expectedStatus { t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) } if tt.checkResponse != nil && w.Code == http.StatusOK { var resp models.AgentAuthResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } tt.checkResponse(t, &resp) } }) } } func TestHandleAgentAuthPoll(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} h := &Handler{store: db, config: cfg} // Create a pending session first session := &models.AgentSession{ RequestToken: "test-token-123", AgentName: "TestAgent", AgentID: "agent-uuid", ExpiresAt: time.Now().Add(5 * time.Minute), } if err := db.CreateAgentSession(session); err != nil { t.Fatalf("Failed to create test session: %v", err) } tests := []struct { name string token string expectedStatus int expectedState string }{ { name: "valid pending session", token: "test-token-123", expectedStatus: http.StatusOK, expectedState: "pending", }, { name: "missing token", token: "", expectedStatus: http.StatusBadRequest, }, { name: "non-existent token", token: "non-existent", expectedStatus: http.StatusNotFound, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { url := "/agent/auth/poll" if tt.token != "" { url += "?token=" + tt.token } req := httptest.NewRequest(http.MethodGet, url, nil) w := httptest.NewRecorder() h.HandleAgentAuthPoll(w, req) if w.Code != tt.expectedStatus { t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) } if tt.expectedStatus == http.StatusOK { var resp models.AgentPollResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } if resp.Status != tt.expectedState { t.Errorf("Expected status '%s', got '%s'", tt.expectedState, resp.Status) } } }) } } func TestHandleAgentAuthApprove(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} h := &Handler{store: db, config: cfg} // Create a pending session session := &models.AgentSession{ RequestToken: "approve-test-token", AgentName: "ApproveTestAgent", AgentID: "approve-agent-uuid", ExpiresAt: time.Now().Add(5 * time.Minute), } if err := db.CreateAgentSession(session); err != nil { t.Fatalf("Failed to create test session: %v", err) } // Test approval body, _ := json.Marshal(map[string]string{"request_token": "approve-test-token"}) req := httptest.NewRequest(http.MethodPost, "/agent/auth/approve", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() h.HandleAgentAuthApprove(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) } // Verify session is now approved updated, err := db.GetAgentSessionByRequestToken("approve-test-token") if err != nil { t.Fatalf("Failed to get session: %v", err) } if updated.Status != "approved" { t.Errorf("Expected session status 'approved', got '%s'", updated.Status) } if updated.SessionToken == "" { t.Error("Expected session_token to be set") } // Verify agent was registered agent, err := db.GetAgentByAgentID("approve-agent-uuid") if err != nil { t.Fatalf("Failed to get agent: %v", err) } if agent == nil { t.Error("Expected agent to be created") } } func TestHandleAgentAuthDeny(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} h := &Handler{store: db, config: cfg} // Create a pending session session := &models.AgentSession{ RequestToken: "deny-test-token", AgentName: "DenyTestAgent", AgentID: "deny-agent-uuid", ExpiresAt: time.Now().Add(5 * time.Minute), } if err := db.CreateAgentSession(session); err != nil { t.Fatalf("Failed to create test session: %v", err) } // Test denial body, _ := json.Marshal(map[string]string{"request_token": "deny-test-token"}) req := httptest.NewRequest(http.MethodPost, "/agent/auth/deny", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() h.HandleAgentAuthDeny(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) } // Verify session is now denied updated, err := db.GetAgentSessionByRequestToken("deny-test-token") if err != nil { t.Fatalf("Failed to get session: %v", err) } if updated.Status != "denied" { t.Errorf("Expected session status 'denied', got '%s'", updated.Status) } } func TestAgentAuthMiddleware(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} h := &Handler{store: db, config: cfg} // Create an approved session sessionToken := "valid-session-token" sessionExpiry := time.Now().Add(1 * time.Hour) session := &models.AgentSession{ RequestToken: "middleware-test-token", AgentName: "MiddlewareTestAgent", AgentID: "middleware-agent-uuid", ExpiresAt: time.Now().Add(5 * time.Minute), } if err := db.CreateAgentSession(session); err != nil { t.Fatalf("Failed to create test session: %v", err) } if err := db.ApproveAgentSession("middleware-test-token", sessionToken, sessionExpiry); err != nil { t.Fatalf("Failed to approve session: %v", err) } // Create a test handler that the middleware protects testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("protected content")) }) wrappedHandler := h.AgentAuthMiddleware(testHandler) tests := []struct { name string authHeader string expectedStatus int }{ { name: "valid token", authHeader: "Bearer " + sessionToken, expectedStatus: http.StatusOK, }, { name: "missing header", authHeader: "", expectedStatus: http.StatusUnauthorized, }, { name: "invalid format", authHeader: "Token " + sessionToken, expectedStatus: http.StatusUnauthorized, }, { name: "invalid token", authHeader: "Bearer invalid-token", expectedStatus: http.StatusUnauthorized, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/agent/context", nil) if tt.authHeader != "" { req.Header.Set("Authorization", tt.authHeader) } w := httptest.NewRecorder() wrappedHandler.ServeHTTP(w, req) if w.Code != tt.expectedStatus { t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) } }) } } func TestAgentTrustLevel(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() // Test new agent (never seen before) trust, err := db.CheckAgentTrust("NewAgent", "new-uuid") if err != nil { t.Fatalf("Failed to check trust: %v", err) } if trust != models.AgentTrustNew { t.Errorf("Expected trust level 'new', got '%s'", trust) } // Register the agent if err := db.CreateOrUpdateAgent("NewAgent", "new-uuid"); err != nil { t.Fatalf("Failed to register agent: %v", err) } // Test recognized agent (same name + ID) trust, err = db.CheckAgentTrust("NewAgent", "new-uuid") if err != nil { t.Fatalf("Failed to check trust: %v", err) } if trust != models.AgentTrustRecognized { t.Errorf("Expected trust level 'recognized', got '%s'", trust) } // Test suspicious agent (same name, different ID) trust, err = db.CheckAgentTrust("NewAgent", "different-uuid") if err != nil { t.Fatalf("Failed to check trust: %v", err) } if trust != models.AgentTrustSuspicious { t.Errorf("Expected trust level 'suspicious', got '%s'", trust) } } func TestInvalidatePreviousSessions(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() agentID := "invalidate-test-agent" // Create multiple sessions for the same agent for i := 0; i < 3; i++ { session := &models.AgentSession{ RequestToken: "token-" + string(rune('a'+i)), AgentName: "TestAgent", AgentID: agentID, ExpiresAt: time.Now().Add(5 * time.Minute), } if err := db.CreateAgentSession(session); err != nil { t.Fatalf("Failed to create session %d: %v", i, err) } } // Invalidate all sessions for this agent if err := db.InvalidatePreviousAgentSessions(agentID); err != nil { t.Fatalf("Failed to invalidate sessions: %v", err) } // Verify all sessions are expired for _, token := range []string{"token-a", "token-b", "token-c"} { session, err := db.GetAgentSessionByRequestToken(token) if err != nil { t.Fatalf("Failed to get session: %v", err) } if session.Status != "expired" { t.Errorf("Expected session %s to be expired, got '%s'", token, session.Status) } } } func TestHandleAgentWebRequest(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} mock := newTestRenderer() h := &Handler{store: db, config: cfg, renderer: mock} tests := []struct { name string queryParams string expectedStatus int expectedTemplate string }{ { name: "valid request", queryParams: "?name=TestAgent&agent_id=web-test-uuid", expectedStatus: http.StatusOK, expectedTemplate: "agent-request.html", }, { name: "missing name", queryParams: "?agent_id=test-uuid", expectedStatus: http.StatusBadRequest, }, { name: "missing agent_id", queryParams: "?name=TestAgent", expectedStatus: http.StatusBadRequest, }, { name: "missing both", queryParams: "", expectedStatus: http.StatusBadRequest, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mock.Calls = nil // Reset calls req := httptest.NewRequest(http.MethodGet, "/agent/web/request"+tt.queryParams, nil) w := httptest.NewRecorder() h.HandleAgentWebRequest(w, req) if w.Code != tt.expectedStatus { t.Errorf("Expected status %d, got %d: %s", tt.expectedStatus, w.Code, w.Body.String()) } if tt.expectedTemplate != "" && w.Code == http.StatusOK { if len(mock.Calls) == 0 { t.Error("Expected render call") } else if mock.Calls[0].Name != tt.expectedTemplate { t.Errorf("Expected template %s, got %s", tt.expectedTemplate, mock.Calls[0].Name) } } }) } } func TestHandleAgentWebRequestReturnsExistingPending(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} mock := newTestRenderer() h := &Handler{store: db, config: cfg, renderer: mock} agentID := "reuse-pending-uuid" // First request creates a session req1 := httptest.NewRequest(http.MethodGet, "/agent/web/request?name=TestAgent&agent_id="+agentID, nil) w1 := httptest.NewRecorder() h.HandleAgentWebRequest(w1, req1) if w1.Code != http.StatusOK { t.Fatalf("First request failed: %d, body: %s", w1.Code, w1.Body.String()) } // Verify template was called if len(mock.Calls) == 0 || mock.Calls[0].Name != "agent-request.html" { t.Fatal("First request didn't render agent-request.html") } firstData := mock.Calls[0].Data // Reset mock and make second request mock.Calls = nil req2 := httptest.NewRequest(http.MethodGet, "/agent/web/request?name=TestAgent&agent_id="+agentID, nil) w2 := httptest.NewRecorder() h.HandleAgentWebRequest(w2, req2) if w2.Code != http.StatusOK { t.Fatalf("Second request failed: %d", w2.Code) } // Verify template was called with same data (same session reused) if len(mock.Calls) == 0 { t.Fatal("Second request didn't render template") } secondData := mock.Calls[0].Data // Compare request tokens from data first := firstData.(map[string]interface{}) second := secondData.(map[string]interface{}) if first["RequestToken"] != second["RequestToken"] { t.Error("Expected same session to be reused (same request token)") } } func TestHandleAgentWebStatus(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} mock := newTestRenderer() h := &Handler{store: db, config: cfg, renderer: mock} // Create a pending session session := &models.AgentSession{ RequestToken: "web-status-test-token", AgentName: "WebStatusTestAgent", AgentID: "web-status-agent-uuid", ExpiresAt: time.Now().Add(5 * time.Minute), } if err := db.CreateAgentSession(session); err != nil { t.Fatalf("Failed to create test session: %v", err) } tests := []struct { name string token string expectedStatus int expectedTemplate string checkData func(t *testing.T, data map[string]interface{}) }{ { name: "valid pending session", token: "web-status-test-token", expectedStatus: http.StatusOK, expectedTemplate: "agent-status.html", checkData: func(t *testing.T, data map[string]interface{}) { if data["Status"] != "pending" { t.Errorf("Expected status 'pending', got '%v'", data["Status"]) } }, }, { name: "missing token", token: "", expectedStatus: http.StatusBadRequest, }, { name: "non-existent token", token: "non-existent-token", expectedStatus: http.StatusNotFound, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mock.Calls = nil url := "/agent/web/status" if tt.token != "" { url += "?token=" + tt.token } req := httptest.NewRequest(http.MethodGet, url, nil) w := httptest.NewRecorder() h.HandleAgentWebStatus(w, req) if w.Code != tt.expectedStatus { t.Errorf("Expected status %d, got %d: %s", tt.expectedStatus, w.Code, w.Body.String()) } if tt.expectedTemplate != "" && w.Code == http.StatusOK { if len(mock.Calls) == 0 { t.Error("Expected render call") } else { if mock.Calls[0].Name != tt.expectedTemplate { t.Errorf("Expected template %s, got %s", tt.expectedTemplate, mock.Calls[0].Name) } if tt.checkData != nil { tt.checkData(t, mock.Calls[0].Data.(map[string]interface{})) } } } }) } } func TestHandleAgentWebStatusApproved(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} mock := newTestRenderer() h := &Handler{store: db, config: cfg, renderer: mock} // Create and approve a session session := &models.AgentSession{ RequestToken: "web-approved-test-token", AgentName: "WebApprovedTestAgent", AgentID: "web-approved-agent-uuid", ExpiresAt: time.Now().Add(5 * time.Minute), } if err := db.CreateAgentSession(session); err != nil { t.Fatalf("Failed to create test session: %v", err) } sessionToken := "web-session-token-xyz" if err := db.ApproveAgentSession("web-approved-test-token", sessionToken, time.Now().Add(1*time.Hour)); err != nil { t.Fatalf("Failed to approve session: %v", err) } req := httptest.NewRequest(http.MethodGet, "/agent/web/status?token=web-approved-test-token", nil) w := httptest.NewRecorder() h.HandleAgentWebStatus(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) } if len(mock.Calls) == 0 { t.Fatal("Expected render call") } if mock.Calls[0].Name != "agent-status.html" { t.Errorf("Expected template agent-status.html, got %s", mock.Calls[0].Name) } data := mock.Calls[0].Data.(map[string]interface{}) if data["Status"] != "approved" { t.Errorf("Expected status 'approved', got '%v'", data["Status"]) } if data["SessionToken"] != sessionToken { t.Errorf("Expected session_token '%s', got '%v'", sessionToken, data["SessionToken"]) } if _, ok := data["ContextURL"]; !ok { t.Error("Expected ContextURL in data") } } func TestHandleAgentWebContext(t *testing.T) { db, cleanup := setupTestDB(t) defer cleanup() cfg := &config.Config{} mock := newTestRenderer() h := &Handler{store: db, config: cfg, renderer: mock} // Create and approve a session session := &models.AgentSession{ RequestToken: "web-context-test-token", AgentName: "WebContextTestAgent", AgentID: "web-context-agent-uuid", ExpiresAt: time.Now().Add(5 * time.Minute), } if err := db.CreateAgentSession(session); err != nil { t.Fatalf("Failed to create test session: %v", err) } sessionToken := "web-context-session-token" if err := db.ApproveAgentSession("web-context-test-token", sessionToken, time.Now().Add(1*time.Hour)); err != nil { t.Fatalf("Failed to approve session: %v", err) } tests := []struct { name string sessionToken string expectedStatus int expectedTemplate string checkData func(t *testing.T, data map[string]interface{}) }{ { name: "valid session", sessionToken: sessionToken, expectedStatus: http.StatusOK, expectedTemplate: "agent-context.html", checkData: func(t *testing.T, data map[string]interface{}) { if _, ok := data["GeneratedAt"]; !ok { t.Error("Expected GeneratedAt in data") } if _, ok := data["Timeline"]; !ok { t.Error("Expected Timeline in data") } }, }, { name: "missing session", sessionToken: "", expectedStatus: http.StatusBadRequest, }, { name: "invalid session", sessionToken: "invalid-session-token", expectedStatus: http.StatusUnauthorized, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mock.Calls = nil url := "/agent/web/context" if tt.sessionToken != "" { url += "?session=" + tt.sessionToken } req := httptest.NewRequest(http.MethodGet, url, nil) w := httptest.NewRecorder() h.HandleAgentWebContext(w, req) if w.Code != tt.expectedStatus { t.Errorf("Expected status %d, got %d: %s", tt.expectedStatus, w.Code, w.Body.String()) } if tt.expectedTemplate != "" && w.Code == http.StatusOK { if len(mock.Calls) == 0 { t.Error("Expected render call") } else { if mock.Calls[0].Name != tt.expectedTemplate { t.Errorf("Expected template %s, got %s", tt.expectedTemplate, mock.Calls[0].Name) } if tt.checkData != nil { tt.checkData(t, mock.Calls[0].Data.(map[string]interface{})) } } } }) } } // contains is a helper function to check if a string contains a substring func contains(s, substr string) bool { return bytes.Contains([]byte(s), []byte(substr)) }