diff options
Diffstat (limited to 'internal/auth/middleware.go')
| -rw-r--r-- | internal/auth/middleware.go | 73 |
1 files changed, 72 insertions, 1 deletions
diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 7710328..b440032 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -1,12 +1,22 @@ package auth import ( + "context" + "crypto/rand" + "encoding/base64" "net/http" "github.com/alexedwards/scs/v2" ) -const SessionKeyUserID = "user_id" +const ( + SessionKeyUserID = "user_id" + SessionKeyCSRF = "csrf_token" +) + +type contextKey string + +const ContextKeyCSRF contextKey = "csrf_token" // Middleware provides authentication middleware type Middleware struct { @@ -48,3 +58,64 @@ func (m *Middleware) SetUserID(r *http.Request, userID int64) { func (m *Middleware) ClearSession(r *http.Request) error { return m.sessions.Destroy(r.Context()) } + +// CSRFProtect checks for a valid CSRF token on state-changing requests +func (m *Middleware) CSRFProtect(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Ensure a token exists in the session + if !m.sessions.Exists(r.Context(), SessionKeyCSRF) { + token, err := generateToken() + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + m.sessions.Put(r.Context(), SessionKeyCSRF, token) + } + + token := m.sessions.GetString(r.Context(), SessionKeyCSRF) + + // Check token for state-changing methods + if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" { + requestToken := r.Header.Get("X-CSRF-Token") + + if requestToken == "" { + requestToken = r.FormValue("csrf_token") + } + + if requestToken == "" || requestToken != token { + http.Error(w, "Forbidden - CSRF Token Mismatch", http.StatusForbidden) + return + } + } + + // Add token to context for handlers to use + ctx := context.WithValue(r.Context(), ContextKeyCSRF, token) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetCSRFToken retrieves the CSRF token from the session +func (m *Middleware) GetCSRFToken(r *http.Request) string { + if !m.sessions.Exists(r.Context(), SessionKeyCSRF) { + return "" + } + return m.sessions.GetString(r.Context(), SessionKeyCSRF) +} + +// GetCSRFTokenFromContext retrieves the CSRF token from the context +func GetCSRFTokenFromContext(ctx context.Context) string { + token, ok := ctx.Value(ContextKeyCSRF).(string) + if !ok { + return "" + } + return token +} + +func generateToken() (string, error) { + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} |
