package auth import ( "context" "crypto/rand" "crypto/subtle" "encoding/base64" "net/http" "github.com/alexedwards/scs/v2" ) const ( SessionKeyUserID = "user_id" SessionKeyCSRF = "csrf_token" ) type contextKey string const ContextKeyCSRF contextKey = "csrf_token" // Middleware provides authentication middleware type Middleware struct { sessions *scs.SessionManager } // NewMiddleware creates a new auth middleware func NewMiddleware(sessions *scs.SessionManager) *Middleware { return &Middleware{sessions: sessions} } // RequireAuth redirects to login if not authenticated func (m *Middleware) RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !m.IsAuthenticated(r) { http.Redirect(w, r, "/login", http.StatusSeeOther) return } next.ServeHTTP(w, r) }) } // IsAuthenticated checks if the current request has a valid session func (m *Middleware) IsAuthenticated(r *http.Request) bool { return m.sessions.Exists(r.Context(), SessionKeyUserID) } // GetUserID returns the authenticated user's ID from the session func (m *Middleware) GetUserID(r *http.Request) int64 { return m.sessions.GetInt64(r.Context(), SessionKeyUserID) } // SetUserID sets the user ID in the session (called after successful login) func (m *Middleware) SetUserID(r *http.Request, userID int64) { m.sessions.Put(r.Context(), SessionKeyUserID, userID) } // ClearSession removes the user ID from the session (called on logout) func (m *Middleware) ClearSession(r *http.Request) error { return m.sessions.Destroy(r.Context()) } // CSRFProtect checks for a valid CSRF token on state-changing requests func (m *Middleware) CSRFProtect(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Ensure a token exists in the session if !m.sessions.Exists(r.Context(), SessionKeyCSRF) { token, err := generateToken() if err != nil { http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } m.sessions.Put(r.Context(), SessionKeyCSRF, token) } token := m.sessions.GetString(r.Context(), SessionKeyCSRF) // Check token for state-changing methods if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" { requestToken := r.Header.Get("X-CSRF-Token") if requestToken == "" { requestToken = r.FormValue("csrf_token") } // Use constant-time comparison to prevent timing attacks if requestToken == "" || subtle.ConstantTimeCompare([]byte(requestToken), []byte(token)) != 1 { http.Error(w, "Forbidden - CSRF Token Mismatch", http.StatusForbidden) return } } // Add token to context for handlers to use ctx := context.WithValue(r.Context(), ContextKeyCSRF, token) next.ServeHTTP(w, r.WithContext(ctx)) }) } // GetCSRFToken retrieves the CSRF token from the session func (m *Middleware) GetCSRFToken(r *http.Request) string { if !m.sessions.Exists(r.Context(), SessionKeyCSRF) { return "" } return m.sessions.GetString(r.Context(), SessionKeyCSRF) } // GetCSRFTokenFromContext retrieves the CSRF token from the context func GetCSRFTokenFromContext(ctx context.Context) string { token, ok := ctx.Value(ContextKeyCSRF).(string) if !ok { return "" } return token } func generateToken() (string, error) { b := make([]byte, 32) _, err := rand.Read(b) if err != nil { return "", err } return base64.URLEncoding.EncodeToString(b), nil }