1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
|
package api
import (
"errors"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"golang.org/x/net/websocket"
)
// wsPingInterval and wsPingDeadline control heartbeat timing.
// Exposed as vars so tests can override them without rebuilding.
var (
wsPingInterval = 30 * time.Second
wsPingDeadline = 10 * time.Second
// maxWsClients caps the number of concurrent WebSocket connections.
// Exposed as a var so tests can override it.
maxWsClients = 1000
)
// Hub manages WebSocket connections and broadcasts messages.
type Hub struct {
mu sync.RWMutex
clients map[*websocket.Conn]bool
logger *slog.Logger
}
func NewHub() *Hub {
return &Hub{
clients: make(map[*websocket.Conn]bool),
logger: slog.Default(),
}
}
// Run is a no-op loop kept for future cleanup/heartbeat logic.
func (h *Hub) Run() {}
func (h *Hub) Register(ws *websocket.Conn) error {
h.mu.Lock()
defer h.mu.Unlock()
if len(h.clients) >= maxWsClients {
return errors.New("max WebSocket clients reached")
}
h.clients[ws] = true
return nil
}
func (h *Hub) Unregister(ws *websocket.Conn) {
h.mu.Lock()
delete(h.clients, ws)
h.mu.Unlock()
}
// Broadcast sends a message to all connected WebSocket clients.
func (h *Hub) Broadcast(msg []byte) {
h.mu.RLock()
defer h.mu.RUnlock()
for conn := range h.clients {
if _, err := conn.Write(msg); err != nil {
h.logger.Error("websocket write error", "error", err)
}
}
}
// ClientCount returns the number of connected clients.
func (h *Hub) ClientCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.clients)
}
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
if s.hub.ClientCount() >= maxWsClients {
http.Error(w, "too many connections", http.StatusServiceUnavailable)
return
}
if s.apiToken != "" {
token := r.URL.Query().Get("token")
if token == "" {
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
token = strings.TrimPrefix(auth, "Bearer ")
}
}
if token != s.apiToken {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
}
handler := websocket.Handler(func(ws *websocket.Conn) {
if err := s.hub.Register(ws); err != nil {
return
}
defer s.hub.Unregister(ws)
// Ping goroutine: detect dead connections by sending periodic pings.
// A write failure (including write deadline exceeded) closes the conn,
// causing the read loop below to exit and unregister the client.
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(wsPingInterval)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
ws.SetWriteDeadline(time.Now().Add(wsPingDeadline))
err := websocket.Message.Send(ws, "ping")
ws.SetWriteDeadline(time.Time{})
if err != nil {
ws.Close()
return
}
}
}
}()
// Keep connection alive until client disconnects or ping fails.
buf := make([]byte, 1024)
for {
if _, err := ws.Read(buf); err != nil {
break
}
}
})
handler.ServeHTTP(w, r)
}
|