diff options
Diffstat (limited to 'internal/middleware')
| -rw-r--r-- | internal/middleware/security.go | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/internal/middleware/security.go b/internal/middleware/security.go new file mode 100644 index 0000000..159a0e6 --- /dev/null +++ b/internal/middleware/security.go @@ -0,0 +1,132 @@ +package middleware + +import ( + "net/http" + "sync" + "time" +) + +// SecurityHeaders adds security headers to all responses +func SecurityHeaders(debug bool) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("X-XSS-Protection", "1; mode=block") + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") + + // Only set HSTS in production (when not debug) + if !debug { + w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains") + } + + // Content Security Policy - allow self, inline styles (Tailwind), and external images + w.Header().Set("Content-Security-Policy", + "default-src 'self'; "+ + "img-src 'self' https: data:; "+ + "script-src 'self' 'unsafe-inline' https://unpkg.com; "+ + "style-src 'self' 'unsafe-inline'; "+ + "font-src 'self' https:; "+ + "connect-src 'self'") + + next.ServeHTTP(w, r) + }) + } +} + +// RateLimiter provides IP-based rate limiting +type RateLimiter struct { + requests map[string][]time.Time + mu sync.Mutex + limit int + window time.Duration +} + +// NewRateLimiter creates a new rate limiter +func NewRateLimiter(limit int, window time.Duration) *RateLimiter { + rl := &RateLimiter{ + requests: make(map[string][]time.Time), + limit: limit, + window: window, + } + // Start cleanup goroutine + go rl.cleanup() + return rl +} + +// cleanup removes old entries periodically +func (rl *RateLimiter) cleanup() { + ticker := time.NewTicker(rl.window) + defer ticker.Stop() + for range ticker.C { + rl.mu.Lock() + now := time.Now() + for ip, times := range rl.requests { + var valid []time.Time + for _, t := range times { + if now.Sub(t) < rl.window { + valid = append(valid, t) + } + } + if len(valid) == 0 { + delete(rl.requests, ip) + } else { + rl.requests[ip] = valid + } + } + rl.mu.Unlock() + } +} + +// Allow checks if a request from the given IP is allowed +func (rl *RateLimiter) Allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + times := rl.requests[ip] + + // Filter to only requests within the window + var valid []time.Time + for _, t := range times { + if now.Sub(t) < rl.window { + valid = append(valid, t) + } + } + + if len(valid) >= rl.limit { + rl.requests[ip] = valid + return false + } + + valid = append(valid, now) + rl.requests[ip] = valid + return true +} + +// Limit returns middleware that rate limits requests +func (rl *RateLimiter) Limit(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := getIP(r) + if !rl.Allow(ip) { + http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) +} + +// getIP extracts the client IP from the request +func getIP(r *http.Request) string { + // Check X-Forwarded-For header (common with reverse proxies) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + return xff + } + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + // Fall back to RemoteAddr + return r.RemoteAddr +} |
