diff options
Diffstat (limited to 'internal/middleware')
| -rw-r--r-- | internal/middleware/security_test.go | 200 |
1 files changed, 200 insertions, 0 deletions
diff --git a/internal/middleware/security_test.go b/internal/middleware/security_test.go new file mode 100644 index 0000000..1717418 --- /dev/null +++ b/internal/middleware/security_test.go @@ -0,0 +1,200 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestSecurityHeaders_Debug(t *testing.T) { + handler := SecurityHeaders(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + // Check common security headers + if rec.Header().Get("X-Content-Type-Options") != "nosniff" { + t.Error("Expected X-Content-Type-Options header") + } + if rec.Header().Get("X-Frame-Options") != "DENY" { + t.Error("Expected X-Frame-Options header") + } + if rec.Header().Get("X-XSS-Protection") != "1; mode=block" { + t.Error("Expected X-XSS-Protection header") + } + + // HSTS should NOT be set in debug mode + if rec.Header().Get("Strict-Transport-Security") != "" { + t.Error("HSTS should not be set in debug mode") + } +} + +func TestSecurityHeaders_Production(t *testing.T) { + handler := SecurityHeaders(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + // HSTS should be set in production + if rec.Header().Get("Strict-Transport-Security") == "" { + t.Error("HSTS should be set in production mode") + } + + // CSP should be set + if rec.Header().Get("Content-Security-Policy") == "" { + t.Error("Expected Content-Security-Policy header") + } +} + +func TestRateLimiter_Allow(t *testing.T) { + rl := &RateLimiter{ + requests: make(map[string][]time.Time), + limit: 3, + window: time.Minute, + } + + ip := "192.168.1.1" + + // First 3 requests should be allowed + for i := 0; i < 3; i++ { + if !rl.Allow(ip) { + t.Errorf("Request %d should be allowed", i+1) + } + } + + // 4th request should be denied + if rl.Allow(ip) { + t.Error("4th request should be denied") + } +} + +func TestRateLimiter_WindowExpiry(t *testing.T) { + rl := &RateLimiter{ + requests: make(map[string][]time.Time), + limit: 2, + window: 50 * time.Millisecond, + } + + ip := "192.168.1.1" + + // Use up the limit + rl.Allow(ip) + rl.Allow(ip) + + // Should be denied + if rl.Allow(ip) { + t.Error("Should be denied when limit reached") + } + + // Wait for window to expire + time.Sleep(60 * time.Millisecond) + + // Should be allowed again + if !rl.Allow(ip) { + t.Error("Should be allowed after window expires") + } +} + +func TestRateLimiter_Limit_Middleware(t *testing.T) { + rl := &RateLimiter{ + requests: make(map[string][]time.Time), + limit: 1, + window: time.Minute, + } + + handler := rl.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should pass + req1 := httptest.NewRequest("GET", "/", nil) + req1.RemoteAddr = "10.0.0.1:12345" + rec1 := httptest.NewRecorder() + handler.ServeHTTP(rec1, req1) + + if rec1.Code != http.StatusOK { + t.Errorf("First request should return 200, got %d", rec1.Code) + } + + // Second request should be rate limited + req2 := httptest.NewRequest("GET", "/", nil) + req2.RemoteAddr = "10.0.0.1:12345" + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + + if rec2.Code != http.StatusTooManyRequests { + t.Errorf("Second request should return 429, got %d", rec2.Code) + } +} + +func TestNewRateLimiter(t *testing.T) { + rl := NewRateLimiter(10, 100*time.Millisecond) + if rl == nil { + t.Fatal("NewRateLimiter returned nil") + } + if rl.limit != 10 { + t.Errorf("Expected limit 10, got %d", rl.limit) + } + if rl.window != 100*time.Millisecond { + t.Errorf("Expected window 100ms, got %v", rl.window) + } + // Let cleanup run once + time.Sleep(150 * time.Millisecond) +} + +func TestGetIP(t *testing.T) { + tests := []struct { + name string + xff string + xri string + remoteAddr string + expected string + }{ + { + name: "X-Forwarded-For takes priority", + xff: "1.2.3.4", + xri: "5.6.7.8", + remoteAddr: "9.10.11.12", + expected: "1.2.3.4", + }, + { + name: "X-Real-IP when no XFF", + xff: "", + xri: "5.6.7.8", + remoteAddr: "9.10.11.12", + expected: "5.6.7.8", + }, + { + name: "RemoteAddr as fallback", + xff: "", + xri: "", + remoteAddr: "9.10.11.12:54321", + expected: "9.10.11.12:54321", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = tc.remoteAddr + if tc.xff != "" { + req.Header.Set("X-Forwarded-For", tc.xff) + } + if tc.xri != "" { + req.Header.Set("X-Real-IP", tc.xri) + } + + ip := getIP(req) + if ip != tc.expected { + t.Errorf("Expected IP %s, got %s", tc.expected, ip) + } + }) + } +} |
