diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-01-20 15:18:57 -1000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-01-20 15:18:57 -1000 |
| commit | 78e8f597ff28f1b8406f5cfbf934adc22abdf85b (patch) | |
| tree | f3b7dfff2c460e2d8752b61c131e80a73fa6b08d /internal/auth | |
| parent | 08bbcf18b1207153983261652b4a43a9b36f386c (diff) | |
Add CSRF protection and auth unit tests
Add CSRF token middleware for state-changing request protection,
integrate tokens into templates and HTMX headers, and add unit
tests for authentication service and handlers.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'internal/auth')
| -rw-r--r-- | internal/auth/auth_test.go | 108 | ||||
| -rw-r--r-- | internal/auth/handlers.go | 18 | ||||
| -rw-r--r-- | internal/auth/handlers_test.go | 106 | ||||
| -rw-r--r-- | internal/auth/middleware.go | 73 |
4 files changed, 297 insertions, 8 deletions
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..505efe3 --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,108 @@ +package auth + +import ( + "database/sql" + "errors" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "golang.org/x/crypto/bcrypt" +) + +func TestAuthenticate(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + + password := "secret" + hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(1, "testuser", string(hash), time.Now()) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("testuser"). + WillReturnRows(rows) + + user, err := service.Authenticate("testuser", password) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if user == nil { + t.Errorf("expected user, got nil") + } + if user.Username != "testuser" { + t.Errorf("expected username testuser, got %s", user.Username) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestAuthenticate_InvalidCredentials(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("nonexistent"). + WillReturnError(sql.ErrNoRows) + + _, err = service.Authenticate("nonexistent", "password") + if !errors.Is(err, ErrInvalidCredentials) { + t.Errorf("expected ErrInvalidCredentials, got %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestCreateUser(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + + // Expect check if user exists + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("newuser"). + WillReturnError(sql.ErrNoRows) + + // Expect insert + mock.ExpectExec("INSERT INTO users"). + WithArgs("newuser", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // Expect retrieve created user + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(1, "newuser", "hashedpassword", time.Now()) + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?"). + WithArgs(1). + WillReturnRows(rows) + + user, err := service.CreateUser("newuser", "password") + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if user.Username != "newuser" { + t.Errorf("expected username newuser, got %s", user.Username) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/internal/auth/handlers.go b/internal/auth/handlers.go index 17bcabd..c690d29 100644 --- a/internal/auth/handlers.go +++ b/internal/auth/handlers.go @@ -40,9 +40,11 @@ func (h *Handlers) HandleLoginPage(w http.ResponseWriter, r *http.Request) { } data := struct { - Error string + Error string + CSRFToken string }{ - Error: "", + Error: "", + CSRFToken: h.middleware.GetCSRFToken(r), } if err := h.templates.ExecuteTemplate(w, "login.html", data); err != nil { @@ -62,14 +64,14 @@ func (h *Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) { password := r.FormValue("password") if username == "" || password == "" { - h.renderLoginError(w, "Username and password are required") + h.renderLoginError(w, r, "Username and password are required") return } user, err := h.service.Authenticate(username, password) if err != nil { log.Printf("Login failed for user %s: %v", username, err) - h.renderLoginError(w, "Invalid username or password") + h.renderLoginError(w, r, "Invalid username or password") return } @@ -96,11 +98,13 @@ func (h *Handlers) HandleLogout(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/login", http.StatusSeeOther) } -func (h *Handlers) renderLoginError(w http.ResponseWriter, errorMsg string) { +func (h *Handlers) renderLoginError(w http.ResponseWriter, r *http.Request, errorMsg string) { data := struct { - Error string + Error string + CSRFToken string }{ - Error: errorMsg, + Error: errorMsg, + CSRFToken: h.middleware.GetCSRFToken(r), } w.WriteHeader(http.StatusUnauthorized) diff --git a/internal/auth/handlers_test.go b/internal/auth/handlers_test.go new file mode 100644 index 0000000..3e154ce --- /dev/null +++ b/internal/auth/handlers_test.go @@ -0,0 +1,106 @@ +package auth + +import ( + "database/sql" + "html/template" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/alexedwards/scs/v2" + "golang.org/x/crypto/bcrypt" +) + +func TestHandleLogin(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + sessionManager := scs.New() + templates := template.Must(template.New("login.html").Parse("{{.Error}}")) + + handlers := NewHandlers(service, sessionManager, templates) + + // Setup mock user + password := "password" + hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(1, "testuser", string(hash), time.Now()) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("testuser"). + WillReturnRows(rows) + + // Create request + form := url.Values{} + form.Add("username", "testuser") + form.Add("password", "password") + req := httptest.NewRequest("POST", "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Wrap request with session middleware + ctx, _ := sessionManager.Load(req.Context(), "") + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + + handlers.HandleLogin(rr, req) + + if status := rr.Code; status != http.StatusSeeOther { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusSeeOther) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestHandleLogin_InvalidCredentials(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + service := NewService(db) + sessionManager := scs.New() + templates := template.Must(template.New("login.html").Parse("{{.Error}}")) + + handlers := NewHandlers(service, sessionManager, templates) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("testuser"). + WillReturnError(sql.ErrNoRows) + + // Create request + form := url.Values{} + form.Add("username", "testuser") + form.Add("password", "wrongpassword") + req := httptest.NewRequest("POST", "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Wrap request with session middleware + ctx, _ := sessionManager.Load(req.Context(), "") + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + + handlers.HandleLogin(rr, req) + + if status := rr.Code; status != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusUnauthorized) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 7710328..b440032 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -1,12 +1,22 @@ package auth import ( + "context" + "crypto/rand" + "encoding/base64" "net/http" "github.com/alexedwards/scs/v2" ) -const SessionKeyUserID = "user_id" +const ( + SessionKeyUserID = "user_id" + SessionKeyCSRF = "csrf_token" +) + +type contextKey string + +const ContextKeyCSRF contextKey = "csrf_token" // Middleware provides authentication middleware type Middleware struct { @@ -48,3 +58,64 @@ func (m *Middleware) SetUserID(r *http.Request, userID int64) { func (m *Middleware) ClearSession(r *http.Request) error { return m.sessions.Destroy(r.Context()) } + +// CSRFProtect checks for a valid CSRF token on state-changing requests +func (m *Middleware) CSRFProtect(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Ensure a token exists in the session + if !m.sessions.Exists(r.Context(), SessionKeyCSRF) { + token, err := generateToken() + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + m.sessions.Put(r.Context(), SessionKeyCSRF, token) + } + + token := m.sessions.GetString(r.Context(), SessionKeyCSRF) + + // Check token for state-changing methods + if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" { + requestToken := r.Header.Get("X-CSRF-Token") + + if requestToken == "" { + requestToken = r.FormValue("csrf_token") + } + + if requestToken == "" || requestToken != token { + http.Error(w, "Forbidden - CSRF Token Mismatch", http.StatusForbidden) + return + } + } + + // Add token to context for handlers to use + ctx := context.WithValue(r.Context(), ContextKeyCSRF, token) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetCSRFToken retrieves the CSRF token from the session +func (m *Middleware) GetCSRFToken(r *http.Request) string { + if !m.sessions.Exists(r.Context(), SessionKeyCSRF) { + return "" + } + return m.sessions.GetString(r.Context(), SessionKeyCSRF) +} + +// GetCSRFTokenFromContext retrieves the CSRF token from the context +func GetCSRFTokenFromContext(ctx context.Context) string { + token, ok := ctx.Value(ContextKeyCSRF).(string) + if !ok { + return "" + } + return token +} + +func generateToken() (string, error) { + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} |
