1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
|
package api
import (
"net"
"net/http"
"sync"
"time"
)
// ipRateLimiter provides per-IP token-bucket rate limiting.
type ipRateLimiter struct {
mu sync.Mutex
limiters map[string]*tokenBucket
rate float64 // tokens replenished per second
burst int // maximum token capacity
}
// newIPRateLimiter creates a limiter with the given replenishment rate (tokens/sec)
// and burst capacity. Use rate=0 to disable replenishment (tokens never refill).
func newIPRateLimiter(rate float64, burst int) *ipRateLimiter {
return &ipRateLimiter{
limiters: make(map[string]*tokenBucket),
rate: rate,
burst: burst,
}
}
func (l *ipRateLimiter) allow(ip string) bool {
l.mu.Lock()
b, ok := l.limiters[ip]
if !ok {
b = &tokenBucket{
tokens: float64(l.burst),
capacity: float64(l.burst),
rate: l.rate,
lastTime: time.Now(),
}
l.limiters[ip] = b
}
l.mu.Unlock()
return b.allow()
}
// middleware wraps h with per-IP rate limiting, returning 429 when exceeded.
func (l *ipRateLimiter) middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := realIP(r)
if !l.allow(ip) {
writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
return
}
next.ServeHTTP(w, r)
})
}
// tokenBucket is a simple token-bucket rate limiter for a single key.
type tokenBucket struct {
mu sync.Mutex
tokens float64
capacity float64
rate float64 // tokens per second
lastTime time.Time
}
func (b *tokenBucket) allow() bool {
b.mu.Lock()
defer b.mu.Unlock()
now := time.Now()
if !b.lastTime.IsZero() {
elapsed := now.Sub(b.lastTime).Seconds()
b.tokens = min(b.capacity, b.tokens+elapsed*b.rate)
}
b.lastTime = now
if b.tokens >= 1.0 {
b.tokens--
return true
}
return false
}
// realIP extracts the client's real IP from a request.
func realIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
for i, c := range xff {
if c == ',' {
return xff[:i]
}
}
return xff
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
|