diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/websocket.go | 37 | ||||
| -rw-r--r-- | internal/api/websocket_test.go | 5 | ||||
| -rw-r--r-- | internal/executor/claude_test.go | 137 |
3 files changed, 162 insertions, 17 deletions
diff --git a/internal/api/websocket.go b/internal/api/websocket.go index b5bf728..25522dc 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -16,33 +16,40 @@ import ( var ( wsPingInterval = 30 * time.Second wsPingDeadline = 10 * time.Second - - // maxWsClients caps the number of concurrent WebSocket connections. - // Exposed as a var so tests can override it. - maxWsClients = 1000 ) +const defaultMaxWsClients = 1000 + // Hub manages WebSocket connections and broadcasts messages. type Hub struct { - mu sync.RWMutex - clients map[*websocket.Conn]bool - logger *slog.Logger + mu sync.RWMutex + clients map[*websocket.Conn]bool + maxClients int + logger *slog.Logger } func NewHub() *Hub { return &Hub{ - clients: make(map[*websocket.Conn]bool), - logger: slog.Default(), + clients: make(map[*websocket.Conn]bool), + maxClients: defaultMaxWsClients, + logger: slog.Default(), } } // Run is a no-op loop kept for future cleanup/heartbeat logic. func (h *Hub) Run() {} +// SetMaxClients configures the maximum number of concurrent WebSocket clients. +func (h *Hub) SetMaxClients(n int) { + h.mu.Lock() + h.maxClients = n + h.mu.Unlock() +} + func (h *Hub) Register(ws *websocket.Conn) error { h.mu.Lock() defer h.mu.Unlock() - if len(h.clients) >= maxWsClients { + if len(h.clients) >= h.maxClients { return errors.New("max WebSocket clients reached") } h.clients[ws] = true @@ -74,7 +81,10 @@ func (h *Hub) ClientCount() int { } func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { - if s.hub.ClientCount() >= maxWsClients { + s.hub.mu.RLock() + atCap := len(s.hub.clients) >= s.hub.maxClients + s.hub.mu.RUnlock() + if atCap { http.Error(w, "too many connections", http.StatusServiceUnavailable) return } @@ -104,15 +114,16 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { // causing the read loop below to exit and unregister the client. done := make(chan struct{}) defer close(done) + pingInterval, pingDeadline := wsPingInterval, wsPingDeadline // capture before goroutine starts go func() { - ticker := time.NewTicker(wsPingInterval) + ticker := time.NewTicker(pingInterval) defer ticker.Stop() for { select { case <-done: return case <-ticker.C: - ws.SetWriteDeadline(time.Now().Add(wsPingDeadline)) + ws.SetWriteDeadline(time.Now().Add(pingDeadline)) err := websocket.Message.Send(ws, "ping") ws.SetWriteDeadline(time.Time{}) if err != nil { diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go index 72b83f2..682a555 100644 --- a/internal/api/websocket_test.go +++ b/internal/api/websocket_test.go @@ -99,11 +99,8 @@ func TestWebSocket_NoTokenConfigured(t *testing.T) { // TestWebSocket_RejectsConnectionWhenAtMaxClients verifies that when the hub // is at capacity, new WebSocket upgrade requests are rejected with 503. func TestWebSocket_RejectsConnectionWhenAtMaxClients(t *testing.T) { - orig := maxWsClients - maxWsClients = 0 // immediately at capacity - t.Cleanup(func() { maxWsClients = orig }) - srv, _ := testServer(t) + srv.hub.SetMaxClients(0) // immediately at capacity srv.StartHub() req := httptest.NewRequest("GET", "/api/ws", nil) diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go index 79f294e..1f6e5be 100644 --- a/internal/executor/claude_test.go +++ b/internal/executor/claude_test.go @@ -4,6 +4,8 @@ import ( "context" "io" "log/slog" + "os" + "os/exec" "path/filepath" "runtime" "strings" @@ -341,3 +343,138 @@ func TestExecOnce_NoGoroutineLeak_OnNaturalExit(t *testing.T) { baseline, after, after-baseline) } } + +// initGitRepo creates a git repo in dir with one commit so it is clonable. +func initGitRepo(t *testing.T, dir string) { + t.Helper() + cmds := [][]string{ + {"git", "-C", dir, "init"}, + {"git", "-C", dir, "config", "user.email", "test@test"}, + {"git", "-C", dir, "config", "user.name", "test"}, + {"git", "-C", dir, "commit", "--allow-empty", "-m", "init"}, + } + for _, args := range cmds { + if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil { + t.Fatalf("%v: %v\n%s", args, err, out) + } + } +} + +func TestSandboxCloneSource_PrefersLocalRemote(t *testing.T) { + dir := t.TempDir() + initGitRepo(t, dir) + // Add a "local" remote pointing to a bare repo. + bare := t.TempDir() + exec.Command("git", "init", "--bare", bare).Run() + exec.Command("git", "-C", dir, "remote", "add", "local", bare).Run() + exec.Command("git", "-C", dir, "remote", "add", "origin", "https://example.com/repo").Run() + + got := sandboxCloneSource(dir) + if got != bare { + t.Errorf("expected bare repo path %q, got %q", bare, got) + } +} + +func TestSandboxCloneSource_FallsBackToOrigin(t *testing.T) { + dir := t.TempDir() + initGitRepo(t, dir) + originURL := "https://example.com/origin-repo" + exec.Command("git", "-C", dir, "remote", "add", "origin", originURL).Run() + + got := sandboxCloneSource(dir) + if got != originURL { + t.Errorf("expected origin URL %q, got %q", originURL, got) + } +} + +func TestSandboxCloneSource_FallsBackToProjectDir(t *testing.T) { + dir := t.TempDir() + initGitRepo(t, dir) + // No remotes configured. + got := sandboxCloneSource(dir) + if got != dir { + t.Errorf("expected projectDir %q (no remotes), got %q", dir, got) + } +} + +func TestSetupSandbox_ClonesGitRepo(t *testing.T) { + src := t.TempDir() + initGitRepo(t, src) + + sandbox, err := setupSandbox(src) + if err != nil { + t.Fatalf("setupSandbox: %v", err) + } + t.Cleanup(func() { os.RemoveAll(sandbox) }) + + // Verify sandbox is a git repo with at least one commit. + out, err := exec.Command("git", "-C", sandbox, "log", "--oneline").Output() + if err != nil { + t.Fatalf("git log in sandbox: %v", err) + } + if len(strings.TrimSpace(string(out))) == 0 { + t.Error("expected at least one commit in sandbox, got empty log") + } +} + +func TestSetupSandbox_InitialisesNonGitDir(t *testing.T) { + // A plain directory (not a git repo) should be initialised then cloned. + src := t.TempDir() + + sandbox, err := setupSandbox(src) + if err != nil { + t.Fatalf("setupSandbox on plain dir: %v", err) + } + t.Cleanup(func() { os.RemoveAll(sandbox) }) + + if _, err := os.Stat(filepath.Join(sandbox, ".git")); err != nil { + t.Errorf("sandbox should be a git repo: %v", err) + } +} + +func TestTeardownSandbox_UncommittedChanges_ReturnsError(t *testing.T) { + src := t.TempDir() + initGitRepo(t, src) + sandbox, err := setupSandbox(src) + if err != nil { + t.Fatalf("setupSandbox: %v", err) + } + t.Cleanup(func() { os.RemoveAll(sandbox) }) + + // Leave an uncommitted file in the sandbox. + if err := os.WriteFile(filepath.Join(sandbox, "dirty.txt"), []byte("oops"), 0644); err != nil { + t.Fatal(err) + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + err = teardownSandbox(src, sandbox, logger) + if err == nil { + t.Fatal("expected error for uncommitted changes, got nil") + } + if !strings.Contains(err.Error(), "uncommitted changes") { + t.Errorf("expected 'uncommitted changes' in error, got: %v", err) + } + // Sandbox should be preserved (not removed) on error. + if _, statErr := os.Stat(sandbox); os.IsNotExist(statErr) { + t.Error("sandbox was removed despite error; should be preserved for debugging") + } +} + +func TestTeardownSandbox_CleanSandboxWithNoNewCommits_RemovesSandbox(t *testing.T) { + src := t.TempDir() + initGitRepo(t, src) + sandbox, err := setupSandbox(src) + if err != nil { + t.Fatalf("setupSandbox: %v", err) + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + // Sandbox has no new commits beyond origin; teardown should succeed and remove it. + if err := teardownSandbox(src, sandbox, logger); err != nil { + t.Fatalf("teardownSandbox: %v", err) + } + if _, statErr := os.Stat(sandbox); !os.IsNotExist(statErr) { + t.Error("sandbox should have been removed after clean teardown") + os.RemoveAll(sandbox) + } +} |
