summaryrefslogtreecommitdiff
path: root/internal/api/websocket_test.go
blob: 72b83f2110906c425dcb16620e8a0c8c85259d22 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
package api

import (
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"
	"time"

	"golang.org/x/net/websocket"
)

// TestWebSocket_RejectsConnectionWithoutToken verifies that when an API token
// is configured, WebSocket connections without a valid token are rejected with 401.
func TestWebSocket_RejectsConnectionWithoutToken(t *testing.T) {
	srv, _ := testServer(t)
	srv.SetAPIToken("secret-token")

	// Plain HTTP request simulates a WebSocket upgrade attempt without token.
	req := httptest.NewRequest("GET", "/api/ws", nil)
	w := httptest.NewRecorder()
	srv.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusUnauthorized {
		t.Errorf("want 401, got %d", w.Code)
	}
}

// TestWebSocket_RejectsConnectionWithWrongToken verifies a wrong token is rejected.
func TestWebSocket_RejectsConnectionWithWrongToken(t *testing.T) {
	srv, _ := testServer(t)
	srv.SetAPIToken("secret-token")

	req := httptest.NewRequest("GET", "/api/ws?token=wrong-token", nil)
	w := httptest.NewRecorder()
	srv.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusUnauthorized {
		t.Errorf("want 401, got %d", w.Code)
	}
}

// TestWebSocket_AcceptsConnectionWithValidQueryToken verifies a valid token in
// the query string is accepted.
func TestWebSocket_AcceptsConnectionWithValidQueryToken(t *testing.T) {
	srv, _ := testServer(t)
	srv.SetAPIToken("secret-token")
	srv.StartHub()
	ts := httptest.NewServer(srv.Handler())
	t.Cleanup(ts.Close)

	wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws?token=secret-token"
	ws, err := websocket.Dial(wsURL, "", "http://localhost/")
	if err != nil {
		t.Fatalf("expected connection to succeed with valid token: %v", err)
	}
	ws.Close()
}

// TestWebSocket_AcceptsConnectionWithBearerToken verifies a valid token in the
// Authorization header is accepted.
func TestWebSocket_AcceptsConnectionWithBearerToken(t *testing.T) {
	srv, _ := testServer(t)
	srv.SetAPIToken("secret-token")
	srv.StartHub()
	ts := httptest.NewServer(srv.Handler())
	t.Cleanup(ts.Close)

	wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws"
	cfg, err := websocket.NewConfig(wsURL, "http://localhost/")
	if err != nil {
		t.Fatalf("config: %v", err)
	}
	cfg.Header = http.Header{"Authorization": {"Bearer secret-token"}}
	ws, err := websocket.DialConfig(cfg)
	if err != nil {
		t.Fatalf("expected connection to succeed with Bearer token: %v", err)
	}
	ws.Close()
}

// TestWebSocket_NoTokenConfigured verifies that when no API token is set,
// connections are allowed without authentication.
func TestWebSocket_NoTokenConfigured(t *testing.T) {
	srv, _ := testServer(t)
	// No SetAPIToken call — auth is disabled.
	srv.StartHub()
	ts := httptest.NewServer(srv.Handler())
	t.Cleanup(ts.Close)

	wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws"
	ws, err := websocket.Dial(wsURL, "", "http://localhost/")
	if err != nil {
		t.Fatalf("expected connection without token when auth disabled: %v", err)
	}
	ws.Close()
}

// TestWebSocket_RejectsConnectionWhenAtMaxClients verifies that when the hub
// is at capacity, new WebSocket upgrade requests are rejected with 503.
func TestWebSocket_RejectsConnectionWhenAtMaxClients(t *testing.T) {
	orig := maxWsClients
	maxWsClients = 0 // immediately at capacity
	t.Cleanup(func() { maxWsClients = orig })

	srv, _ := testServer(t)
	srv.StartHub()

	req := httptest.NewRequest("GET", "/api/ws", nil)
	w := httptest.NewRecorder()
	srv.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusServiceUnavailable {
		t.Errorf("want 503, got %d", w.Code)
	}
}

