package api import ( "net" "net/http" "sync" "time" ) // ipRateLimiter provides per-IP token-bucket rate limiting. type ipRateLimiter struct { mu sync.Mutex limiters map[string]*tokenBucket rate float64 // tokens replenished per second burst int // maximum token capacity } // newIPRateLimiter creates a limiter with the given replenishment rate (tokens/sec) // and burst capacity. Use rate=0 to disable replenishment (tokens never refill). func newIPRateLimiter(rate float64, burst int) *ipRateLimiter { return &ipRateLimiter{ limiters: make(map[string]*tokenBucket), rate: rate, burst: burst, } } func (l *ipRateLimiter) allow(ip string) bool { l.mu.Lock() b, ok := l.limiters[ip] if !ok { b = &tokenBucket{ tokens: float64(l.burst), capacity: float64(l.burst), rate: l.rate, lastTime: time.Now(), } l.limiters[ip] = b } l.mu.Unlock() return b.allow() } // middleware wraps h with per-IP rate limiting, returning 429 when exceeded. func (l *ipRateLimiter) middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := realIP(r) if !l.allow(ip) { writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) return } next.ServeHTTP(w, r) }) } // tokenBucket is a simple token-bucket rate limiter for a single key. type tokenBucket struct { mu sync.Mutex tokens float64 capacity float64 rate float64 // tokens per second lastTime time.Time } func (b *tokenBucket) allow() bool { b.mu.Lock() defer b.mu.Unlock() now := time.Now() if !b.lastTime.IsZero() { elapsed := now.Sub(b.lastTime).Seconds() b.tokens = min(b.capacity, b.tokens+elapsed*b.rate) } b.lastTime = now if b.tokens >= 1.0 { b.tokens-- return true } return false } // realIP extracts the client's real IP from a request. func realIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { for i, c := range xff { if c == ',' { return xff[:i] } } return xff } if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host }