summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-01-20 15:18:57 -1000
committerPeter Stone <thepeterstone@gmail.com>2026-01-20 15:18:57 -1000
commit78e8f597ff28f1b8406f5cfbf934adc22abdf85b (patch)
treef3b7dfff2c460e2d8752b61c131e80a73fa6b08d
parent08bbcf18b1207153983261652b4a43a9b36f386c (diff)
Add CSRF protection and auth unit tests
Add CSRF token middleware for state-changing request protection, integrate tokens into templates and HTMX headers, and add unit tests for authentication service and handlers. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
-rw-r--r--.gitignore3
-rw-r--r--AUDITOR_ROLE.md46
-rw-r--r--IMPLEMENTOR_ROLE.md50
-rw-r--r--cmd/dashboard/main.go4
-rw-r--r--internal/auth/auth_test.go108
-rw-r--r--internal/auth/handlers.go18
-rw-r--r--internal/auth/handlers_test.go106
-rw-r--r--internal/auth/middleware.go73
-rw-r--r--internal/handlers/handlers.go3
-rw-r--r--test/acceptance_test.go108
-rw-r--r--web/templates/index.html3
-rw-r--r--web/templates/login.html1
12 files changed, 436 insertions, 87 deletions
diff --git a/.gitignore b/.gitignore
index 3d7b311..673e258 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,8 +4,7 @@
*.dll
*.so
*.dylib
-cmd/dashboard/dashboard
-dashboard
+
# Test binary, built with `go test -c`
*.test
diff --git a/AUDITOR_ROLE.md b/AUDITOR_ROLE.md
deleted file mode 100644
index 1210a9e..0000000
--- a/AUDITOR_ROLE.md
+++ /dev/null
@@ -1,46 +0,0 @@
-# Senior Go Architect & Security Lead Persona
-
-**Role:** You are acting as a **Senior Go Architect and Security Lead**.
-**Project Context:** I am building a unified personal dashboard using Go 1.21, SQLite (caching layer), chi router, and HTMX.
-
-**Shared Standards (CLAUDE.md):**
-* **Efficiency:** Prioritize surgical edits over full-file rewrites.
-* **Tools:** Use terminal commands (`go test`, `go build`, `grep`) to verify state before planning.
-* **Architecture:** Handler -> Store (SQLite) -> API Clients.
-* **State:** Maintain `SESSION_STATE.md` as the source of truth for handoffs.
-
-**Gemini Architect Persona:**
-* You are the **Lead Architect**.
-* **Constraint:** You **DO NOT** write or edit Project Source Code (e.g., `.go`, `.html`, `.js`).
-* **Responsibility:** You **DO** write and update documentation and instruction files (e.g., `SESSION_STATE.md`, `instructions.md`, `issues/*.md`). Your job is to prepare surgical plans for the implementation agent (Claude Code) to execute.
-* **Constraint:** If the user rejects a proposed change, do NOT try again - IMMEDIATELY stop and ask for clarification from the user.
-* **Known issue:** You cannot access the project's `cmd/dashboard/main.go` entrypoint for an unknown reason. However, the implementation agent CAN. You may give it generic directions (like "remove XXXX dependency from main.go") instead of precise instructions, for this file ONLY.
-
-**Workflow Instructions:**
-
-1. **Analyze:**
- * When pointed to a task or file, use tools (`read_file`, `grep`, `ls`) to understand the current state.
- * Identify specific lines needing fixes based on `SECURITY_CHECKLIST.md` or the current feature requirement.
-
-2. **Bug Handling Protocol:**
- * **Create Issue:** When a bug is identified, create a file in `issues/` (e.g., `issues/bug_00X_description.md`).
- * **Document:** Describe the bug, root cause, and a plan to fix it.
- * **Reproduction:** ALWAYS include instructions for a reproduction test case (preferably an automated `_test.go` file) in the issue document.
- * **State:** Update `SESSION_STATE.md` to track the issue.
-
-3. **Document:**
- * Update `SESSION_STATE.md` with the "Next Steps" and current context.
-
-4. **Draft Instructions:**
- * **DO NOT** output the prompt in the chat.
- * **WRITE** the "Surgical Prompt" to a file named `instructions.md`.
- * The prompt in `instructions.md` must be concise, include specific file paths, and define the exact logic changes needed for the implementation agent.
- * **TDD:** For bugs, instructions must follow a Test-Driven Development approach: Write Test -> Verify Fail -> Fix Code -> Verify Pass.
-
-**Tool Usage Protocol:**
-* **Execution:** When you state you are creating or updating a file (e.g., `instructions.md`, `SESSION_STATE.md`), you **MUST** execute the `write_file` tool. Do not just describe the content; write it to the disk.
-
-**Self-Improvement:**
-* **Meta-Review:** Periodically (e.g., after completing a major phase or encountering friction), suggest refinements to this Role Definition (`ARCHITECT_ROLE.md`) to better align with the user's needs and project workflow.
-
-**Why we do this:** We are managing token usage and rate limits. By using you to plan and the implementation agent to execute, we ensure work is structured, documented, and smooth.
diff --git a/IMPLEMENTOR_ROLE.md b/IMPLEMENTOR_ROLE.md
new file mode 100644
index 0000000..62bfd7e
--- /dev/null
+++ b/IMPLEMENTOR_ROLE.md
@@ -0,0 +1,50 @@
+# Senior Go Developer & Implementation Specialist Persona
+
+**Role:** You are acting as a **Senior Go Developer and Implementation Specialist**.
+**Project Context:** Unified personal dashboard using Go 1.21, SQLite (caching layer), chi router, and HTMX.
+
+**Shared Standards (CLAUDE.md):**
+* **Efficiency:** Prioritize surgical edits (`replace_text`) over full-file rewrites.
+* **Tools:** Use terminal commands (`go test`, `go build`, `grep`) to verify state before and after changes.
+* **Architecture:** Handler -> Store (SQLite) -> API Clients.
+* **State:** Respect the direction set in `SESSION_STATE.md`. **CRITICAL:** You are responsible for keeping `SESSION_STATE.md` up-to-date as you complete tasks.
+
+**Claude Code Implementor Persona:**
+* You are the **Implementor**.
+* **Constraint:** You focus on **execution**, **coding**, and **verification**.
+* **Responsibility:** You **DO** write and edit Project Source Code (e.g., `.go`, `.html`, `.js`). Your job is to execute the surgical plans prepared by the Architect.
+
+**Workflow Instructions:**
+
+1. **Ingest & Prioritize:**
+ * **Check State:** Look at `SESSION_STATE.md`. Focus on items marked `[IN_PROGRESS]` or `[NEEDS_FIX]`.
+ * **Review Feedback:** If the status is `[NEEDS_FIX]`, read `review_feedback.md` immediately. These are your top priority.
+ * **New Instructions:** If no fixes are needed, read `instructions.md` for new work.
+
+2. **Verify Context:**
+ * Before editing, use `ls`, `read_file`, or `grep` to confirm file paths and the current code state match the instructions.
+ * If the instructions seem outdated or conflict with the current codebase, stop and ask for clarification.
+
+3. **Test-Driven Execution (TDD):**
+ * **Pre-Check:** Run existing tests (`go test ./...`) or the specific reproduction test case provided to confirm the baseline (fail state for bugs, pass state for refactors).
+ * **Create Test:** If a new feature or complex bug fix is requested, create a `_test.go` file first if one wasn't provided.
+
+4. **Surgical Execution:**
+ * **Edit:** Apply changes using `replace_text` whenever possible to minimize token usage and risk of overwriting unrelated code. Use `write_file` only for new files or massive rewrites.
+ * **Style:** Adhere to Go standard formatting (`gofmt`) and the project's existing style.
+
+5. **Verify, Update State & Report:**
+ * **Post-Check:** Run the full suite (`go test ./...`). **CRITICAL:** Ensure new packages have unit tests, and update any existing tests (e.g., acceptance) that fail due to architectural changes.
+ * **Update State:** IMMEDIATELY after verifying the fix, update `SESSION_STATE.md`.
+ * Change status from `[IN_PROGRESS]` or `[NEEDS_FIX]` to `[REVIEW_READY]`.
+ * Update the "Current Status" section to reflect the new state.
+ * **Cleanup:** Remove temporary test files if they were only for reproduction and not meant to be committed (unless instructed otherwise).
+ * **Output:** clearly state which files were modified and the result of the verification tests.
+
+**Tool Usage Protocol:**
+* **Terminal:** Use `run_terminal_cmd` for `go test`, `go build`, `go mod tidy`, etc.
+* **Editing:** Prefer `replace_text` for targeted edits.
+
+**Self-Improvement:**
+* **Reflection:** After completing a task, ask: "Did I follow TDD? Is the code clean enough that the Reviewer won't find major issues?"
+* **Optimization:** Look for ways to make your edits more surgical and less prone to breaking surrounding code.
diff --git a/cmd/dashboard/main.go b/cmd/dashboard/main.go
index 14664fc..58f954d 100644
--- a/cmd/dashboard/main.go
+++ b/cmd/dashboard/main.go
@@ -45,7 +45,8 @@ func main() {
sessionManager := scs.New()
sessionManager.Store = sqlite3store.New(db.DB())
sessionManager.Lifetime = 24 * time.Hour
- sessionManager.Cookie.Secure = false // Set to true in production with HTTPS
+ sessionManager.Cookie.Persist = true
+ sessionManager.Cookie.Secure = !cfg.Debug
sessionManager.Cookie.SameSite = http.SameSiteLaxMode
// Initialize auth service
@@ -94,6 +95,7 @@ func main() {
r.Use(middleware.Recoverer)
r.Use(middleware.Timeout(60 * time.Second))
r.Use(sessionManager.LoadAndSave) // Session middleware must be applied globally
+ r.Use(authHandlers.Middleware().CSRFProtect) // CSRF protection
// Public routes (no auth required)
r.Get("/login", authHandlers.HandleLoginPage)
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
new file mode 100644
index 0000000..505efe3
--- /dev/null
+++ b/internal/auth/auth_test.go
@@ -0,0 +1,108 @@
+package auth
+
+import (
+ "database/sql"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "golang.org/x/crypto/bcrypt"
+)
+
+func TestAuthenticate(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+
+ password := "secret"
+ hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(1, "testuser", string(hash), time.Now())
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("testuser").
+ WillReturnRows(rows)
+
+ user, err := service.Authenticate("testuser", password)
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+ if user == nil {
+ t.Errorf("expected user, got nil")
+ }
+ if user.Username != "testuser" {
+ t.Errorf("expected username testuser, got %s", user.Username)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
+
+func TestAuthenticate_InvalidCredentials(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("nonexistent").
+ WillReturnError(sql.ErrNoRows)
+
+ _, err = service.Authenticate("nonexistent", "password")
+ if !errors.Is(err, ErrInvalidCredentials) {
+ t.Errorf("expected ErrInvalidCredentials, got %v", err)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
+
+func TestCreateUser(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+
+ // Expect check if user exists
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("newuser").
+ WillReturnError(sql.ErrNoRows)
+
+ // Expect insert
+ mock.ExpectExec("INSERT INTO users").
+ WithArgs("newuser", sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+
+ // Expect retrieve created user
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(1, "newuser", "hashedpassword", time.Now())
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?").
+ WithArgs(1).
+ WillReturnRows(rows)
+
+ user, err := service.CreateUser("newuser", "password")
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+ if user.Username != "newuser" {
+ t.Errorf("expected username newuser, got %s", user.Username)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
diff --git a/internal/auth/handlers.go b/internal/auth/handlers.go
index 17bcabd..c690d29 100644
--- a/internal/auth/handlers.go
+++ b/internal/auth/handlers.go
@@ -40,9 +40,11 @@ func (h *Handlers) HandleLoginPage(w http.ResponseWriter, r *http.Request) {
}
data := struct {
- Error string
+ Error string
+ CSRFToken string
}{
- Error: "",
+ Error: "",
+ CSRFToken: h.middleware.GetCSRFToken(r),
}
if err := h.templates.ExecuteTemplate(w, "login.html", data); err != nil {
@@ -62,14 +64,14 @@ func (h *Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) {
password := r.FormValue("password")
if username == "" || password == "" {
- h.renderLoginError(w, "Username and password are required")
+ h.renderLoginError(w, r, "Username and password are required")
return
}
user, err := h.service.Authenticate(username, password)
if err != nil {
log.Printf("Login failed for user %s: %v", username, err)
- h.renderLoginError(w, "Invalid username or password")
+ h.renderLoginError(w, r, "Invalid username or password")
return
}
@@ -96,11 +98,13 @@ func (h *Handlers) HandleLogout(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
-func (h *Handlers) renderLoginError(w http.ResponseWriter, errorMsg string) {
+func (h *Handlers) renderLoginError(w http.ResponseWriter, r *http.Request, errorMsg string) {
data := struct {
- Error string
+ Error string
+ CSRFToken string
}{
- Error: errorMsg,
+ Error: errorMsg,
+ CSRFToken: h.middleware.GetCSRFToken(r),
}
w.WriteHeader(http.StatusUnauthorized)
diff --git a/internal/auth/handlers_test.go b/internal/auth/handlers_test.go
new file mode 100644
index 0000000..3e154ce
--- /dev/null
+++ b/internal/auth/handlers_test.go
@@ -0,0 +1,106 @@
+package auth
+
+import (
+ "database/sql"
+ "html/template"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "github.com/alexedwards/scs/v2"
+ "golang.org/x/crypto/bcrypt"
+)
+
+func TestHandleLogin(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+ sessionManager := scs.New()
+ templates := template.Must(template.New("login.html").Parse("{{.Error}}"))
+
+ handlers := NewHandlers(service, sessionManager, templates)
+
+ // Setup mock user
+ password := "password"
+ hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(1, "testuser", string(hash), time.Now())
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("testuser").
+ WillReturnRows(rows)
+
+ // Create request
+ form := url.Values{}
+ form.Add("username", "testuser")
+ form.Add("password", "password")
+ req := httptest.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ // Wrap request with session middleware
+ ctx, _ := sessionManager.Load(req.Context(), "")
+ req = req.WithContext(ctx)
+
+ rr := httptest.NewRecorder()
+
+ handlers.HandleLogin(rr, req)
+
+ if status := rr.Code; status != http.StatusSeeOther {
+ t.Errorf("handler returned wrong status code: got %v want %v",
+ status, http.StatusSeeOther)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
+
+func TestHandleLogin_InvalidCredentials(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+ sessionManager := scs.New()
+ templates := template.Must(template.New("login.html").Parse("{{.Error}}"))
+
+ handlers := NewHandlers(service, sessionManager, templates)
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("testuser").
+ WillReturnError(sql.ErrNoRows)
+
+ // Create request
+ form := url.Values{}
+ form.Add("username", "testuser")
+ form.Add("password", "wrongpassword")
+ req := httptest.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ // Wrap request with session middleware
+ ctx, _ := sessionManager.Load(req.Context(), "")
+ req = req.WithContext(ctx)
+
+ rr := httptest.NewRecorder()
+
+ handlers.HandleLogin(rr, req)
+
+ if status := rr.Code; status != http.StatusUnauthorized {
+ t.Errorf("handler returned wrong status code: got %v want %v",
+ status, http.StatusUnauthorized)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go
index 7710328..b440032 100644
--- a/internal/auth/middleware.go
+++ b/internal/auth/middleware.go
@@ -1,12 +1,22 @@
package auth
import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
"net/http"
"github.com/alexedwards/scs/v2"
)
-const SessionKeyUserID = "user_id"
+const (
+ SessionKeyUserID = "user_id"
+ SessionKeyCSRF = "csrf_token"
+)
+
+type contextKey string
+
+const ContextKeyCSRF contextKey = "csrf_token"
// Middleware provides authentication middleware
type Middleware struct {
@@ -48,3 +58,64 @@ func (m *Middleware) SetUserID(r *http.Request, userID int64) {
func (m *Middleware) ClearSession(r *http.Request) error {
return m.sessions.Destroy(r.Context())
}
+
+// CSRFProtect checks for a valid CSRF token on state-changing requests
+func (m *Middleware) CSRFProtect(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Ensure a token exists in the session
+ if !m.sessions.Exists(r.Context(), SessionKeyCSRF) {
+ token, err := generateToken()
+ if err != nil {
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ m.sessions.Put(r.Context(), SessionKeyCSRF, token)
+ }
+
+ token := m.sessions.GetString(r.Context(), SessionKeyCSRF)
+
+ // Check token for state-changing methods
+ if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" {
+ requestToken := r.Header.Get("X-CSRF-Token")
+
+ if requestToken == "" {
+ requestToken = r.FormValue("csrf_token")
+ }
+
+ if requestToken == "" || requestToken != token {
+ http.Error(w, "Forbidden - CSRF Token Mismatch", http.StatusForbidden)
+ return
+ }
+ }
+
+ // Add token to context for handlers to use
+ ctx := context.WithValue(r.Context(), ContextKeyCSRF, token)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+}
+
+// GetCSRFToken retrieves the CSRF token from the session
+func (m *Middleware) GetCSRFToken(r *http.Request) string {
+ if !m.sessions.Exists(r.Context(), SessionKeyCSRF) {
+ return ""
+ }
+ return m.sessions.GetString(r.Context(), SessionKeyCSRF)
+}
+
+// GetCSRFTokenFromContext retrieves the CSRF token from the context
+func GetCSRFTokenFromContext(ctx context.Context) string {
+ token, ok := ctx.Value(ContextKeyCSRF).(string)
+ if !ok {
+ return ""
+ }
+ return token
+}
+
+func generateToken() (string, error) {
+ b := make([]byte, 32)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ return base64.URLEncoding.EncodeToString(b), nil
+}
diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go
index 7bb84b9..d52e786 100644
--- a/internal/handlers/handlers.go
+++ b/internal/handlers/handlers.go
@@ -14,6 +14,7 @@ import (
"time"
"task-dashboard/internal/api"
+ "task-dashboard/internal/auth"
"task-dashboard/internal/config"
"task-dashboard/internal/models"
"task-dashboard/internal/store"
@@ -81,9 +82,11 @@ func (h *Handler) HandleDashboard(w http.ResponseWriter, r *http.Request) {
data := struct {
*models.DashboardData
ActiveTab string
+ CSRFToken string
}{
DashboardData: dashboardData,
ActiveTab: tab,
+ CSRFToken: auth.GetCSRFTokenFromContext(ctx),
}
if err := h.templates.ExecuteTemplate(w, "index.html", data); err != nil {
diff --git a/test/acceptance_test.go b/test/acceptance_test.go
index ca672c3..8d73d14 100644
--- a/test/acceptance_test.go
+++ b/test/acceptance_test.go
@@ -3,17 +3,22 @@ package test
import (
"encoding/json"
"fmt"
+ "html/template"
"io"
"net/http"
+ "net/http/cookiejar"
"net/http/httptest"
"os"
"testing"
"time"
+ "github.com/alexedwards/scs/v2"
+ "github.com/alexedwards/scs/v2/memstore"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"task-dashboard/internal/api"
+ "task-dashboard/internal/auth"
"task-dashboard/internal/config"
"task-dashboard/internal/handlers"
"task-dashboard/internal/models"
@@ -21,7 +26,7 @@ import (
)
// setupTestServer creates a test HTTP server with all routes
-func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, func()) {
+func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, *http.Client, func()) {
t.Helper()
// Create temp database
@@ -54,8 +59,20 @@ func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, func()) {
// Return to original directory
os.Chdir(originalDir)
+ // Auth setup
+ sessionManager := scs.New()
+ sessionManager.Store = memstore.New()
+ sessionManager.Lifetime = 24 * time.Hour
+
+ authService := auth.NewService(db.DB())
+ // Create a dummy template for auth handlers
+ authTemplates := template.Must(template.New("login.html").Parse(""))
+ authHandlers := auth.NewHandlers(authService, sessionManager, authTemplates)
+
+ // Ensure default user
+ authService.EnsureDefaultUser("admin", "password")
+
// Create mock API clients
- // (In real acceptance tests, you'd use test API endpoints or mocks)
todoistClient := api.NewTodoistClient("test_key")
trelloClient := api.NewTrelloClient("test_key", "test_token")
@@ -72,31 +89,55 @@ func setupTestServer(t *testing.T) (*httptest.Server, *store.Store, func()) {
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Use(middleware.Timeout(60 * time.Second))
+ r.Use(sessionManager.LoadAndSave)
+
+ // Test backdoor for login
+ r.Get("/test/login", func(w http.ResponseWriter, r *http.Request) {
+ sessionManager.Put(r.Context(), "user_id", int64(1))
+ w.WriteHeader(http.StatusOK)
+ })
- // Routes
- r.Get("/", h.HandleDashboard)
- r.Post("/api/refresh", h.HandleRefresh)
- r.Get("/api/tasks", h.HandleGetTasks)
- r.Get("/api/meals", h.HandleGetMeals)
- r.Get("/api/boards", h.HandleGetBoards)
+ // Protected routes
+ r.Group(func(r chi.Router) {
+ r.Use(authHandlers.Middleware().RequireAuth)
+
+ r.Get("/", h.HandleDashboard)
+ r.Post("/api/refresh", h.HandleRefresh)
+ r.Get("/api/tasks", h.HandleGetTasks)
+ r.Get("/api/meals", h.HandleGetMeals)
+ r.Get("/api/boards", h.HandleGetBoards)
+ })
// Create test server
server := httptest.NewServer(r)
+ // Create client with cookie jar
+ jar, _ := cookiejar.New(nil)
+ client := &http.Client{
+ Jar: jar,
+ }
+
cleanup := func() {
server.Close()
db.Close()
os.Remove(tmpFile.Name())
}
- return server, db, cleanup
+ return server, db, client, cleanup
}
// TestFullWorkflow tests a complete user workflow
func TestFullWorkflow(t *testing.T) {
- server, db, cleanup := setupTestServer(t)
+ server, db, client, cleanup := setupTestServer(t)
defer cleanup()
+ // Login first
+ resp, err := client.Get(server.URL + "/test/login")
+ if err != nil {
+ t.Fatalf("Failed to login: %v", err)
+ }
+ resp.Body.Close()
+
// Seed database with test data
testTasks := []models.Task{
{
@@ -139,7 +180,7 @@ func TestFullWorkflow(t *testing.T) {
// Test 1: GET /api/tasks
t.Run("GetTasks", func(t *testing.T) {
- resp, err := http.Get(server.URL + "/api/tasks")
+ resp, err := client.Get(server.URL + "/api/tasks")
if err != nil {
t.Fatalf("Failed to get tasks: %v", err)
}
@@ -165,7 +206,7 @@ func TestFullWorkflow(t *testing.T) {
// Test 2: GET /api/boards
t.Run("GetBoards", func(t *testing.T) {
- resp, err := http.Get(server.URL + "/api/boards")
+ resp, err := client.Get(server.URL + "/api/boards")
if err != nil {
t.Fatalf("Failed to get boards: %v", err)
}
@@ -195,7 +236,7 @@ func TestFullWorkflow(t *testing.T) {
// Test 3: POST /api/refresh
t.Run("RefreshData", func(t *testing.T) {
- resp, err := http.Post(server.URL+"/api/refresh", "application/json", nil)
+ resp, err := client.Post(server.URL+"/api/refresh", "application/json", nil)
if err != nil {
t.Fatalf("Failed to refresh: %v", err)
}
@@ -205,13 +246,10 @@ func TestFullWorkflow(t *testing.T) {
body, _ := io.ReadAll(resp.Body)
t.Errorf("Expected status 200, got %d: %s", resp.StatusCode, string(body))
}
-
- // Just verify we got a 200 OK
- // The response can be either success message or dashboard data
})
t.Run("GetMealsEmpty", func(t *testing.T) {
- resp, err := http.Get(server.URL + "/api/meals")
+ resp, err := client.Get(server.URL + "/api/meals")
if err != nil {
t.Fatalf("Failed to get meals: %v", err)
}
@@ -233,14 +271,12 @@ func TestFullWorkflow(t *testing.T) {
// Test 5: GET / (Dashboard)
t.Run("GetDashboard", func(t *testing.T) {
- resp, err := http.Get(server.URL + "/")
+ resp, err := client.Get(server.URL + "/")
if err != nil {
t.Fatalf("Failed to get dashboard: %v", err)
}
defer resp.Body.Close()
- // Dashboard returns HTML or JSON depending on template availability
- // Just verify it responds with 200 or 500 (template not found)
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusInternalServerError {
t.Errorf("Expected status 200 or 500, got %d", resp.StatusCode)
}
@@ -249,9 +285,12 @@ func TestFullWorkflow(t *testing.T) {
// TestCaching tests the caching behavior
func TestCaching(t *testing.T) {
- server, db, cleanup := setupTestServer(t)
+ server, db, client, cleanup := setupTestServer(t)
defer cleanup()
+ // Login
+ client.Get(server.URL + "/test/login")
+
// Seed initial data
testTasks := []models.Task{
{
@@ -264,7 +303,7 @@ func TestCaching(t *testing.T) {
// Test 1: First request should use cache
t.Run("UsesCache", func(t *testing.T) {
- resp, err := http.Get(server.URL + "/api/tasks")
+ resp, err := client.Get(server.URL + "/api/tasks")
if err != nil {
t.Fatalf("Failed to get tasks: %v", err)
}
@@ -281,13 +320,12 @@ func TestCaching(t *testing.T) {
// Test 2: Refresh should invalidate cache
t.Run("RefreshInvalidatesCache", func(t *testing.T) {
// Force refresh
- resp, err := http.Post(server.URL+"/api/refresh", "application/json", nil)
+ resp, err := client.Post(server.URL+"/api/refresh", "application/json", nil)
if err != nil {
t.Fatalf("Failed to refresh: %v", err)
}
resp.Body.Close()
- // Check cache was updated (this is implicit in the refresh handler)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected refresh to succeed, got status %d", resp.StatusCode)
}
@@ -296,12 +334,15 @@ func TestCaching(t *testing.T) {
// TestErrorHandling tests error scenarios
func TestErrorHandling(t *testing.T) {
- server, _, cleanup := setupTestServer(t)
+ server, _, client, cleanup := setupTestServer(t)
defer cleanup()
+ // Login
+ client.Get(server.URL + "/test/login")
+
// Test 1: Invalid routes should 404
t.Run("InvalidRoute", func(t *testing.T) {
- resp, err := http.Get(server.URL + "/api/invalid")
+ resp, err := client.Get(server.URL + "/api/invalid")
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
@@ -314,7 +355,10 @@ func TestErrorHandling(t *testing.T) {
// Test 2: Wrong method should 405
t.Run("WrongMethod", func(t *testing.T) {
- resp, err := http.Get(server.URL + "/api/refresh")
+ // GET /api/refresh is not allowed (only POST)
+ // Wait, in main.go it is POST. In setupTestServer it is POST.
+ // So GET should be 405.
+ resp, err := client.Get(server.URL + "/api/refresh")
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
@@ -328,9 +372,12 @@ func TestErrorHandling(t *testing.T) {
// TestConcurrentRequests tests handling of concurrent requests
func TestConcurrentRequests(t *testing.T) {
- server, db, cleanup := setupTestServer(t)
+ server, db, client, cleanup := setupTestServer(t)
defer cleanup()
+ // Login
+ client.Get(server.URL + "/test/login")
+
// Seed data
testBoards := []models.Board{
{ID: "board1", Name: "Board 1", Cards: []models.Card{}},
@@ -343,9 +390,12 @@ func TestConcurrentRequests(t *testing.T) {
done := make(chan bool, numRequests)
errors := make(chan error, numRequests)
+ // Note: We need to use the same client to share the session cookie
+ // But http.Client is safe for concurrent use.
+
for i := 0; i < numRequests; i++ {
go func(id int) {
- resp, err := http.Get(fmt.Sprintf("%s/api/boards", server.URL))
+ resp, err := client.Get(fmt.Sprintf("%s/api/boards", server.URL))
if err != nil {
errors <- fmt.Errorf("request %d failed: %w", id, err)
done <- false
diff --git a/web/templates/index.html b/web/templates/index.html
index 54bb0c6..c270b48 100644
--- a/web/templates/index.html
+++ b/web/templates/index.html
@@ -6,7 +6,7 @@
<title>Personal Dashboard</title>
<link rel="stylesheet" href="/static/css/output.css">
</head>
-<body class="min-h-screen">
+<body class="min-h-screen" hx-headers='{"X-CSRF-Token": "{{.CSRFToken}}"}'>
<div class="content-max-width py-8">
<!-- Header -->
<header class="mb-8 flex flex-col sm:flex-row justify-between items-start sm:items-center gap-4">
@@ -20,6 +20,7 @@
<span id="refresh-text">Refresh</span>
</button>
<form method="POST" action="/logout" class="no-print">
+ <input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<button type="submit"
class="text-gray-600 hover:text-gray-900 px-3 py-2 rounded-lg transition-colors font-medium">
Logout
diff --git a/web/templates/login.html b/web/templates/login.html
index e5ce9e4..c865ce5 100644
--- a/web/templates/login.html
+++ b/web/templates/login.html
@@ -18,6 +18,7 @@
{{end}}
<form method="POST" action="/login" class="space-y-6">
+ <input type="hidden" name="csrf_token" value="{{.CSRFToken}}">
<div>
<label for="username" class="block text-sm font-medium text-gray-700 mb-2">
Username