summaryrefslogtreecommitdiff
path: root/internal/api/websocket.go
blob: 25522dc9e8268c6fad6205fb4fc709d1277c20e1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package api

import (
	"errors"
	"log/slog"
	"net/http"
	"strings"
	"sync"
	"time"

	"golang.org/x/net/websocket"
)

// wsPingInterval and wsPingDeadline control heartbeat timing.
// Exposed as vars so tests can override them without rebuilding.
var (
	wsPingInterval = 30 * time.Second
	wsPingDeadline = 10 * time.Second
)

const defaultMaxWsClients = 1000

// Hub manages WebSocket connections and broadcasts messages.
type Hub struct {
	mu         sync.RWMutex
	clients    map[*websocket.Conn]bool
	maxClients int
	logger     *slog.Logger
}

func NewHub() *Hub {
	return &Hub{
		clients:    make(map[*websocket.Conn]bool),
		maxClients: defaultMaxWsClients,
		logger:     slog.Default(),
	}
}

// Run is a no-op loop kept for future cleanup/heartbeat logic.
func (h *Hub) Run() {}

// SetMaxClients configures the maximum number of concurrent WebSocket clients.
func (h *Hub) SetMaxClients(n int) {
	h.mu.Lock()
	h.maxClients = n
	h.mu.Unlock()
}

func (h *Hub) Register(ws *websocket.Conn) error {
	h.mu.Lock()
	defer h.mu.Unlock()
	if len(h.clients) >= h.maxClients {
		return errors.New("max WebSocket clients reached")
	}
	h.clients[ws] = true
	return nil
}

func (h *Hub) Unregister(ws *websocket.Conn) {
	h.mu.Lock()
	delete(h.clients, ws)
	h.mu.Unlock()
}

// Broadcast sends a message to all connected WebSocket clients.
func (h *Hub) Broadcast(msg []byte) {
	h.mu.RLock()
	defer h.mu.RUnlock()
	for conn := range h.clients {
		if _, err := conn.Write(msg); err != nil {
			h.logger.Error("websocket write error", "error", err)
		}
	}
}

// ClientCount returns the number of connected clients.
func (h *Hub) ClientCount() int {
	h.mu.RLock()
	defer h.mu.RUnlock()
	return len(h.clients)
}

func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
	s.hub.mu.RLock()
	atCap := len(s.hub.clients) >= s.hub.maxClients
	s.hub.mu.RUnlock()
	if atCap {
		http.Error(w, "too many connections", http.StatusServiceUnavailable)
		return
	}

	if s.apiToken != "" {
		token := r.URL.Query().Get("token")
		if token == "" {
			auth := r.Header.Get("Authorization")
			if strings.HasPrefix(auth, "Bearer ") {
				token = strings.TrimPrefix(auth, "Bearer ")
			}
		}
		if token != s.apiToken {
			http.Error(w, "unauthorized", http.StatusUnauthorized)
			return
		}
	}

	handler := websocket.Handler(func(ws *websocket.Conn) {
		if err := s.hub.Register(ws); err != nil {
			return
		}
		defer s.hub.Unregister(ws)

		// Ping goroutine: detect dead connections by sending periodic pings.
		// A write failure (including write deadline exceeded) closes the conn,
		// causing the read loop below to exit and unregister the client.
		done := make(chan struct{})
		defer close(done)
		pingInterval, pingDeadline := wsPingInterval, wsPingDeadline // capture before goroutine starts
		go func() {
			ticker := time.NewTicker(pingInterval)
			defer ticker.Stop()
			for {
				select {
				case <-done:
					return
				case <-ticker.C:
					ws.SetWriteDeadline(time.Now().Add(pingDeadline))
					err := websocket.Message.Send(ws, "ping")
					ws.SetWriteDeadline(time.Time{})
					if err != nil {
						ws.Close()
						return
					}
				}
			}
		}()

		// Keep connection alive until client disconnects or ping fails.
		buf := make([]byte, 1024)
		for {
			if _, err := ws.Read(buf); err != nil {
				break
			}
		}
	})
	handler.ServeHTTP(w, r)
}