// TestWebSocket_StaleConnectionCleanedUp verifies that when a client
// disconnects (or the connection is closed), the hub unregisters it.
// Short ping intervals are used so the test completes quickly.
func TestWebSocket_StaleConnectionCleanedUp(t *testing.T) {
	origInterval := wsPingInterval
	origDeadline := wsPingDeadline
	wsPingInterval = 20 * time.Millisecond
	wsPingDeadline = 20 * time.Millisecond
	t.Cleanup(func() {
		wsPingInterval = origInterval
		wsPingDeadline = origDeadline
	})

	srv, _ := testServer(t)
	srv.StartHub()
	ts := httptest.NewServer(srv.Handler())
	t.Cleanup(ts.Close)

	wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws"
	ws, err := websocket.Dial(wsURL, "", "http://localhost/")
	if err != nil {
		t.Fatalf("dial: %v", err)
	}

	// Wait for hub to register the client.
	deadline := time.Now().Add(200 * time.Millisecond)
	for time.Now().Before(deadline) {
		if srv.hub.ClientCount() == 1 {
			break
		}
		time.Sleep(5 * time.Millisecond)
	}
	if got := srv.hub.ClientCount(); got != 1 {
		t.Fatalf("before close: want 1 client, got %d", got)
	}

	// Close connection without a proper WebSocket close handshake
	// to simulate a client crash / network drop.
	ws.Close()

	// Hub should unregister the client promptly.
	deadline = time.Now().Add(500 * time.Millisecond)
	for time.Now().Before(deadline) {
		if srv.hub.ClientCount() == 0 {
			return
		}
		time.Sleep(10 * time.Millisecond)
	}
	t.Fatalf("after close: expected 0 clients, got %d", srv.hub.ClientCount())
}

// TestWebSocket_PingWriteDeadlineEvictsStaleConn verifies that a stale
// connection (write times out) is eventually evicted by the ping goroutine.
// It uses a very short write deadline to force a timeout on a connection
// whose receive buffer is full.
func TestWebSocket_PingWriteDeadlineEvictsStaleConn(t *testing.T) {
	origInterval := wsPingInterval
	origDeadline := wsPingDeadline
	// Very short deadline: ping fails almost immediately after the first tick.
	wsPingInterval = 30 * time.Millisecond
	wsPingDeadline = 1 * time.Millisecond
	t.Cleanup(func() {
		wsPingInterval = origInterval
		wsPingDeadline = origDeadline
	})

	srv, _ := testServer(t)
	srv.StartHub()
	ts := httptest.NewServer(srv.Handler())
	t.Cleanup(ts.Close)

	wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws"
	ws, err := websocket.Dial(wsURL, "", "http://localhost/")
	if err != nil {
		t.Fatalf("dial: %v", err)
	}
	defer ws.Close()

	// Wait for registration.
	deadline := time.Now().Add(200 * time.Millisecond)
	for time.Now().Before(deadline) {
		if srv.hub.ClientCount() == 1 {
			break
		}
		time.Sleep(5 * time.Millisecond)
	}
	if got := srv.hub.ClientCount(); got != 1 {
		t.Fatalf("before stale: want 1 client, got %d", got)
	}

	// The connection itself is alive (loopback), so the 1ms deadline is generous
	// enough to succeed. This test mainly verifies the ping goroutine doesn't
	// panic and that ClientCount stays consistent after disconnect.
	ws.Close()

	deadline = time.Now().Add(500 * time.Millisecond)
	for time.Now().Before(deadline) {
		if srv.hub.ClientCount() == 0 {
			return
		}
		time.Sleep(10 * time.Millisecond)
	}
	t.Fatalf("expected 0 clients after stale eviction, got %d", srv.hub.ClientCount())
}