package middleware import ( "net/http" "net/http/httptest" "testing" "time" ) func TestSecurityHeaders_Debug(t *testing.T) { handler := SecurityHeaders(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) // Check common security headers if rec.Header().Get("X-Content-Type-Options") != "nosniff" { t.Error("Expected X-Content-Type-Options header") } if rec.Header().Get("X-Frame-Options") != "DENY" { t.Error("Expected X-Frame-Options header") } if rec.Header().Get("X-XSS-Protection") != "1; mode=block" { t.Error("Expected X-XSS-Protection header") } // HSTS should NOT be set in debug mode if rec.Header().Get("Strict-Transport-Security") != "" { t.Error("HSTS should not be set in debug mode") } } func TestSecurityHeaders_Production(t *testing.T) { handler := SecurityHeaders(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) // HSTS should be set in production if rec.Header().Get("Strict-Transport-Security") == "" { t.Error("HSTS should be set in production mode") } // CSP should be set if rec.Header().Get("Content-Security-Policy") == "" { t.Error("Expected Content-Security-Policy header") } } func TestRateLimiter_Allow(t *testing.T) { rl := &RateLimiter{ requests: make(map[string][]time.Time), limit: 3, window: time.Minute, } ip := "192.168.1.1" // First 3 requests should be allowed for i := 0; i < 3; i++ { if !rl.Allow(ip) { t.Errorf("Request %d should be allowed", i+1) } } // 4th request should be denied if rl.Allow(ip) { t.Error("4th request should be denied") } } func TestRateLimiter_WindowExpiry(t *testing.T) { rl := &RateLimiter{ requests: make(map[string][]time.Time), limit: 2, window: 50 * time.Millisecond, } ip := "192.168.1.1" // Use up the limit rl.Allow(ip) rl.Allow(ip) // Should be denied if rl.Allow(ip) { t.Error("Should be denied when limit reached") } // Wait for window to expire time.Sleep(60 * time.Millisecond) // Should be allowed again if !rl.Allow(ip) { t.Error("Should be allowed after window expires") } } func TestRateLimiter_Limit_Middleware(t *testing.T) { rl := &RateLimiter{ requests: make(map[string][]time.Time), limit: 1, window: time.Minute, } handler := rl.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // First request should pass req1 := httptest.NewRequest("GET", "/", nil) req1.RemoteAddr = "10.0.0.1:12345" rec1 := httptest.NewRecorder() handler.ServeHTTP(rec1, req1) if rec1.Code != http.StatusOK { t.Errorf("First request should return 200, got %d", rec1.Code) } // Second request should be rate limited req2 := httptest.NewRequest("GET", "/", nil) req2.RemoteAddr = "10.0.0.1:12345" rec2 := httptest.NewRecorder() handler.ServeHTTP(rec2, req2) if rec2.Code != http.StatusTooManyRequests { t.Errorf("Second request should return 429, got %d", rec2.Code) } } func TestNewRateLimiter(t *testing.T) { rl := NewRateLimiter(10, 100*time.Millisecond) if rl == nil { t.Fatal("NewRateLimiter returned nil") } if rl.limit != 10 { t.Errorf("Expected limit 10, got %d", rl.limit) } if rl.window != 100*time.Millisecond { t.Errorf("Expected window 100ms, got %v", rl.window) } // Let cleanup run once time.Sleep(150 * time.Millisecond) } func TestGetIP(t *testing.T) { tests := []struct { name string xff string xri string remoteAddr string expected string }{ { name: "X-Forwarded-For takes priority", xff: "1.2.3.4", xri: "5.6.7.8", remoteAddr: "9.10.11.12", expected: "1.2.3.4", }, { name: "X-Real-IP when no XFF", xff: "", xri: "5.6.7.8", remoteAddr: "9.10.11.12", expected: "5.6.7.8", }, { name: "RemoteAddr as fallback", xff: "", xri: "", remoteAddr: "9.10.11.12:54321", expected: "9.10.11.12:54321", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = tc.remoteAddr if tc.xff != "" { req.Header.Set("X-Forwarded-For", tc.xff) } if tc.xri != "" { req.Header.Set("X-Real-IP", tc.xri) } ip := getIP(req) if ip != tc.expected { t.Errorf("Expected IP %s, got %s", tc.expected, ip) } }) } }