diff options
Diffstat (limited to 'internal/handlers/agent_test.go')
| -rw-r--r-- | internal/handlers/agent_test.go | 706 |
1 files changed, 706 insertions, 0 deletions
diff --git a/internal/handlers/agent_test.go b/internal/handlers/agent_test.go new file mode 100644 index 0000000..7828650 --- /dev/null +++ b/internal/handlers/agent_test.go @@ -0,0 +1,706 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "task-dashboard/internal/config" + "task-dashboard/internal/models" +) + +func TestHandleAgentAuthRequest(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + tests := []struct { + name string + requestBody interface{} + expectedStatus int + checkResponse func(t *testing.T, resp *models.AgentAuthResponse) + }{ + { + name: "valid request", + requestBody: models.AgentAuthRequest{ + Name: "TestAgent", + AgentID: "test-uuid-12345", + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, resp *models.AgentAuthResponse) { + if resp.RequestToken == "" { + t.Error("Expected request_token in response") + } + if resp.Status != "pending" { + t.Errorf("Expected status 'pending', got '%s'", resp.Status) + } + }, + }, + { + name: "missing name", + requestBody: models.AgentAuthRequest{ + AgentID: "test-uuid-12345", + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "missing agent_id", + requestBody: models.AgentAuthRequest{ + Name: "TestAgent", + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "invalid JSON", + requestBody: "not json", + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var body []byte + var err error + if s, ok := tt.requestBody.(string); ok { + body = []byte(s) + } else { + body, err = json.Marshal(tt.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + } + + req := httptest.NewRequest(http.MethodPost, "/agent/auth/request", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.HandleAgentAuthRequest(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) + } + + if tt.checkResponse != nil && w.Code == http.StatusOK { + var resp models.AgentAuthResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + tt.checkResponse(t, &resp) + } + }) + } +} + +func TestHandleAgentAuthPoll(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + // Create a pending session first + session := &models.AgentSession{ + RequestToken: "test-token-123", + AgentName: "TestAgent", + AgentID: "agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + + tests := []struct { + name string + token string + expectedStatus int + expectedState string + }{ + { + name: "valid pending session", + token: "test-token-123", + expectedStatus: http.StatusOK, + expectedState: "pending", + }, + { + name: "missing token", + token: "", + expectedStatus: http.StatusBadRequest, + }, + { + name: "non-existent token", + token: "non-existent", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/agent/auth/poll" + if tt.token != "" { + url += "?token=" + tt.token + } + + req := httptest.NewRequest(http.MethodGet, url, nil) + w := httptest.NewRecorder() + + h.HandleAgentAuthPoll(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) + } + + if tt.expectedStatus == http.StatusOK { + var resp models.AgentPollResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + if resp.Status != tt.expectedState { + t.Errorf("Expected status '%s', got '%s'", tt.expectedState, resp.Status) + } + } + }) + } +} + +func TestHandleAgentAuthApprove(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + // Create a pending session + session := &models.AgentSession{ + RequestToken: "approve-test-token", + AgentName: "ApproveTestAgent", + AgentID: "approve-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + + // Test approval + body, _ := json.Marshal(map[string]string{"request_token": "approve-test-token"}) + req := httptest.NewRequest(http.MethodPost, "/agent/auth/approve", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.HandleAgentAuthApprove(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify session is now approved + updated, err := db.GetAgentSessionByRequestToken("approve-test-token") + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if updated.Status != "approved" { + t.Errorf("Expected session status 'approved', got '%s'", updated.Status) + } + if updated.SessionToken == "" { + t.Error("Expected session_token to be set") + } + + // Verify agent was registered + agent, err := db.GetAgentByAgentID("approve-agent-uuid") + if err != nil { + t.Fatalf("Failed to get agent: %v", err) + } + if agent == nil { + t.Error("Expected agent to be created") + } +} + +func TestHandleAgentAuthDeny(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + // Create a pending session + session := &models.AgentSession{ + RequestToken: "deny-test-token", + AgentName: "DenyTestAgent", + AgentID: "deny-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + + // Test denial + body, _ := json.Marshal(map[string]string{"request_token": "deny-test-token"}) + req := httptest.NewRequest(http.MethodPost, "/agent/auth/deny", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.HandleAgentAuthDeny(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify session is now denied + updated, err := db.GetAgentSessionByRequestToken("deny-test-token") + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if updated.Status != "denied" { + t.Errorf("Expected session status 'denied', got '%s'", updated.Status) + } +} + +func TestAgentAuthMiddleware(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg} + + // Create an approved session + sessionToken := "valid-session-token" + sessionExpiry := time.Now().Add(1 * time.Hour) + session := &models.AgentSession{ + RequestToken: "middleware-test-token", + AgentName: "MiddlewareTestAgent", + AgentID: "middleware-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + if err := db.ApproveAgentSession("middleware-test-token", sessionToken, sessionExpiry); err != nil { + t.Fatalf("Failed to approve session: %v", err) + } + + // Create a test handler that the middleware protects + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("protected content")) + }) + + wrappedHandler := h.AgentAuthMiddleware(testHandler) + + tests := []struct { + name string + authHeader string + expectedStatus int + }{ + { + name: "valid token", + authHeader: "Bearer " + sessionToken, + expectedStatus: http.StatusOK, + }, + { + name: "missing header", + authHeader: "", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "invalid format", + authHeader: "Token " + sessionToken, + expectedStatus: http.StatusUnauthorized, + }, + { + name: "invalid token", + authHeader: "Bearer invalid-token", + expectedStatus: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/agent/context", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) + } + }) + } +} + +func TestAgentTrustLevel(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + // Test new agent (never seen before) + trust, err := db.CheckAgentTrust("NewAgent", "new-uuid") + if err != nil { + t.Fatalf("Failed to check trust: %v", err) + } + if trust != models.AgentTrustNew { + t.Errorf("Expected trust level 'new', got '%s'", trust) + } + + // Register the agent + if err := db.CreateOrUpdateAgent("NewAgent", "new-uuid"); err != nil { + t.Fatalf("Failed to register agent: %v", err) + } + + // Test recognized agent (same name + ID) + trust, err = db.CheckAgentTrust("NewAgent", "new-uuid") + if err != nil { + t.Fatalf("Failed to check trust: %v", err) + } + if trust != models.AgentTrustRecognized { + t.Errorf("Expected trust level 'recognized', got '%s'", trust) + } + + // Test suspicious agent (same name, different ID) + trust, err = db.CheckAgentTrust("NewAgent", "different-uuid") + if err != nil { + t.Fatalf("Failed to check trust: %v", err) + } + if trust != models.AgentTrustSuspicious { + t.Errorf("Expected trust level 'suspicious', got '%s'", trust) + } +} + +func TestInvalidatePreviousSessions(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + agentID := "invalidate-test-agent" + + // Create multiple sessions for the same agent + for i := 0; i < 3; i++ { + session := &models.AgentSession{ + RequestToken: "token-" + string(rune('a'+i)), + AgentName: "TestAgent", + AgentID: agentID, + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create session %d: %v", i, err) + } + } + + // Invalidate all sessions for this agent + if err := db.InvalidatePreviousAgentSessions(agentID); err != nil { + t.Fatalf("Failed to invalidate sessions: %v", err) + } + + // Verify all sessions are expired + for _, token := range []string{"token-a", "token-b", "token-c"} { + session, err := db.GetAgentSessionByRequestToken(token) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if session.Status != "expired" { + t.Errorf("Expected session %s to be expired, got '%s'", token, session.Status) + } + } +} + +func TestHandleAgentWebRequest(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + tests := []struct { + name string + queryParams string + expectedStatus int + checkBody func(t *testing.T, body string) + }{ + { + name: "valid request", + queryParams: "?name=TestAgent&agent_id=web-test-uuid", + expectedStatus: http.StatusOK, + checkBody: func(t *testing.T, body string) { + if body == "" { + t.Error("Expected non-empty response body") + } + // Should contain JSON data in script tag + if !contains(body, "application/json") { + t.Error("Expected JSON data in response") + } + if !contains(body, "request_token") { + t.Error("Expected request_token in response") + } + }, + }, + { + name: "missing name", + queryParams: "?agent_id=test-uuid", + expectedStatus: http.StatusBadRequest, + }, + { + name: "missing agent_id", + queryParams: "?name=TestAgent", + expectedStatus: http.StatusBadRequest, + }, + { + name: "missing both", + queryParams: "", + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/agent/web/request"+tt.queryParams, nil) + w := httptest.NewRecorder() + + h.HandleAgentWebRequest(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d: %s", tt.expectedStatus, w.Code, w.Body.String()) + } + + if tt.checkBody != nil && w.Code == http.StatusOK { + tt.checkBody(t, w.Body.String()) + } + }) + } +} + +func TestHandleAgentWebRequestReturnsExistingPending(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + agentID := "reuse-pending-uuid" + + // First request creates a session + req1 := httptest.NewRequest(http.MethodGet, "/agent/web/request?name=TestAgent&agent_id="+agentID, nil) + w1 := httptest.NewRecorder() + h.HandleAgentWebRequest(w1, req1) + + if w1.Code != http.StatusOK { + t.Fatalf("First request failed: %d, body: %s", w1.Code, w1.Body.String()) + } + body1 := w1.Body.String() + + // Verify session was created + if !contains(body1, "request_token") { + t.Fatal("First response doesn't contain request_token") + } + + // Second request should return the same pending session + req2 := httptest.NewRequest(http.MethodGet, "/agent/web/request?name=TestAgent&agent_id="+agentID, nil) + w2 := httptest.NewRecorder() + h.HandleAgentWebRequest(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("Second request failed: %d", w2.Code) + } + body2 := w2.Body.String() + + // Both responses should be identical (same session reused) + if body1 != body2 { + t.Error("Expected same response for existing pending session (session should be reused)") + } +} + +func TestHandleAgentWebStatus(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + // Create a pending session + session := &models.AgentSession{ + RequestToken: "web-status-test-token", + AgentName: "WebStatusTestAgent", + AgentID: "web-status-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + + tests := []struct { + name string + token string + expectedStatus int + checkBody func(t *testing.T, body string) + }{ + { + name: "valid pending session", + token: "web-status-test-token", + expectedStatus: http.StatusOK, + checkBody: func(t *testing.T, body string) { + if !contains(body, `"status": "pending"`) { + t.Error("Expected status 'pending' in response") + } + }, + }, + { + name: "missing token", + token: "", + expectedStatus: http.StatusBadRequest, + }, + { + name: "non-existent token", + token: "non-existent-token", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/agent/web/status" + if tt.token != "" { + url += "?token=" + tt.token + } + + req := httptest.NewRequest(http.MethodGet, url, nil) + w := httptest.NewRecorder() + + h.HandleAgentWebStatus(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d: %s", tt.expectedStatus, w.Code, w.Body.String()) + } + + if tt.checkBody != nil && w.Code == http.StatusOK { + tt.checkBody(t, w.Body.String()) + } + }) + } +} + +func TestHandleAgentWebStatusApproved(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + // Create and approve a session + session := &models.AgentSession{ + RequestToken: "web-approved-test-token", + AgentName: "WebApprovedTestAgent", + AgentID: "web-approved-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + sessionToken := "web-session-token-xyz" + if err := db.ApproveAgentSession("web-approved-test-token", sessionToken, time.Now().Add(1*time.Hour)); err != nil { + t.Fatalf("Failed to approve session: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/agent/web/status?token=web-approved-test-token", nil) + w := httptest.NewRecorder() + + h.HandleAgentWebStatus(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + body := w.Body.String() + if !contains(body, `"status": "approved"`) { + t.Error("Expected status 'approved' in response") + } + if !contains(body, `"session_token": "`+sessionToken) { + t.Error("Expected session_token in response") + } + if !contains(body, `"context_url":`) { + t.Error("Expected context_url in response") + } +} + +func TestHandleAgentWebContext(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + cfg := &config.Config{} + h := &Handler{store: db, config: cfg, templates: loadTestTemplates(t)} + + // Create and approve a session + session := &models.AgentSession{ + RequestToken: "web-context-test-token", + AgentName: "WebContextTestAgent", + AgentID: "web-context-agent-uuid", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + if err := db.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create test session: %v", err) + } + sessionToken := "web-context-session-token" + if err := db.ApproveAgentSession("web-context-test-token", sessionToken, time.Now().Add(1*time.Hour)); err != nil { + t.Fatalf("Failed to approve session: %v", err) + } + + tests := []struct { + name string + sessionToken string + expectedStatus int + checkBody func(t *testing.T, body string) + }{ + { + name: "valid session", + sessionToken: sessionToken, + expectedStatus: http.StatusOK, + checkBody: func(t *testing.T, body string) { + if !contains(body, "generated_at") { + t.Error("Expected generated_at in response") + } + if !contains(body, "timeline") { + t.Error("Expected timeline in response") + } + }, + }, + { + name: "missing session", + sessionToken: "", + expectedStatus: http.StatusBadRequest, + }, + { + name: "invalid session", + sessionToken: "invalid-session-token", + expectedStatus: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/agent/web/context" + if tt.sessionToken != "" { + url += "?session=" + tt.sessionToken + } + + req := httptest.NewRequest(http.MethodGet, url, nil) + w := httptest.NewRecorder() + + h.HandleAgentWebContext(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d: %s", tt.expectedStatus, w.Code, w.Body.String()) + } + + if tt.checkBody != nil && w.Code == http.StatusOK { + tt.checkBody(t, w.Body.String()) + } + }) + } +} + +// contains is a helper function to check if a string contains a substring +func contains(s, substr string) bool { + return bytes.Contains([]byte(s), []byte(substr)) +} |
