summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-01-26 07:01:25 -1000
committerPeter Stone <thepeterstone@gmail.com>2026-01-26 07:01:25 -1000
commit8c2c88f90039e87b29ce32cd31b7b0361b5803d0 (patch)
tree6099e498084b876d343b071bbdf2cb62838eae7d
parentf5b997bfc4c77ef262726d14b30d387eb7acd1c6 (diff)
Phase 1: Critical security fixes
- Remove default password fallback - require DEFAULT_PASS in all environments - Fix XSS vulnerabilities in HTML generation (handlers.go:795,920) - Add security headers middleware (X-Frame-Options, CSP, HSTS, etc.) - Add rate limiting on login endpoint (5 req/15min per IP) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
-rw-r--r--cmd/dashboard/main.go17
-rw-r--r--internal/handlers/handlers.go8
-rw-r--r--internal/middleware/security.go132
3 files changed, 147 insertions, 10 deletions
diff --git a/cmd/dashboard/main.go b/cmd/dashboard/main.go
index 40a5002..d7da061 100644
--- a/cmd/dashboard/main.go
+++ b/cmd/dashboard/main.go
@@ -21,6 +21,7 @@ import (
"task-dashboard/internal/auth"
"task-dashboard/internal/config"
"task-dashboard/internal/handlers"
+ appmiddleware "task-dashboard/internal/middleware"
"task-dashboard/internal/store"
)
@@ -59,11 +60,7 @@ func main() {
defaultUser = "admin"
}
if defaultPass == "" {
- if !cfg.Debug {
- log.Fatal("CRITICAL: DEFAULT_PASS must be set in production. Set DEBUG=true for development.")
- }
- log.Println("WARNING: Using default password - set DEFAULT_PASS for production")
- defaultPass = "changeme"
+ log.Fatal("CRITICAL: DEFAULT_PASS environment variable must be set. Cannot start without a password.")
}
if err := authService.EnsureDefaultUser(defaultUser, defaultPass); err != nil {
log.Printf("Warning: failed to ensure default user: %v", err)
@@ -112,12 +109,16 @@ func main() {
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Use(middleware.Timeout(60 * time.Second))
- r.Use(sessionManager.LoadAndSave) // Session middleware must be applied globally
- r.Use(authHandlers.Middleware().CSRFProtect) // CSRF protection
+ r.Use(appmiddleware.SecurityHeaders(cfg.Debug)) // Security headers
+ r.Use(sessionManager.LoadAndSave) // Session middleware must be applied globally
+ r.Use(authHandlers.Middleware().CSRFProtect) // CSRF protection
+
+ // Rate limiter for auth endpoints (5 requests per 15 minutes per IP)
+ authRateLimiter := appmiddleware.NewRateLimiter(5, 15*time.Minute)
// Public routes (no auth required)
r.Get("/login", authHandlers.HandleLoginPage)
- r.Post("/login", authHandlers.HandleLogin)
+ r.With(authRateLimiter.Limit).Post("/login", authHandlers.HandleLogin)
r.Post("/logout", authHandlers.HandleLogout)
// Serve static files (public)
diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go
index a169478..635a69d 100644
--- a/internal/handlers/handlers.go
+++ b/internal/handlers/handlers.go
@@ -792,7 +792,9 @@ func (h *Handler) HandleGetListsOptions(w http.ResponseWriter, r *http.Request)
w.Header().Set("Content-Type", "text/html")
for _, list := range lists {
- _, _ = fmt.Fprintf(w, `<option value="%s">%s</option>`, list.ID, list.Name)
+ _, _ = fmt.Fprintf(w, `<option value="%s">%s</option>`,
+ template.HTMLEscapeString(list.ID),
+ template.HTMLEscapeString(list.Name))
}
}
@@ -917,7 +919,9 @@ func (h *Handler) HandleGetShoppingLists(w http.ResponseWriter, r *http.Request)
w.Header().Set("Content-Type", "text/html")
for _, list := range lists {
- _, _ = fmt.Fprintf(w, `<option value="%s">%s</option>`, list.ID, list.Name)
+ _, _ = fmt.Fprintf(w, `<option value="%s">%s</option>`,
+ template.HTMLEscapeString(list.ID),
+ template.HTMLEscapeString(list.Name))
}
}
diff --git a/internal/middleware/security.go b/internal/middleware/security.go
new file mode 100644
index 0000000..159a0e6
--- /dev/null
+++ b/internal/middleware/security.go
@@ -0,0 +1,132 @@
+package middleware
+
+import (
+ "net/http"
+ "sync"
+ "time"
+)
+
+// SecurityHeaders adds security headers to all responses
+func SecurityHeaders(debug bool) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.Header().Set("X-Frame-Options", "DENY")
+ w.Header().Set("X-XSS-Protection", "1; mode=block")
+ w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
+ w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
+
+ // Only set HSTS in production (when not debug)
+ if !debug {
+ w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains")
+ }
+
+ // Content Security Policy - allow self, inline styles (Tailwind), and external images
+ w.Header().Set("Content-Security-Policy",
+ "default-src 'self'; "+
+ "img-src 'self' https: data:; "+
+ "script-src 'self' 'unsafe-inline' https://unpkg.com; "+
+ "style-src 'self' 'unsafe-inline'; "+
+ "font-src 'self' https:; "+
+ "connect-src 'self'")
+
+ next.ServeHTTP(w, r)
+ })
+ }
+}
+
+// RateLimiter provides IP-based rate limiting
+type RateLimiter struct {
+ requests map[string][]time.Time
+ mu sync.Mutex
+ limit int
+ window time.Duration
+}
+
+// NewRateLimiter creates a new rate limiter
+func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
+ rl := &RateLimiter{
+ requests: make(map[string][]time.Time),
+ limit: limit,
+ window: window,
+ }
+ // Start cleanup goroutine
+ go rl.cleanup()
+ return rl
+}
+
+// cleanup removes old entries periodically
+func (rl *RateLimiter) cleanup() {
+ ticker := time.NewTicker(rl.window)
+ defer ticker.Stop()
+ for range ticker.C {
+ rl.mu.Lock()
+ now := time.Now()
+ for ip, times := range rl.requests {
+ var valid []time.Time
+ for _, t := range times {
+ if now.Sub(t) < rl.window {
+ valid = append(valid, t)
+ }
+ }
+ if len(valid) == 0 {
+ delete(rl.requests, ip)
+ } else {
+ rl.requests[ip] = valid
+ }
+ }
+ rl.mu.Unlock()
+ }
+}
+
+// Allow checks if a request from the given IP is allowed
+func (rl *RateLimiter) Allow(ip string) bool {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ now := time.Now()
+ times := rl.requests[ip]
+
+ // Filter to only requests within the window
+ var valid []time.Time
+ for _, t := range times {
+ if now.Sub(t) < rl.window {
+ valid = append(valid, t)
+ }
+ }
+
+ if len(valid) >= rl.limit {
+ rl.requests[ip] = valid
+ return false
+ }
+
+ valid = append(valid, now)
+ rl.requests[ip] = valid
+ return true
+}
+
+// Limit returns middleware that rate limits requests
+func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ip := getIP(r)
+ if !rl.Allow(ip) {
+ http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
+ return
+ }
+ next.ServeHTTP(w, r)
+ })
+}
+
+// getIP extracts the client IP from the request
+func getIP(r *http.Request) string {
+ // Check X-Forwarded-For header (common with reverse proxies)
+ if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
+ return xff
+ }
+ // Check X-Real-IP header
+ if xri := r.Header.Get("X-Real-IP"); xri != "" {
+ return xri
+ }
+ // Fall back to RemoteAddr
+ return r.RemoteAddr
+}