package middleware import ( "net/http" "sync" "time" ) // SecurityHeaders adds security headers to all responses func SecurityHeaders(debug bool) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") // Only set HSTS in production (when not debug) if !debug { w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains") } // Content Security Policy - allow self, inline styles (Tailwind), external images, and embeds w.Header().Set("Content-Security-Policy", "default-src 'self'; "+ "img-src 'self' https: data:; "+ "script-src 'self' 'unsafe-inline' https://unpkg.com; "+ "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; "+ "font-src 'self' https://fonts.gstatic.com; "+ "frame-src https://www.youtube.com https://embed.windy.com; "+ "connect-src 'self'") next.ServeHTTP(w, r) }) } } // RateLimiter provides IP-based rate limiting type RateLimiter struct { requests map[string][]time.Time mu sync.Mutex limit int window time.Duration } // NewRateLimiter creates a new rate limiter func NewRateLimiter(limit int, window time.Duration) *RateLimiter { rl := &RateLimiter{ requests: make(map[string][]time.Time), limit: limit, window: window, } // Start cleanup goroutine go rl.cleanup() return rl } // cleanup removes old entries periodically func (rl *RateLimiter) cleanup() { ticker := time.NewTicker(rl.window) defer ticker.Stop() for range ticker.C { rl.mu.Lock() now := time.Now() for ip, times := range rl.requests { var valid []time.Time for _, t := range times { if now.Sub(t) < rl.window { valid = append(valid, t) } } if len(valid) == 0 { delete(rl.requests, ip) } else { rl.requests[ip] = valid } } rl.mu.Unlock() } } // Allow checks if a request from the given IP is allowed func (rl *RateLimiter) Allow(ip string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() times := rl.requests[ip] // Filter to only requests within the window var valid []time.Time for _, t := range times { if now.Sub(t) < rl.window { valid = append(valid, t) } } if len(valid) >= rl.limit { rl.requests[ip] = valid return false } valid = append(valid, now) rl.requests[ip] = valid return true } // Limit returns middleware that rate limits requests func (rl *RateLimiter) Limit(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := getIP(r) if !rl.Allow(ip) { http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // getIP extracts the client IP from the request func getIP(r *http.Request) string { // Check X-Forwarded-For header (common with reverse proxies) if xff := r.Header.Get("X-Forwarded-For"); xff != "" { return xff } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // Fall back to RemoteAddr return r.RemoteAddr }