summaryrefslogtreecommitdiff
path: root/internal/api/ratelimit.go
blob: 089354c42e891f43e1fea65b16b5315079620eb4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package api

import (
	"net"
	"net/http"
	"sync"
	"time"
)

// ipRateLimiter provides per-IP token-bucket rate limiting.
type ipRateLimiter struct {
	mu       sync.Mutex
	limiters map[string]*tokenBucket
	rate     float64 // tokens replenished per second
	burst    int     // maximum token capacity
}

// newIPRateLimiter creates a limiter with the given replenishment rate (tokens/sec)
// and burst capacity. Use rate=0 to disable replenishment (tokens never refill).
func newIPRateLimiter(rate float64, burst int) *ipRateLimiter {
	return &ipRateLimiter{
		limiters: make(map[string]*tokenBucket),
		rate:     rate,
		burst:    burst,
	}
}

func (l *ipRateLimiter) allow(ip string) bool {
	l.mu.Lock()
	b, ok := l.limiters[ip]
	if !ok {
		b = &tokenBucket{
			tokens:   float64(l.burst),
			capacity: float64(l.burst),
			rate:     l.rate,
			lastTime: time.Now(),
		}
		l.limiters[ip] = b
	}
	l.mu.Unlock()
	return b.allow()
}

// middleware wraps h with per-IP rate limiting, returning 429 when exceeded.
func (l *ipRateLimiter) middleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ip := realIP(r)
		if !l.allow(ip) {
			writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
			return
		}
		next.ServeHTTP(w, r)
	})
}

// tokenBucket is a simple token-bucket rate limiter for a single key.
type tokenBucket struct {
	mu       sync.Mutex
	tokens   float64
	capacity float64
	rate     float64 // tokens per second
	lastTime time.Time
}

func (b *tokenBucket) allow() bool {
	b.mu.Lock()
	defer b.mu.Unlock()
	now := time.Now()
	if !b.lastTime.IsZero() {
		elapsed := now.Sub(b.lastTime).Seconds()
		b.tokens = min(b.capacity, b.tokens+elapsed*b.rate)
	}
	b.lastTime = now
	if b.tokens >= 1.0 {
		b.tokens--
		return true
	}
	return false
}

// realIP extracts the client's real IP from a request.
func realIP(r *http.Request) string {
	if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
		for i, c := range xff {
			if c == ',' {
				return xff[:i]
			}
		}
		return xff
	}
	if xri := r.Header.Get("X-Real-IP"); xri != "" {
		return xri
	}
	host, _, err := net.SplitHostPort(r.RemoteAddr)
	if err != nil {
		return r.RemoteAddr
	}
	return host
}