summaryrefslogtreecommitdiff
path: root/internal/middleware/security.go
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-01-26 07:01:25 -1000
committerPeter Stone <thepeterstone@gmail.com>2026-01-26 07:01:25 -1000
commit8c2c88f90039e87b29ce32cd31b7b0361b5803d0 (patch)
tree6099e498084b876d343b071bbdf2cb62838eae7d /internal/middleware/security.go
parentf5b997bfc4c77ef262726d14b30d387eb7acd1c6 (diff)
Phase 1: Critical security fixes
- Remove default password fallback - require DEFAULT_PASS in all environments - Fix XSS vulnerabilities in HTML generation (handlers.go:795,920) - Add security headers middleware (X-Frame-Options, CSP, HSTS, etc.) - Add rate limiting on login endpoint (5 req/15min per IP) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'internal/middleware/security.go')
-rw-r--r--internal/middleware/security.go132
1 files changed, 132 insertions, 0 deletions
diff --git a/internal/middleware/security.go b/internal/middleware/security.go
new file mode 100644
index 0000000..159a0e6
--- /dev/null
+++ b/internal/middleware/security.go
@@ -0,0 +1,132 @@
+package middleware
+
+import (
+ "net/http"
+ "sync"
+ "time"
+)
+
+// SecurityHeaders adds security headers to all responses
+func SecurityHeaders(debug bool) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.Header().Set("X-Frame-Options", "DENY")
+ w.Header().Set("X-XSS-Protection", "1; mode=block")
+ w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
+ w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
+
+ // Only set HSTS in production (when not debug)
+ if !debug {
+ w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains")
+ }
+
+ // Content Security Policy - allow self, inline styles (Tailwind), and external images
+ w.Header().Set("Content-Security-Policy",
+ "default-src 'self'; "+
+ "img-src 'self' https: data:; "+
+ "script-src 'self' 'unsafe-inline' https://unpkg.com; "+
+ "style-src 'self' 'unsafe-inline'; "+
+ "font-src 'self' https:; "+
+ "connect-src 'self'")
+
+ next.ServeHTTP(w, r)
+ })
+ }
+}
+
+// RateLimiter provides IP-based rate limiting
+type RateLimiter struct {
+ requests map[string][]time.Time
+ mu sync.Mutex
+ limit int
+ window time.Duration
+}
+
+// NewRateLimiter creates a new rate limiter
+func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
+ rl := &RateLimiter{
+ requests: make(map[string][]time.Time),
+ limit: limit,
+ window: window,
+ }
+ // Start cleanup goroutine
+ go rl.cleanup()
+ return rl
+}
+
+// cleanup removes old entries periodically
+func (rl *RateLimiter) cleanup() {
+ ticker := time.NewTicker(rl.window)
+ defer ticker.Stop()
+ for range ticker.C {
+ rl.mu.Lock()
+ now := time.Now()
+ for ip, times := range rl.requests {
+ var valid []time.Time
+ for _, t := range times {
+ if now.Sub(t) < rl.window {
+ valid = append(valid, t)
+ }
+ }
+ if len(valid) == 0 {
+ delete(rl.requests, ip)
+ } else {
+ rl.requests[ip] = valid
+ }
+ }
+ rl.mu.Unlock()
+ }
+}
+
+// Allow checks if a request from the given IP is allowed
+func (rl *RateLimiter) Allow(ip string) bool {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ now := time.Now()
+ times := rl.requests[ip]
+
+ // Filter to only requests within the window
+ var valid []time.Time
+ for _, t := range times {
+ if now.Sub(t) < rl.window {
+ valid = append(valid, t)
+ }
+ }
+
+ if len(valid) >= rl.limit {
+ rl.requests[ip] = valid
+ return false
+ }
+
+ valid = append(valid, now)
+ rl.requests[ip] = valid
+ return true
+}
+
+// Limit returns middleware that rate limits requests
+func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ip := getIP(r)
+ if !rl.Allow(ip) {
+ http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
+ return
+ }
+ next.ServeHTTP(w, r)
+ })
+}
+
+// getIP extracts the client IP from the request
+func getIP(r *http.Request) string {
+ // Check X-Forwarded-For header (common with reverse proxies)
+ if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
+ return xff
+ }
+ // Check X-Real-IP header
+ if xri := r.Header.Get("X-Real-IP"); xri != "" {
+ return xri
+ }
+ // Fall back to RemoteAddr
+ return r.RemoteAddr
+}