summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-10 23:58:00 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-10 23:58:00 +0000
commit1aea271363884f7d616d95971fd25b6c4d87de85 (patch)
treeb17e1614c3bea0e0d8e796e6df2ac608dea7d1ed /internal
parentce185cd10839879e566d0dcf4a14466f0148634f (diff)
test: sandbox coverage + fix WebSocket races
executor: add 7 tests for sandboxCloneSource, setupSandbox, and teardownSandbox (uncommitted-changes error, clean-no-commits removal). api: fix two data races in WebSocket tests — wsPingInterval/Deadline are now captured as locals before goroutine start; maxWsClients is moved from a package-level var into Hub.maxClients (with SetMaxClients method) so concurrent tests don't stomp each other. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal')
-rw-r--r--internal/api/websocket.go37
-rw-r--r--internal/api/websocket_test.go5
-rw-r--r--internal/executor/claude_test.go137
3 files changed, 162 insertions, 17 deletions
diff --git a/internal/api/websocket.go b/internal/api/websocket.go
index b5bf728..25522dc 100644
--- a/internal/api/websocket.go
+++ b/internal/api/websocket.go
@@ -16,33 +16,40 @@ import (
var (
wsPingInterval = 30 * time.Second
wsPingDeadline = 10 * time.Second
-
- // maxWsClients caps the number of concurrent WebSocket connections.
- // Exposed as a var so tests can override it.
- maxWsClients = 1000
)
+const defaultMaxWsClients = 1000
+
// Hub manages WebSocket connections and broadcasts messages.
type Hub struct {
- mu sync.RWMutex
- clients map[*websocket.Conn]bool
- logger *slog.Logger
+ mu sync.RWMutex
+ clients map[*websocket.Conn]bool
+ maxClients int
+ logger *slog.Logger
}
func NewHub() *Hub {
return &Hub{
- clients: make(map[*websocket.Conn]bool),
- logger: slog.Default(),
+ clients: make(map[*websocket.Conn]bool),
+ maxClients: defaultMaxWsClients,
+ logger: slog.Default(),
}
}
// Run is a no-op loop kept for future cleanup/heartbeat logic.
func (h *Hub) Run() {}
+// SetMaxClients configures the maximum number of concurrent WebSocket clients.
+func (h *Hub) SetMaxClients(n int) {
+ h.mu.Lock()
+ h.maxClients = n
+ h.mu.Unlock()
+}
+
func (h *Hub) Register(ws *websocket.Conn) error {
h.mu.Lock()
defer h.mu.Unlock()
- if len(h.clients) >= maxWsClients {
+ if len(h.clients) >= h.maxClients {
return errors.New("max WebSocket clients reached")
}
h.clients[ws] = true
@@ -74,7 +81,10 @@ func (h *Hub) ClientCount() int {
}
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
- if s.hub.ClientCount() >= maxWsClients {
+ s.hub.mu.RLock()
+ atCap := len(s.hub.clients) >= s.hub.maxClients
+ s.hub.mu.RUnlock()
+ if atCap {
http.Error(w, "too many connections", http.StatusServiceUnavailable)
return
}
@@ -104,15 +114,16 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
// causing the read loop below to exit and unregister the client.
done := make(chan struct{})
defer close(done)
+ pingInterval, pingDeadline := wsPingInterval, wsPingDeadline // capture before goroutine starts
go func() {
- ticker := time.NewTicker(wsPingInterval)
+ ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
- ws.SetWriteDeadline(time.Now().Add(wsPingDeadline))
+ ws.SetWriteDeadline(time.Now().Add(pingDeadline))
err := websocket.Message.Send(ws, "ping")
ws.SetWriteDeadline(time.Time{})
if err != nil {
diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go
index 72b83f2..682a555 100644
--- a/internal/api/websocket_test.go
+++ b/internal/api/websocket_test.go
@@ -99,11 +99,8 @@ func TestWebSocket_NoTokenConfigured(t *testing.T) {
// TestWebSocket_RejectsConnectionWhenAtMaxClients verifies that when the hub
// is at capacity, new WebSocket upgrade requests are rejected with 503.
func TestWebSocket_RejectsConnectionWhenAtMaxClients(t *testing.T) {
- orig := maxWsClients
- maxWsClients = 0 // immediately at capacity
- t.Cleanup(func() { maxWsClients = orig })
-
srv, _ := testServer(t)
+ srv.hub.SetMaxClients(0) // immediately at capacity
srv.StartHub()
req := httptest.NewRequest("GET", "/api/ws", nil)
diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go
index 79f294e..1f6e5be 100644
--- a/internal/executor/claude_test.go
+++ b/internal/executor/claude_test.go
@@ -4,6 +4,8 @@ import (
"context"
"io"
"log/slog"
+ "os"
+ "os/exec"
"path/filepath"
"runtime"
"strings"
@@ -341,3 +343,138 @@ func TestExecOnce_NoGoroutineLeak_OnNaturalExit(t *testing.T) {
baseline, after, after-baseline)
}
}
+
+// initGitRepo creates a git repo in dir with one commit so it is clonable.
+func initGitRepo(t *testing.T, dir string) {
+ t.Helper()
+ cmds := [][]string{
+ {"git", "-C", dir, "init"},
+ {"git", "-C", dir, "config", "user.email", "test@test"},
+ {"git", "-C", dir, "config", "user.name", "test"},
+ {"git", "-C", dir, "commit", "--allow-empty", "-m", "init"},
+ }
+ for _, args := range cmds {
+ if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil {
+ t.Fatalf("%v: %v\n%s", args, err, out)
+ }
+ }
+}
+
+func TestSandboxCloneSource_PrefersLocalRemote(t *testing.T) {
+ dir := t.TempDir()
+ initGitRepo(t, dir)
+ // Add a "local" remote pointing to a bare repo.
+ bare := t.TempDir()
+ exec.Command("git", "init", "--bare", bare).Run()
+ exec.Command("git", "-C", dir, "remote", "add", "local", bare).Run()
+ exec.Command("git", "-C", dir, "remote", "add", "origin", "https://example.com/repo").Run()
+
+ got := sandboxCloneSource(dir)
+ if got != bare {
+ t.Errorf("expected bare repo path %q, got %q", bare, got)
+ }
+}
+
+func TestSandboxCloneSource_FallsBackToOrigin(t *testing.T) {
+ dir := t.TempDir()
+ initGitRepo(t, dir)
+ originURL := "https://example.com/origin-repo"
+ exec.Command("git", "-C", dir, "remote", "add", "origin", originURL).Run()
+
+ got := sandboxCloneSource(dir)
+ if got != originURL {
+ t.Errorf("expected origin URL %q, got %q", originURL, got)
+ }
+}
+
+func TestSandboxCloneSource_FallsBackToProjectDir(t *testing.T) {
+ dir := t.TempDir()
+ initGitRepo(t, dir)
+ // No remotes configured.
+ got := sandboxCloneSource(dir)
+ if got != dir {
+ t.Errorf("expected projectDir %q (no remotes), got %q", dir, got)
+ }
+}
+
+func TestSetupSandbox_ClonesGitRepo(t *testing.T) {
+ src := t.TempDir()
+ initGitRepo(t, src)
+
+ sandbox, err := setupSandbox(src)
+ if err != nil {
+ t.Fatalf("setupSandbox: %v", err)
+ }
+ t.Cleanup(func() { os.RemoveAll(sandbox) })
+
+ // Verify sandbox is a git repo with at least one commit.
+ out, err := exec.Command("git", "-C", sandbox, "log", "--oneline").Output()
+ if err != nil {
+ t.Fatalf("git log in sandbox: %v", err)
+ }
+ if len(strings.TrimSpace(string(out))) == 0 {
+ t.Error("expected at least one commit in sandbox, got empty log")
+ }
+}
+
+func TestSetupSandbox_InitialisesNonGitDir(t *testing.T) {
+ // A plain directory (not a git repo) should be initialised then cloned.
+ src := t.TempDir()
+
+ sandbox, err := setupSandbox(src)
+ if err != nil {
+ t.Fatalf("setupSandbox on plain dir: %v", err)
+ }
+ t.Cleanup(func() { os.RemoveAll(sandbox) })
+
+ if _, err := os.Stat(filepath.Join(sandbox, ".git")); err != nil {
+ t.Errorf("sandbox should be a git repo: %v", err)
+ }
+}
+
+func TestTeardownSandbox_UncommittedChanges_ReturnsError(t *testing.T) {
+ src := t.TempDir()
+ initGitRepo(t, src)
+ sandbox, err := setupSandbox(src)
+ if err != nil {
+ t.Fatalf("setupSandbox: %v", err)
+ }
+ t.Cleanup(func() { os.RemoveAll(sandbox) })
+
+ // Leave an uncommitted file in the sandbox.
+ if err := os.WriteFile(filepath.Join(sandbox, "dirty.txt"), []byte("oops"), 0644); err != nil {
+ t.Fatal(err)
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ err = teardownSandbox(src, sandbox, logger)
+ if err == nil {
+ t.Fatal("expected error for uncommitted changes, got nil")
+ }
+ if !strings.Contains(err.Error(), "uncommitted changes") {
+ t.Errorf("expected 'uncommitted changes' in error, got: %v", err)
+ }
+ // Sandbox should be preserved (not removed) on error.
+ if _, statErr := os.Stat(sandbox); os.IsNotExist(statErr) {
+ t.Error("sandbox was removed despite error; should be preserved for debugging")
+ }
+}
+
+func TestTeardownSandbox_CleanSandboxWithNoNewCommits_RemovesSandbox(t *testing.T) {
+ src := t.TempDir()
+ initGitRepo(t, src)
+ sandbox, err := setupSandbox(src)
+ if err != nil {
+ t.Fatalf("setupSandbox: %v", err)
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ // Sandbox has no new commits beyond origin; teardown should succeed and remove it.
+ if err := teardownSandbox(src, sandbox, logger); err != nil {
+ t.Fatalf("teardownSandbox: %v", err)
+ }
+ if _, statErr := os.Stat(sandbox); !os.IsNotExist(statErr) {
+ t.Error("sandbox should have been removed after clean teardown")
+ os.RemoveAll(sandbox)
+ }
+}