summaryrefslogtreecommitdiff
path: root/internal/auth/middleware.go
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-01-20 15:18:57 -1000
committerPeter Stone <thepeterstone@gmail.com>2026-01-20 15:18:57 -1000
commit78e8f597ff28f1b8406f5cfbf934adc22abdf85b (patch)
treef3b7dfff2c460e2d8752b61c131e80a73fa6b08d /internal/auth/middleware.go
parent08bbcf18b1207153983261652b4a43a9b36f386c (diff)
Add CSRF protection and auth unit tests
Add CSRF token middleware for state-changing request protection, integrate tokens into templates and HTMX headers, and add unit tests for authentication service and handlers. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'internal/auth/middleware.go')
-rw-r--r--internal/auth/middleware.go73
1 files changed, 72 insertions, 1 deletions
diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go
index 7710328..b440032 100644
--- a/internal/auth/middleware.go
+++ b/internal/auth/middleware.go
@@ -1,12 +1,22 @@
package auth
import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
"net/http"
"github.com/alexedwards/scs/v2"
)
-const SessionKeyUserID = "user_id"
+const (
+ SessionKeyUserID = "user_id"
+ SessionKeyCSRF = "csrf_token"
+)
+
+type contextKey string
+
+const ContextKeyCSRF contextKey = "csrf_token"
// Middleware provides authentication middleware
type Middleware struct {
@@ -48,3 +58,64 @@ func (m *Middleware) SetUserID(r *http.Request, userID int64) {
func (m *Middleware) ClearSession(r *http.Request) error {
return m.sessions.Destroy(r.Context())
}
+
+// CSRFProtect checks for a valid CSRF token on state-changing requests
+func (m *Middleware) CSRFProtect(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Ensure a token exists in the session
+ if !m.sessions.Exists(r.Context(), SessionKeyCSRF) {
+ token, err := generateToken()
+ if err != nil {
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ m.sessions.Put(r.Context(), SessionKeyCSRF, token)
+ }
+
+ token := m.sessions.GetString(r.Context(), SessionKeyCSRF)
+
+ // Check token for state-changing methods
+ if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" {
+ requestToken := r.Header.Get("X-CSRF-Token")
+
+ if requestToken == "" {
+ requestToken = r.FormValue("csrf_token")
+ }
+
+ if requestToken == "" || requestToken != token {
+ http.Error(w, "Forbidden - CSRF Token Mismatch", http.StatusForbidden)
+ return
+ }
+ }
+
+ // Add token to context for handlers to use
+ ctx := context.WithValue(r.Context(), ContextKeyCSRF, token)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+}
+
+// GetCSRFToken retrieves the CSRF token from the session
+func (m *Middleware) GetCSRFToken(r *http.Request) string {
+ if !m.sessions.Exists(r.Context(), SessionKeyCSRF) {
+ return ""
+ }
+ return m.sessions.GetString(r.Context(), SessionKeyCSRF)
+}
+
+// GetCSRFTokenFromContext retrieves the CSRF token from the context
+func GetCSRFTokenFromContext(ctx context.Context) string {
+ token, ok := ctx.Value(ContextKeyCSRF).(string)
+ if !ok {
+ return ""
+ }
+ return token
+}
+
+func generateToken() (string, error) {
+ b := make([]byte, 32)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ return base64.URLEncoding.EncodeToString(b), nil
+}