From 78e8f597ff28f1b8406f5cfbf934adc22abdf85b Mon Sep 17 00:00:00 2001 From: Peter Stone Date: Tue, 20 Jan 2026 15:18:57 -1000 Subject: Add CSRF protection and auth unit tests Add CSRF token middleware for state-changing request protection, integrate tokens into templates and HTMX headers, and add unit tests for authentication service and handlers. Co-Authored-By: Claude Opus 4.5 --- .gitignore | 3 +- AUDITOR_ROLE.md | 46 ------------------ IMPLEMENTOR_ROLE.md | 50 +++++++++++++++++++ cmd/dashboard/main.go | 4 +- internal/auth/auth_test.go | 108 +++++++++++++++++++++++++++++++++++++++++ internal/auth/handlers.go | 18 ++++--- internal/auth/handlers_test.go | 106 ++++++++++++++++++++++++++++++++++++++++ internal/auth/middleware.go | 73 +++++++++++++++++++++++++++- internal/handlers/handlers.go | 3 ++ test/acceptance_test.go | 108 ++++++++++++++++++++++++++++++----------- web/templates/index.html | 3 +- web/templates/login.html | 1 + 12 files changed, 436 insertions(+), 87 deletions(-) delete mode 100644 AUDITOR_ROLE.md create mode 100644 IMPLEMENTOR_ROLE.md create mode 100644 internal/auth/auth_test.go create mode 100644 internal/auth/handlers_test.go diff --git a/.gitignore b/.gitignore index 3d7b311..673e258 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,7 @@ *.dll *.so *.dylib -cmd/dashboard/dashboard -dashboard + # Test binary, built with `go test -c` *.test diff --git a/AUDITOR_ROLE.md b/AUDITOR_ROLE.md deleted file mode 100644 index 1210a9e..0000000 --- a/AUDITOR_ROLE.md +++ /dev/null @@ -1,46 +0,0 @@ -# Senior Go Architect & Security Lead Persona - -**Role:** You are acting as a **Senior Go Architect and Security Lead**. -**Project Context:** I am building a unified personal dashboard using Go 1.21, SQLite (caching layer), chi router, and HTMX. - -**Shared Standards (CLAUDE.md):** -* **Efficiency:** Prioritize surgical edits over full-file rewrites. -* **Tools:** Use terminal commands (`go test`, `go build`, `grep`) to verify state before planning. -* **Architecture:** Handler -> Store (SQLite) -> API Clients. -* **State:** Maintain `SESSION_STATE.md` as the source of truth for handoffs. - -**Gemini Architect Persona:** -* You are the **Lead Architect**. -* **Constraint:** You **DO NOT** write or edit Project Source Code (e.g., `.go`, `.html`, `.js`). -* **Responsibility:** You **DO** write and update documentation and instruction files (e.g., `SESSION_STATE.md`, `instructions.md`, `issues/*.md`). Your job is to prepare surgical plans for the implementation agent (Claude Code) to execute. -* **Constraint:** If the user rejects a proposed change, do NOT try again - IMMEDIATELY stop and ask for clarification from the user. -* **Known issue:** You cannot access the project's `cmd/dashboard/main.go` entrypoint for an unknown reason. However, the implementation agent CAN. You may give it generic directions (like "remove XXXX dependency from main.go") instead of precise instructions, for this file ONLY. - -**Workflow Instructions:** - -1. **Analyze:** - * When pointed to a task or file, use tools (`read_file`, `grep`, `ls`) to understand the current state. - * Identify specific lines needing fixes based on `SECURITY_CHECKLIST.md` or the current feature requirement. - -2. **Bug Handling Protocol:** - * **Create Issue:** When a bug is identified, create a file in `issues/` (e.g., `issues/bug_00X_description.md`). - * **Document:** Describe the bug, root cause, and a plan to fix it. - * **Reproduction:** ALWAYS include instructions for a reproduction test case (preferably an automated `_test.go` file) in the issue document. - * **State:** Update `SESSION_STATE.md` to track the issue. - -3. **Document:** - * Update `SESSION_STATE.md` with the "Next Steps" and current context. - -4. **Draft Instructions:** - * **DO NOT** output the prompt in the chat. - * **WRITE** the "Surgical Prompt" to a file named `instructions.md`. - * The prompt in `instructions.md` must be concise, include specific file paths, and define the exact logic changes needed for the implementation agent. - * **TDD:** For bugs, instructions must follow a Test-Driven Development approach: Write Test -> Verify Fail -> Fix Code -> Verify Pass. - -**Tool Usage Protocol:** -* **Execution:** When you state you are creating or updating a file (e.g., `instructions.md`, `SESSION_STATE.md`), you **MUST** execute the `write_file` tool. Do not just describe the content; write it to the disk. - -**Self-Improvement:** -* **Meta-Review:** Periodically (e.g., after completing a major phase or encountering friction), suggest refinements to this Role Definition (`ARCHITECT_ROLE.md`) to better align with the user's needs and project workflow. - -**Why we do this:** We are managing token usage and rate limits. By using you to plan and the implementation agent to execute, we ensure work is structured, documented, and smooth. diff --git a/IMPLEMENTOR_ROLE.md b/IMPLEMENTOR_ROLE.md new file mode 100644 index 0000000..62bfd7e --- /dev/null +++ b/IMPLEMENTOR_ROLE.md @@ -0,0 +1,50 @@ +# Senior Go Developer & Implementation Specialist Persona + +**Role:** You are acting as a **Senior Go Developer and Implementation Specialist**. +**Project Context:** Unified personal dashboard using Go 1.21, SQLite (caching layer), chi router, and HTMX. + +**Shared Standards (CLAUDE.md):** +* **Efficiency:** Prioritize surgical edits (`replace_text`) over full-file rewrites. +* **Tools:** Use terminal commands (`go test`, `go build`, `grep`) to verify state before and after changes. +* **Architecture:** Handler -> Store (SQLite) -> API Clients. +* **State:** Respect the direction set in `SESSION_STATE.md`. **CRITICAL:** You are responsible for keeping `SESSION_STATE.md` up-to-date as you complete tasks. + +**Claude Code Implementor Persona:** +* You are the **Implementor**. +* **Constraint:** You focus on **execution**, **coding**, and **verification**. +* **Responsibility:** You **DO** write and edit Project Source Code (e.g., `.go`, `.html`, `.js`). Your job is to execute the surgical plans prepared by the Architect. + +**Workflow Instructions:** + +1. **Ingest & Prioritize:** + * **Check State:** Look at `SESSION_STATE.md`. Focus on items marked `[IN_PROGRESS]` or `[NEEDS_FIX]`. + * **Review Feedback:** If the status is `[NEEDS_FIX]`, read `review_feedback.md` immediately. These are your top priority. + * **New Instructions:** If no fixes are needed, read `instructions.md` for new work. + +2. **Verify Context:** + * Before editing, use `ls`, `read_file`, or `grep` to confirm file paths and the current code state match the instructions. + * If the instructions seem outdated or conflict with the current codebase, stop and ask for clarification. + +3. **Test-Driven Execution (TDD):** + * **Pre-Check:** Run existing tests (`go test ./...`) or the specific reproduction test case provided to confirm the baseline (fail state for bugs, pass state for refactors). + * **Create Test:** If a new feature or complex bug fix is requested, create a `_test.go` file first if one wasn't provided. + +4. **Surgical Execution:** + * **Edit:** Apply changes using `replace_text` whenever possible to minimize token usage and risk of overwriting unrelated code. Use `write_file` only for new files or massive rewrites. + * **Style:** Adhere to Go standard formatting (`gofmt`) and the project's existing style. + +5. **Verify, Update State & Report:** + * **Post-Check:** Run the full suite (`go test ./...`). **CRITICAL:** Ensure new packages have unit tests, and update any existing tests (e.g., acceptance) that fail due to architectural changes. + * **Update State:** IMMEDIATELY after verifying the fix, update `SESSION_STATE.md`. + * Change status from `[IN_PROGRESS]` or `[NEEDS_FIX]` to `[REVIEW_READY]`. + * Update the "Current Status" section to reflect the new state. + * **Cleanup:** Remove temporary test files if they were only for reproduction and not meant to be committed (unless instructed otherwise). + * **Output:** clearly state which files were modified and the result of the verification tests. + +**Tool Usage Protocol:** +* **Terminal:** Use `run_terminal_cmd` for `go test`, `go build`, `go mod tidy`, etc. +* **Editing:** Prefer `replace_text` for targeted edits. + +**Self-Improvement:** +* **Reflection:** After completing a task, ask: "Did I follow TDD? Is the code clean enough that the Reviewer won't find major issues?" +* **Optimization:** Look for ways to make your edits more surgical and less prone to breaking surrounding code. diff --git a/cmd/dashboard/main.go b/cmd/dashboard/main.go index 14664fc..58f954d 100644 --- a/cmd/dashboard/main.go +++ b/cmd/dashboard/main.go @@ -45,7 +45,8 @@ func main() { sessionManager := scs.New() sessionManager.Store = sqlite3store.New(db.DB()) sessionManager.Lifetime = 24 * time.Hour - sessionManager.Cookie.Secure = false // Set to true in production with HTTPS + sessionManager.Cookie.Persist = true + sessionManager.Cookie.Secure = !cfg.Debug sessionManager.Cookie.SameSite = http.SameSiteLaxMode // Initialize auth service @@ -94,6 +95,7 @@ func main() { r.Use(middleware.Recoverer) r.Use(middleware.Timeout(60 * time.Second)) r.Use(sessionManager.LoadAndSave) // Session middleware must be applied globally + r.Use(authHandlers.Middleware().CSRFProtect) // CSRF protection // Public routes (no auth required) r.Get("/login", authHandlers.HandleLoginPage) diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..505efe3 --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,108 @@ +package auth + +import ( + "database/sql" + "errors" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "golang.org/x/crypto/bcrypt" +) + +func TestAuthenticate(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + + password := "secret" + hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(1, "testuser", string(hash), time.Now()) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("testuser"). + WillReturnRows(rows) + + user, err := service.Authenticate("testuser", password) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if user == nil { + t.Errorf("expected user, got nil") + } + if user.Username != "testuser" { + t.Errorf("expected username testuser, got %s", user.Username) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestAuthenticate_InvalidCredentials(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("nonexistent"). + WillReturnError(sql.ErrNoRows) + + _, err = service.Authenticate("nonexistent", "password") + if !errors.Is(err, ErrInvalidCredentials) { + t.Errorf("expected ErrInvalidCredentials, got %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestCreateUser(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + + // Expect check if user exists + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("newuser"). + WillReturnError(sql.ErrNoRows) + + // Expect insert + mock.ExpectExec("INSERT INTO users"). + WithArgs("newuser", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // Expect retrieve created user + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(1, "newuser", "hashedpassword", time.Now()) + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?"). + WithArgs(1). + WillReturnRows(rows) + + user, err := service.CreateUser("newuser", "password") + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if user.Username != "newuser" { + t.Errorf("expected username newuser, got %s", user.Username) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/internal/auth/handlers.go b/internal/auth/handlers.go index 17bcabd..c690d29 100644 --- a/internal/auth/handlers.go +++ b/internal/auth/handlers.go @@ -40,9 +40,11 @@ func (h *Handlers) HandleLoginPage(w http.ResponseWriter, r *http.Request) { } data := struct { - Error string + Error string + CSRFToken string }{ - Error: "", + Error: "", + CSRFToken: h.middleware.GetCSRFToken(r), } if err := h.templates.ExecuteTemplate(w, "login.html", data); err != nil { @@ -62,14 +64,14 @@ func (h *Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) { password := r.FormValue("password") if username == "" || password == "" { - h.renderLoginError(w, "Username and password are required") + h.renderLoginError(w, r, "Username and password are required") return } user, err := h.service.Authenticate(username, password) if err != nil { log.Printf("Login failed for user %s: %v", username, err) - h.renderLoginError(w, "Invalid username or password") + h.renderLoginError(w, r, "Invalid username or password") return } @@ -96,11 +98,13 @@ func (h *Handlers) HandleLogout(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/login", http.StatusSeeOther) } -func (h *Handlers) renderLoginError(w http.ResponseWriter, errorMsg string) { +func (h *Handlers) renderLoginError(w http.ResponseWriter, r *http.Request, errorMsg string) { data := struct { - Error string + Error string + CSRFToken string }{ - Error: errorMsg, + Error: errorMsg, + CSRFToken: h.middleware.GetCSRFToken(r), } w.WriteHeader(http.StatusUnauthorized) diff --git a/internal/auth/handlers_test.go b/internal/auth/handlers_test.go new file mode 100644 index 0000000..3e154ce --- /dev/null +++ b/internal/auth/handlers_test.go @@ -0,0 +1,106 @@ +package auth + +import ( + "database/sql" + "html/template" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/alexedwards/scs/v2" + "golang.org/x/crypto/bcrypt" +) + +func TestHandleLogin(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + sessionManager := scs.New() + templates := template.Must(template.New("login.html").Parse("{{.Error}}")) + + handlers := NewHandlers(service, sessionManager, templates) + + // Setup mock user + password := "password" + hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(1, "testuser", string(hash), time.Now()) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("testuser"). + WillReturnRows(rows) + + // Create request + form := url.Values{} + form.Add("username", "testuser") + form.Add("password", "password") + req := httptest.NewRequest("POST", "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Wrap request with session middleware + ctx, _ := sessionManager.Load(req.Context(), "") + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + + handlers.HandleLogin(rr, req) + + if status := rr.Code; status != http.StatusSeeOther { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusSeeOther) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestHandleLogin_InvalidCredentials(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + sessionManager := scs.New() + templates := template.Must(template.New("login.html").Parse("{{.Error}}")) + + handlers := NewHandlers(service, sessionManager, templates) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("testuser"). + WillReturnError(sql.ErrNoRows) + + // Create request + form := url.Values{} + form.Add("username", "testuser") + form.Add("password", "wrongpassword") + req := httptest.NewRequest("POST", "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Wrap request with session middleware + ctx, _ := sessionManager.Load(req.Context(), "") + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + + handlers.HandleLogin(rr, req) + + if status := rr.Code; status != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusUnauthorized) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 7710328..b440032 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -1,12 +1,22 @@ package auth import ( + "context" + "crypto/rand" + "encoding/base64" "net/http" "github.com/alexedwards/scs/v2" ) -const SessionKeyUserID = "user_id" +const ( + SessionKeyUserID = "user_id" + SessionKeyCSRF = "csrf_token" +) + +type contextKey string + +const ContextKeyCSRF contextKey = "csrf_token" // Middleware provides authentication middleware type Middleware struct { @@ -48,3 +58,64 @@ func (m *Middleware) SetUserID(r *http.Request, userID int64) { func (m *Middleware) ClearSession(r *http.Request) error { return m.sessions.Destroy(r.Context()) } + +// CSRFProtect checks for a valid CSRF token on state-changing requests +func (m *Middleware) CSRFProtect(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Ensure a token exists in the session + if !m.sessions.Exists(r.Context(), SessionKeyCSRF) { + token, err := generateToken() + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + m.sessions.Put(r.Context(), SessionKeyCSRF, token) + } + + token := m.sessions.GetString(r.Context(), SessionKeyCSRF) + + // Check token for state-changing methods + if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" { + requestToken := r.Header.Get("X-CSRF-Token") + + if requestToken == "" { + requestToken = r.FormValue("csrf_token") + } + + if requestToken == "" || requestToken != token { + http.Error(w, "Forbidden - CSRF Token Mismatch", http.StatusForbidden) + return + } + } + + // Add token to context for handlers to use + ctx := context.WithValue(r.Context(), ContextKeyCSRF, token) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetCSRFToken retrieves the CSRF token from the session +func (m *Middleware) GetCSRFToken(r *http.Request) string { + if !m.sessions.Exists(r.Context(), SessionKeyCSRF) { + return "" + } + return m.sessions.GetString(r.Context(), SessionKeyCSRF) +} + +// GetCSRFTokenFromContext retrieves the CSRF token from the context +func GetCSRFTokenFromContext(ctx context.Context) string { + token, ok := ctx.Value(ContextKeyCSRF).(string) + if !ok { + return "" + } + return token +} + +func generateToken() (string, error) { + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 7bb84b9..d52e786 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -14,6 +14,7 @@ import ( "time" "task-dashboard/internal/api" + "task-dashboard/internal/auth" "task-dashboard/internal/config" "task-dashboard/internal/models" "task-dashboard/internal/store" @@ -81,9 +82,11 @@ func (h *Handler) HandleDashboard(w http.ResponseWriter, r *http.Request) { data := struct { *models.DashboardData ActiveTab string + CSRFToken string }{ DashboardData: dashboardData, ActiveTab: tab, + CSRFToken: auth.GetCSRFTokenFromContext(ctx), } if err := h.templates.ExecuteTemplate(w, "index.html", data); err != nil { diff --git a/test/acceptance_test.go b/test/acceptance_test.go index ca672c3..8d73d14 100644 --- a/test/acceptance_test.go +++ b/test/acceptance_test.go @@ -3,17 +3,22 @@ package test import ( "encoding/json" "fmt" + "html/template" "io" "net/http" + "net/http/cookiejar" "net/http/httptest" "os" "testing" "time" + "github.com/alexedwards/scs/v2" + "github.com/alexedwards/scs/v2/memstore" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "task-dashboard/internal/api" + "task-dashboard/internal/auth" "task-dashboard/internal/config" "task-dashboard/internal/handlers" "task-dashboard/internal/models" @@ -21,7 +26,7 @@ import ( ) // setupTestServer creates a test HTTP server with all routes -func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, func()) { +func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, *http.Client, func()) { t.Helper() // Create temp database @@ -54,8 +59,20 @@ func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, func()) { // Return to original directory os.Chdir(originalDir) + // Auth setup + sessionManager := scs.New() + sessionManager.Store = memstore.New() + sessionManager.Lifetime = 24 * time.Hour + + authService := auth.NewService(db.DB()) + // Create a dummy template for auth handlers + authTemplates := template.Must(template.New("login.html").Parse("")) + authHandlers := auth.NewHandlers(authService, sessionManager, authTemplates) + + // Ensure default user + authService.EnsureDefaultUser("admin", "password") + // Create mock API clients - // (In real acceptance tests, you'd use test API endpoints or mocks) todoistClient := api.NewTodoistClient("test_key") trelloClient := api.NewTrelloClient("test_key", "test_token") @@ -72,31 +89,55 @@ func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, func()) { r.Use(middleware.Logger) r.Use(middleware.Recoverer) r.Use(middleware.Timeout(60 * time.Second)) + r.Use(sessionManager.LoadAndSave) + + // Test backdoor for login + r.Get("/test/login", func(w http.ResponseWriter, r *http.Request) { + sessionManager.Put(r.Context(), "user_id", int64(1)) + w.WriteHeader(http.StatusOK) + }) - // Routes - r.Get("/", h.HandleDashboard) - r.Post("/api/refresh", h.HandleRefresh) - r.Get("/api/tasks", h.HandleGetTasks) - r.Get("/api/meals", h.HandleGetMeals) - r.Get("/api/boards", h.HandleGetBoards) + // Protected routes + r.Group(func(r chi.Router) { + r.Use(authHandlers.Middleware().RequireAuth) + + r.Get("/", h.HandleDashboard) + r.Post("/api/refresh", h.HandleRefresh) + r.Get("/api/tasks", h.HandleGetTasks) + r.Get("/api/meals", h.HandleGetMeals) + r.Get("/api/boards", h.HandleGetBoards) + }) // Create test server server := httptest.NewServer(r) + // Create client with cookie jar + jar, _ := cookiejar.New(nil) + client := &http.Client{ + Jar: jar, + } + cleanup := func() { server.Close() db.Close() os.Remove(tmpFile.Name()) } - return server, db, cleanup + return server, db, client, cleanup } // TestFullWorkflow tests a complete user workflow func TestFullWorkflow(t *testing.T) { - server, db, cleanup := setupTestServer(t) + server, db, client, cleanup := setupTestServer(t) defer cleanup() + // Login first + resp, err := client.Get(server.URL + "/test/login") + if err != nil { + t.Fatalf("Failed to login: %v", err) + } + resp.Body.Close() + // Seed database with test data testTasks := []models.Task{ { @@ -139,7 +180,7 @@ func TestFullWorkflow(t *testing.T) { // Test 1: GET /api/tasks t.Run("GetTasks", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/tasks") + resp, err := client.Get(server.URL + "/api/tasks") if err != nil { t.Fatalf("Failed to get tasks: %v", err) } @@ -165,7 +206,7 @@ func TestFullWorkflow(t *testing.T) { // Test 2: GET /api/boards t.Run("GetBoards", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/boards") + resp, err := client.Get(server.URL + "/api/boards") if err != nil { t.Fatalf("Failed to get boards: %v", err) } @@ -195,7 +236,7 @@ func TestFullWorkflow(t *testing.T) { // Test 3: POST /api/refresh t.Run("RefreshData", func(t *testing.T) { - resp, err := http.Post(server.URL+"/api/refresh", "application/json", nil) + resp, err := client.Post(server.URL+"/api/refresh", "application/json", nil) if err != nil { t.Fatalf("Failed to refresh: %v", err) } @@ -205,13 +246,10 @@ func TestFullWorkflow(t *testing.T) { body, _ := io.ReadAll(resp.Body) t.Errorf("Expected status 200, got %d: %s", resp.StatusCode, string(body)) } - - // Just verify we got a 200 OK - // The response can be either success message or dashboard data }) t.Run("GetMealsEmpty", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/meals") + resp, err := client.Get(server.URL + "/api/meals") if err != nil { t.Fatalf("Failed to get meals: %v", err) } @@ -233,14 +271,12 @@ func TestFullWorkflow(t *testing.T) { // Test 5: GET / (Dashboard) t.Run("GetDashboard", func(t *testing.T) { - resp, err := http.Get(server.URL + "/") + resp, err := client.Get(server.URL + "/") if err != nil { t.Fatalf("Failed to get dashboard: %v", err) } defer resp.Body.Close() - // Dashboard returns HTML or JSON depending on template availability - // Just verify it responds with 200 or 500 (template not found) if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusInternalServerError { t.Errorf("Expected status 200 or 500, got %d", resp.StatusCode) } @@ -249,9 +285,12 @@ func TestFullWorkflow(t *testing.T) { // TestCaching tests the caching behavior func TestCaching(t *testing.T) { - server, db, cleanup := setupTestServer(t) + server, db, client, cleanup := setupTestServer(t) defer cleanup() + // Login + client.Get(server.URL + "/test/login") + // Seed initial data testTasks := []models.Task{ { @@ -264,7 +303,7 @@ func TestCaching(t *testing.T) { // Test 1: First request should use cache t.Run("UsesCache", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/tasks") + resp, err := client.Get(server.URL + "/api/tasks") if err != nil { t.Fatalf("Failed to get tasks: %v", err) } @@ -281,13 +320,12 @@ func TestCaching(t *testing.T) { // Test 2: Refresh should invalidate cache t.Run("RefreshInvalidatesCache", func(t *testing.T) { // Force refresh - resp, err := http.Post(server.URL+"/api/refresh", "application/json", nil) + resp, err := client.Post(server.URL+"/api/refresh", "application/json", nil) if err != nil { t.Fatalf("Failed to refresh: %v", err) } resp.Body.Close() - // Check cache was updated (this is implicit in the refresh handler) if resp.StatusCode != http.StatusOK { t.Errorf("Expected refresh to succeed, got status %d", resp.StatusCode) } @@ -296,12 +334,15 @@ func TestCaching(t *testing.T) { // TestErrorHandling tests error scenarios func TestErrorHandling(t *testing.T) { - server, _, cleanup := setupTestServer(t) + server, _, client, cleanup := setupTestServer(t) defer cleanup() + // Login + client.Get(server.URL + "/test/login") + // Test 1: Invalid routes should 404 t.Run("InvalidRoute", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/invalid") + resp, err := client.Get(server.URL + "/api/invalid") if err != nil { t.Fatalf("Failed to make request: %v", err) } @@ -314,7 +355,10 @@ func TestErrorHandling(t *testing.T) { // Test 2: Wrong method should 405 t.Run("WrongMethod", func(t *testing.T) { - resp, err := http.Get(server.URL + "/api/refresh") + // GET /api/refresh is not allowed (only POST) + // Wait, in main.go it is POST. In setupTestServer it is POST. + // So GET should be 405. + resp, err := client.Get(server.URL + "/api/refresh") if err != nil { t.Fatalf("Failed to make request: %v", err) } @@ -328,9 +372,12 @@ func TestErrorHandling(t *testing.T) { // TestConcurrentRequests tests handling of concurrent requests func TestConcurrentRequests(t *testing.T) { - server, db, cleanup := setupTestServer(t) + server, db, client, cleanup := setupTestServer(t) defer cleanup() + // Login + client.Get(server.URL + "/test/login") + // Seed data testBoards := []models.Board{ {ID: "board1", Name: "Board 1", Cards: []models.Card{}}, @@ -343,9 +390,12 @@ func TestConcurrentRequests(t *testing.T) { done := make(chan bool, numRequests) errors := make(chan error, numRequests) + // Note: We need to use the same client to share the session cookie + // But http.Client is safe for concurrent use. + for i := 0; i < numRequests; i++ { go func(id int) { - resp, err := http.Get(fmt.Sprintf("%s/api/boards", server.URL)) + resp, err := client.Get(fmt.Sprintf("%s/api/boards", server.URL)) if err != nil { errors <- fmt.Errorf("request %d failed: %w", id, err) done <- false diff --git a/web/templates/index.html b/web/templates/index.html index 54bb0c6..c270b48 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -6,7 +6,7 @@ Personal Dashboard - +
@@ -20,6 +20,7 @@ Refresh
+