From 78e8f597ff28f1b8406f5cfbf934adc22abdf85b Mon Sep 17 00:00:00 2001 From: Peter Stone Date: Tue, 20 Jan 2026 15:18:57 -1000 Subject: Add CSRF protection and auth unit tests Add CSRF token middleware for state-changing request protection, integrate tokens into templates and HTMX headers, and add unit tests for authentication service and handlers. Co-Authored-By: Claude Opus 4.5 --- test/acceptance_test.go | 108 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 79 insertions(+), 29 deletions(-) (limited to 'test') diff --git a/test/acceptance_test.go b/test/acceptance_test.go index ca672c3..8d73d14 100644 --- a/test/acceptance_test.go +++ b/test/acceptance_test.go @@ -3,17 +3,22 @@ package test import ( "encoding/json" "fmt" + "html/template" "io" "net/http" + "net/http/cookiejar" "net/http/httptest" "os" "testing" "time" + "github.com/alexedwards/scs/v2" + "github.com/alexedwards/scs/v2/memstore" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "task-dashboard/internal/api" + "task-dashboard/internal/auth" "task-dashboard/internal/config" "task-dashboard/internal/handlers" "task-dashboard/internal/models" @@ -21,7 +26,7 @@ import ( ) // setupTestServer creates a test HTTP server with all routes -func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, func()) { +func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, *http.Client, func()) { t.Helper() // Create temp database @@ -54,8 +59,20 @@ func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, func()) { // Return to original directory os.Chdir(originalDir) + // Auth setup + sessionManager := scs.New() + sessionManager.Store = memstore.New() + sessionManager.Lifetime = 24 * time.Hour + + authService := auth.NewService(db.DB()) + // Create a dummy template for auth handlers + authTemplates := template.Must(template.New("login.html").Parse("")) + authHandlers := auth.NewHandlers(authService, sessionManager, authTemplates) + + // Ensure default user + authService.EnsureDefaultUser("admin", "password") + // Create mock API clients - // (In real acceptance tests, you'd use test API endpoints or mocks) todoistClient := api.NewTodoistClient("test_key") trelloClient := api.NewTrelloClient("test_key", "test_token") @@ -72,31 +89,55 @@ func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, func()) { r.Use(middleware.Logger) r.Use(middleware.Recoverer) r.Use(middleware.Timeout(60 * time.Second)) + r.Use(sessionManager.LoadAndSave) + + // Test backdoor for login + r.Get("/test/login", func(w http.ResponseWriter, r *http.Request) { + sessionManager.Put(r.Context(), "user_id", int64(1)) + w.WriteHeader(http.StatusOK) + }) - // Routes - r.Get("/", h.HandleDashboard) - r.Post("/api/refresh", h.HandleRefresh) - r.Get("/api/tasks", h.HandleGetTasks) - r.Get("/api/meals", h.HandleGetMeals) - r.Get("/api/boards", h.HandleGetBoards) + // Protected routes + r.Group(func(r chi.Router) { + r.Use(authHandlers.Middleware().RequireAuth) + + r.Get("/", h.HandleDashboard) + r.Post("/api/refresh", h.HandleRefresh) + r.Get("/api/tasks", h.HandleGetTasks) + r.Get("/api/meals", h.HandleGetMeals) + r.Get("/api/boards", h.HandleGetBoards) + }) // Create test server server := httptest.NewServer(r) + // Create client with cookie jar + jar, _ := cookiejar.New(nil) + client := &http.Client{ + Jar: jar, + } + cleanup := func() { server.Close() db.Close() os.Remove(tmpFile.Name()) } - return server, db, cleanup + return server, db, client, cleanup } // TestFullWorkflow tests a complete user workflow func TestFullWorkflow(t *testing.T) { - server, db, cleanup := setupTestServer(t) + server, db, client, cleanup := setupTestServer(t) defer cleanup() + // Login first + resp, err := client.Get(server.URL + "/test/login") + if err != nil { + t.Fatalf("Failed to login: %v", err) + } + resp.Body.Close() + // Seed database with test data testTasks := []models.Task{ { @@ -139,7 +180,7 @@ func TestFullWorkflow(t *testing.T) { // Test 1: GET /api/tasks t.Run("GetTasks", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/tasks") + resp, err := client.Get(server.URL + "/api/tasks") if err != nil { t.Fatalf("Failed to get tasks: %v", err) } @@ -165,7 +206,7 @@ func TestFullWorkflow(t *testing.T) { // Test 2: GET /api/boards t.Run("GetBoards", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/boards") + resp, err := client.Get(server.URL + "/api/boards") if err != nil { t.Fatalf("Failed to get boards: %v", err) } @@ -195,7 +236,7 @@ func TestFullWorkflow(t *testing.T) { // Test 3: POST /api/refresh t.Run("RefreshData", func(t *testing.T) { - resp, err := http.Post(server.URL+"/api/refresh", "application/json", nil) + resp, err := client.Post(server.URL+"/api/refresh", "application/json", nil) if err != nil { t.Fatalf("Failed to refresh: %v", err) } @@ -205,13 +246,10 @@ func TestFullWorkflow(t *testing.T) { body, _ := io.ReadAll(resp.Body) t.Errorf("Expected status 200, got %d: %s", resp.StatusCode, string(body)) } - - // Just verify we got a 200 OK - // The response can be either success message or dashboard data }) t.Run("GetMealsEmpty", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/meals") + resp, err := client.Get(server.URL + "/api/meals") if err != nil { t.Fatalf("Failed to get meals: %v", err) } @@ -233,14 +271,12 @@ func TestFullWorkflow(t *testing.T) { // Test 5: GET / (Dashboard) t.Run("GetDashboard", func(t *testing.T) { - resp, err := http.Get(server.URL + "/") + resp, err := client.Get(server.URL + "/") if err != nil { t.Fatalf("Failed to get dashboard: %v", err) } defer resp.Body.Close() - // Dashboard returns HTML or JSON depending on template availability - // Just verify it responds with 200 or 500 (template not found) if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusInternalServerError { t.Errorf("Expected status 200 or 500, got %d", resp.StatusCode) } @@ -249,9 +285,12 @@ func TestFullWorkflow(t *testing.T) { // TestCaching tests the caching behavior func TestCaching(t *testing.T) { - server, db, cleanup := setupTestServer(t) + server, db, client, cleanup := setupTestServer(t) defer cleanup() + // Login + client.Get(server.URL + "/test/login") + // Seed initial data testTasks := []models.Task{ { @@ -264,7 +303,7 @@ func TestCaching(t *testing.T) { // Test 1: First request should use cache t.Run("UsesCache", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/tasks") + resp, err := client.Get(server.URL + "/api/tasks") if err != nil { t.Fatalf("Failed to get tasks: %v", err) } @@ -281,13 +320,12 @@ func TestCaching(t *testing.T) { // Test 2: Refresh should invalidate cache t.Run("RefreshInvalidatesCache", func(t *testing.T) { // Force refresh - resp, err := http.Post(server.URL+"/api/refresh", "application/json", nil) + resp, err := client.Post(server.URL+"/api/refresh", "application/json", nil) if err != nil { t.Fatalf("Failed to refresh: %v", err) } resp.Body.Close() - // Check cache was updated (this is implicit in the refresh handler) if resp.StatusCode != http.StatusOK { t.Errorf("Expected refresh to succeed, got status %d", resp.StatusCode) } @@ -296,12 +334,15 @@ func TestCaching(t *testing.T) { // TestErrorHandling tests error scenarios func TestErrorHandling(t *testing.T) { - server, _, cleanup := setupTestServer(t) + server, _, client, cleanup := setupTestServer(t) defer cleanup() + // Login + client.Get(server.URL + "/test/login") + // Test 1: Invalid routes should 404 t.Run("InvalidRoute", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/invalid") + resp, err := client.Get(server.URL + "/api/invalid") if err != nil { t.Fatalf("Failed to make request: %v", err) } @@ -314,7 +355,10 @@ func TestErrorHandling(t *testing.T) { // Test 2: Wrong method should 405 t.Run("WrongMethod", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/refresh") + // GET /api/refresh is not allowed (only POST) + // Wait, in main.go it is POST. In setupTestServer it is POST. + // So GET should be 405. + resp, err := client.Get(server.URL + "/api/refresh") if err != nil { t.Fatalf("Failed to make request: %v", err) } @@ -328,9 +372,12 @@ func TestErrorHandling(t *testing.T) { // TestConcurrentRequests tests handling of concurrent requests func TestConcurrentRequests(t *testing.T) { - server, db, cleanup := setupTestServer(t) + server, db, client, cleanup := setupTestServer(t) defer cleanup() + // Login + client.Get(server.URL + "/test/login") + // Seed data testBoards := []models.Board{ {ID: "board1", Name: "Board 1", Cards: []models.Card{}}, @@ -343,9 +390,12 @@ func TestConcurrentRequests(t *testing.T) { done := make(chan bool, numRequests) errors := make(chan error, numRequests) + // Note: We need to use the same client to share the session cookie + // But http.Client is safe for concurrent use. + for i := 0; i < numRequests; i++ { go func(id int) { - resp, err := http.Get(fmt.Sprintf("%s/api/boards", server.URL)) + resp, err := client.Get(fmt.Sprintf("%s/api/boards", server.URL)) if err != nil { errors <- fmt.Errorf("request %d failed: %w", id, err) done <- false -- cgit v1.2.3