summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-01-20 15:18:57 -1000
committerPeter Stone <thepeterstone@gmail.com>2026-01-20 15:18:57 -1000
commit78e8f597ff28f1b8406f5cfbf934adc22abdf85b (patch)
treef3b7dfff2c460e2d8752b61c131e80a73fa6b08d /internal
parent08bbcf18b1207153983261652b4a43a9b36f386c (diff)
Add CSRF protection and auth unit tests
Add CSRF token middleware for state-changing request protection, integrate tokens into templates and HTMX headers, and add unit tests for authentication service and handlers. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'internal')
-rw-r--r--internal/auth/auth_test.go108
-rw-r--r--internal/auth/handlers.go18
-rw-r--r--internal/auth/handlers_test.go106
-rw-r--r--internal/auth/middleware.go73
-rw-r--r--internal/handlers/handlers.go3
5 files changed, 300 insertions, 8 deletions
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
new file mode 100644
index 0000000..505efe3
--- /dev/null
+++ b/internal/auth/auth_test.go
@@ -0,0 +1,108 @@
+package auth
+
+import (
+ "database/sql"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "golang.org/x/crypto/bcrypt"
+)
+
+func TestAuthenticate(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+
+ password := "secret"
+ hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(1, "testuser", string(hash), time.Now())
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("testuser").
+ WillReturnRows(rows)
+
+ user, err := service.Authenticate("testuser", password)
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+ if user == nil {
+ t.Errorf("expected user, got nil")
+ }
+ if user.Username != "testuser" {
+ t.Errorf("expected username testuser, got %s", user.Username)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
+
+func TestAuthenticate_InvalidCredentials(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("nonexistent").
+ WillReturnError(sql.ErrNoRows)
+
+ _, err = service.Authenticate("nonexistent", "password")
+ if !errors.Is(err, ErrInvalidCredentials) {
+ t.Errorf("expected ErrInvalidCredentials, got %v", err)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
+
+func TestCreateUser(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+
+ // Expect check if user exists
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("newuser").
+ WillReturnError(sql.ErrNoRows)
+
+ // Expect insert
+ mock.ExpectExec("INSERT INTO users").
+ WithArgs("newuser", sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+
+ // Expect retrieve created user
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(1, "newuser", "hashedpassword", time.Now())
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?").
+ WithArgs(1).
+ WillReturnRows(rows)
+
+ user, err := service.CreateUser("newuser", "password")
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+ if user.Username != "newuser" {
+ t.Errorf("expected username newuser, got %s", user.Username)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
diff --git a/internal/auth/handlers.go b/internal/auth/handlers.go
index 17bcabd..c690d29 100644
--- a/internal/auth/handlers.go
+++ b/internal/auth/handlers.go
@@ -40,9 +40,11 @@ func (h *Handlers) HandleLoginPage(w http.ResponseWriter, r *http.Request) {
}
data := struct {
- Error string
+ Error string
+ CSRFToken string
}{
- Error: "",
+ Error: "",
+ CSRFToken: h.middleware.GetCSRFToken(r),
}
if err := h.templates.ExecuteTemplate(w, "login.html", data); err != nil {
@@ -62,14 +64,14 @@ func (h *Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) {
password := r.FormValue("password")
if username == "" || password == "" {
- h.renderLoginError(w, "Username and password are required")
+ h.renderLoginError(w, r, "Username and password are required")
return
}
user, err := h.service.Authenticate(username, password)
if err != nil {
log.Printf("Login failed for user %s: %v", username, err)
- h.renderLoginError(w, "Invalid username or password")
+ h.renderLoginError(w, r, "Invalid username or password")
return
}
@@ -96,11 +98,13 @@ func (h *Handlers) HandleLogout(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
-func (h *Handlers) renderLoginError(w http.ResponseWriter, errorMsg string) {
+func (h *Handlers) renderLoginError(w http.ResponseWriter, r *http.Request, errorMsg string) {
data := struct {
- Error string
+ Error string
+ CSRFToken string
}{
- Error: errorMsg,
+ Error: errorMsg,
+ CSRFToken: h.middleware.GetCSRFToken(r),
}
w.WriteHeader(http.StatusUnauthorized)
diff --git a/internal/auth/handlers_test.go b/internal/auth/handlers_test.go
new file mode 100644
index 0000000..3e154ce
--- /dev/null
+++ b/internal/auth/handlers_test.go
@@ -0,0 +1,106 @@
+package auth
+
+import (
+ "database/sql"
+ "html/template"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "github.com/alexedwards/scs/v2"
+ "golang.org/x/crypto/bcrypt"
+)
+
+func TestHandleLogin(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+ sessionManager := scs.New()
+ templates := template.Must(template.New("login.html").Parse("{{.Error}}"))
+
+ handlers := NewHandlers(service, sessionManager, templates)
+
+ // Setup mock user
+ password := "password"
+ hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(1, "testuser", string(hash), time.Now())
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("testuser").
+ WillReturnRows(rows)
+
+ // Create request
+ form := url.Values{}
+ form.Add("username", "testuser")
+ form.Add("password", "password")
+ req := httptest.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ // Wrap request with session middleware
+ ctx, _ := sessionManager.Load(req.Context(), "")
+ req = req.WithContext(ctx)
+
+ rr := httptest.NewRecorder()
+
+ handlers.HandleLogin(rr, req)
+
+ if status := rr.Code; status != http.StatusSeeOther {
+ t.Errorf("handler returned wrong status code: got %v want %v",
+ status, http.StatusSeeOther)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
+
+func TestHandleLogin_InvalidCredentials(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ service := NewService(db)
+ sessionManager := scs.New()
+ templates := template.Must(template.New("login.html").Parse("{{.Error}}"))
+
+ handlers := NewHandlers(service, sessionManager, templates)
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("testuser").
+ WillReturnError(sql.ErrNoRows)
+
+ // Create request
+ form := url.Values{}
+ form.Add("username", "testuser")
+ form.Add("password", "wrongpassword")
+ req := httptest.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ // Wrap request with session middleware
+ ctx, _ := sessionManager.Load(req.Context(), "")
+ req = req.WithContext(ctx)
+
+ rr := httptest.NewRecorder()
+
+ handlers.HandleLogin(rr, req)
+
+ if status := rr.Code; status != http.StatusUnauthorized {
+ t.Errorf("handler returned wrong status code: got %v want %v",
+ status, http.StatusUnauthorized)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go
index 7710328..b440032 100644
--- a/internal/auth/middleware.go
+++ b/internal/auth/middleware.go
@@ -1,12 +1,22 @@
package auth
import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
"net/http"
"github.com/alexedwards/scs/v2"
)
-const SessionKeyUserID = "user_id"
+const (
+ SessionKeyUserID = "user_id"
+ SessionKeyCSRF = "csrf_token"
+)
+
+type contextKey string
+
+const ContextKeyCSRF contextKey = "csrf_token"
// Middleware provides authentication middleware
type Middleware struct {
@@ -48,3 +58,64 @@ func (m *Middleware) SetUserID(r *http.Request, userID int64) {
func (m *Middleware) ClearSession(r *http.Request) error {
return m.sessions.Destroy(r.Context())
}
+
+// CSRFProtect checks for a valid CSRF token on state-changing requests
+func (m *Middleware) CSRFProtect(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Ensure a token exists in the session
+ if !m.sessions.Exists(r.Context(), SessionKeyCSRF) {
+ token, err := generateToken()
+ if err != nil {
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ m.sessions.Put(r.Context(), SessionKeyCSRF, token)
+ }
+
+ token := m.sessions.GetString(r.Context(), SessionKeyCSRF)
+
+ // Check token for state-changing methods
+ if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" {
+ requestToken := r.Header.Get("X-CSRF-Token")
+
+ if requestToken == "" {
+ requestToken = r.FormValue("csrf_token")
+ }
+
+ if requestToken == "" || requestToken != token {
+ http.Error(w, "Forbidden - CSRF Token Mismatch", http.StatusForbidden)
+ return
+ }
+ }
+
+ // Add token to context for handlers to use
+ ctx := context.WithValue(r.Context(), ContextKeyCSRF, token)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+}
+
+// GetCSRFToken retrieves the CSRF token from the session
+func (m *Middleware) GetCSRFToken(r *http.Request) string {
+ if !m.sessions.Exists(r.Context(), SessionKeyCSRF) {
+ return ""
+ }
+ return m.sessions.GetString(r.Context(), SessionKeyCSRF)
+}
+
+// GetCSRFTokenFromContext retrieves the CSRF token from the context
+func GetCSRFTokenFromContext(ctx context.Context) string {
+ token, ok := ctx.Value(ContextKeyCSRF).(string)
+ if !ok {
+ return ""
+ }
+ return token
+}
+
+func generateToken() (string, error) {
+ b := make([]byte, 32)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ return base64.URLEncoding.EncodeToString(b), nil
+}
diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go
index 7bb84b9..d52e786 100644
--- a/internal/handlers/handlers.go
+++ b/internal/handlers/handlers.go
@@ -14,6 +14,7 @@ import (
"time"
"task-dashboard/internal/api"
+ "task-dashboard/internal/auth"
"task-dashboard/internal/config"
"task-dashboard/internal/models"
"task-dashboard/internal/store"
@@ -81,9 +82,11 @@ func (h *Handler) HandleDashboard(w http.ResponseWriter, r *http.Request) {
data := struct {
*models.DashboardData
ActiveTab string
+ CSRFToken string
}{
DashboardData: dashboardData,
ActiveTab: tab,
+ CSRFToken: auth.GetCSRFTokenFromContext(ctx),
}
if err := h.templates.ExecuteTemplate(w, "index.html", data); err != nil {