summaryrefslogtreecommitdiff
path: root/internal/executor
diff options
context:
space:
mode:
authorClaudomator Agent <agent@claudomator.local>2026-03-21 23:18:50 +0000
committerClaudomator Agent <agent@claudomator.local>2026-03-21 23:18:50 +0000
commit8dca9bbb0baee59ffe0d3127180ef0958dda8b91 (patch)
treee887036f4cce0f10694c5b9a29f4b4dc251769ba /internal/executor
parent9e35f7e4087cfa6017cb65ec6a7036f394f5eb22 (diff)
feat: executor reliability — per-agent limit, drain gate, pre-flight creds, auth recovery
- maxPerAgent=1: only 1 in-flight execution per agent type at a time; excess tasks are requeued after 30s - Drain gate: after 2 consecutive failures the agent is drained and a question is set on the task; reset on first success; POST /api/pool/agents/{agent}/undrain to acknowledge - Pre-flight credential check: verify .credentials.json and .claude.json exist in agentHome before spinning up a container - Auth error auto-recovery: detect auth errors (Not logged in, OAuth token has expired, etc.) and retry once after running sync-credentials and re-copying fresh credentials - Extracted runContainer() helper from ContainerRunner.Run() to support the retry flow - Wire CredentialSyncCmd in serve.go for all three ContainerRunner instances - Tests: TestPool_MaxPerAgent_*, TestPool_ConsecutiveFailures_*, TestPool_Undrain_*, TestContainerRunner_Missing{Credentials,Settings}_FailsFast, TestIsAuthError_*, TestContainerRunner_AuthError_SyncsAndRetries Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal/executor')
-rw-r--r--internal/executor/container.go128
-rw-r--r--internal/executor/container_test.go171
-rw-r--r--internal/executor/executor.go116
-rw-r--r--internal/executor/executor_test.go263
4 files changed, 602 insertions, 76 deletions
diff --git a/internal/executor/container.go b/internal/executor/container.go
index 5421108..d9ed8ef 100644
--- a/internal/executor/container.go
+++ b/internal/executor/container.go
@@ -2,6 +2,7 @@ package executor
import (
"context"
+ "errors"
"fmt"
"log/slog"
"os"
@@ -22,14 +23,26 @@ type ContainerRunner struct {
LogDir string
APIURL string
DropsDir string
- SSHAuthSock string // optional path to host SSH agent
- ClaudeBinary string // optional path to claude binary in container
- GeminiBinary string // optional path to gemini binary in container
- ClaudeConfigDir string // host path to ~/.claude; mounted into container for auth credentials
+ SSHAuthSock string // optional path to host SSH agent
+ ClaudeBinary string // optional path to claude binary in container
+ GeminiBinary string // optional path to gemini binary in container
+ ClaudeConfigDir string // host path to ~/.claude; mounted into container for auth credentials
+ CredentialSyncCmd string // optional path to sync-credentials script for auth-error auto-recovery
// Command allows mocking exec.CommandContext for tests.
Command func(ctx context.Context, name string, arg ...string) *exec.Cmd
}
+func isAuthError(err error) bool {
+ if err == nil {
+ return false
+ }
+ s := err.Error()
+ return strings.Contains(s, "Not logged in") ||
+ strings.Contains(s, "OAuth token has expired") ||
+ strings.Contains(s, "authentication_error") ||
+ strings.Contains(s, "Please run /login")
+}
+
func (r *ContainerRunner) command(ctx context.Context, name string, arg ...string) *exec.Cmd {
if r.Command != nil {
return r.Command(ctx, name, arg...)
@@ -51,14 +64,6 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
return fmt.Errorf("task %s has no repository_url", t.ID)
}
- image := t.Agent.ContainerImage
- if image == "" {
- image = r.Image
- }
- if image == "" {
- image = "claudomator-agent:latest"
- }
-
// 1. Setup workspace on host
var workspace string
isResume := false
@@ -106,6 +111,81 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
}
e.SandboxDir = workspace
+ // Set up a writable $HOME staging dir so any agent tool (claude, gemini, etc.)
+ // can freely create subdirs (session-env, .gemini, .cache, …) without hitting
+ // a non-existent or read-only home. We copy only the claude credentials into it.
+ agentHome := filepath.Join(workspace, ".agent-home")
+ if err := os.MkdirAll(filepath.Join(agentHome, ".claude"), 0755); err != nil {
+ return fmt.Errorf("creating agent home staging dir: %w", err)
+ }
+ if err := os.MkdirAll(filepath.Join(agentHome, ".gemini"), 0755); err != nil {
+ return fmt.Errorf("creating .gemini dir: %w", err)
+ }
+ if r.ClaudeConfigDir != "" {
+ // credentials
+ if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".credentials.json")); readErr == nil {
+ _ = os.WriteFile(filepath.Join(agentHome, ".claude", ".credentials.json"), srcData, 0600)
+ }
+ // settings (used by claude CLI; copy so it can write updates without hitting the host)
+ if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".claude.json")); readErr == nil {
+ _ = os.WriteFile(filepath.Join(agentHome, ".claude.json"), srcData, 0644)
+ }
+ }
+
+ // Pre-flight: verify credentials were actually copied before spinning up a container.
+ if r.ClaudeConfigDir != "" {
+ credsPath := filepath.Join(agentHome, ".claude", ".credentials.json")
+ settingsPath := filepath.Join(agentHome, ".claude.json")
+ if _, err := os.Stat(credsPath); os.IsNotExist(err) {
+ return fmt.Errorf("credentials not found at %s — run sync-credentials", r.ClaudeConfigDir)
+ }
+ if _, err := os.Stat(settingsPath); os.IsNotExist(err) {
+ return fmt.Errorf("claude settings (.claude.json) not found at %s — run sync-credentials", r.ClaudeConfigDir)
+ }
+ }
+
+ // Run container (with auth retry on failure).
+ runErr := r.runContainer(ctx, t, e, workspace, agentHome, isResume)
+ if runErr != nil && isAuthError(runErr) && r.CredentialSyncCmd != "" {
+ r.Logger.Warn("auth failure detected, syncing credentials and retrying once", "taskID", t.ID)
+ syncOut, syncErr := r.command(ctx, r.CredentialSyncCmd).CombinedOutput()
+ if syncErr != nil {
+ r.Logger.Warn("sync-credentials failed", "error", syncErr, "output", string(syncOut))
+ }
+ // Re-copy credentials into agentHome with fresh files.
+ if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".credentials.json")); readErr == nil {
+ _ = os.WriteFile(filepath.Join(agentHome, ".claude", ".credentials.json"), srcData, 0600)
+ }
+ if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".claude.json")); readErr == nil {
+ _ = os.WriteFile(filepath.Join(agentHome, ".claude.json"), srcData, 0644)
+ }
+ runErr = r.runContainer(ctx, t, e, workspace, agentHome, isResume)
+ }
+
+ if runErr == nil {
+ success = true
+ }
+ var blockedErr *BlockedError
+ if errors.As(runErr, &blockedErr) {
+ isBlocked = true
+ success = true // preserve workspace for resumption
+ }
+ return runErr
+}
+
+// runContainer runs the docker container for the given task and handles log setup,
+// environment files, instructions, and post-execution git operations.
+func (r *ContainerRunner) runContainer(ctx context.Context, t *task.Task, e *storage.Execution, workspace, agentHome string, isResume bool) error {
+ repoURL := t.RepositoryURL
+
+ image := t.Agent.ContainerImage
+ if image == "" {
+ image = r.Image
+ }
+ if image == "" {
+ image = "claudomator-agent:latest"
+ }
+
// 3. Prepare logs
logDir := r.ExecLogDir(e.ID)
if logDir == "" {
@@ -145,27 +225,6 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
return fmt.Errorf("writing instructions: %w", err)
}
- // Set up a writable $HOME staging dir so any agent tool (claude, gemini, etc.)
- // can freely create subdirs (session-env, .gemini, .cache, …) without hitting
- // a non-existent or read-only home. We copy only the claude credentials into it.
- agentHome := filepath.Join(workspace, ".agent-home")
- if err := os.MkdirAll(filepath.Join(agentHome, ".claude"), 0755); err != nil {
- return fmt.Errorf("creating agent home staging dir: %w", err)
- }
- if err := os.MkdirAll(filepath.Join(agentHome, ".gemini"), 0755); err != nil {
- return fmt.Errorf("creating .gemini dir: %w", err)
- }
- if r.ClaudeConfigDir != "" {
- // credentials
- if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".credentials.json")); readErr == nil {
- _ = os.WriteFile(filepath.Join(agentHome, ".claude", ".credentials.json"), srcData, 0600)
- }
- // settings (used by claude CLI; copy so it can write updates without hitting the host)
- if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".claude.json")); readErr == nil {
- _ = os.WriteFile(filepath.Join(agentHome, ".claude.json"), srcData, 0644)
- }
- }
-
args := r.buildDockerArgs(workspace, agentHome, e.TaskID)
innerCmd := r.buildInnerCmd(t, e, isResume)
@@ -233,8 +292,6 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
r.Logger.Info("treating question file as completion report", "taskID", e.TaskID)
e.Summary = extractQuestionText(questionJSON)
} else {
- isBlocked = true
- success = true // We consider BLOCKED as a "success" for workspace preservation
if e.SessionID == "" {
r.Logger.Warn("missing session ID; resume will start fresh", "taskID", e.TaskID)
}
@@ -278,7 +335,6 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
}
r.Logger.Info("no new commits to push", "taskID", t.ID)
}
- success = true
}
if waitErr != nil {
diff --git a/internal/executor/container_test.go b/internal/executor/container_test.go
index be80b51..b6946ef 100644
--- a/internal/executor/container_test.go
+++ b/internal/executor/container_test.go
@@ -7,6 +7,7 @@ import (
"log/slog"
"os"
"os/exec"
+ "path/filepath"
"strings"
"testing"
@@ -343,3 +344,173 @@ func TestGitSafe_PrependsSafeDirectory(t *testing.T) {
}
}
}
+
+func TestContainerRunner_MissingCredentials_FailsFast(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+
+ claudeConfigDir := t.TempDir()
+
+ // Set up ClaudeConfigDir with MISSING credentials (so pre-flight fails)
+ // Don't create .credentials.json
+ // But DO create .claude.json so the test isolates the credentials check
+ if err := os.WriteFile(filepath.Join(claudeConfigDir, ".claude.json"), []byte("{}"), 0644); err != nil {
+ t.Fatal(err)
+ }
+
+ runner := &ContainerRunner{
+ Logger: logger,
+ Image: "busybox",
+ ClaudeConfigDir: claudeConfigDir,
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if name == "git" && len(arg) > 0 && arg[0] == "clone" {
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ return exec.Command("true")
+ }
+ return exec.Command("true")
+ },
+ }
+
+ tk := &task.Task{
+ ID: "test-missing-creds",
+ RepositoryURL: "https://github.com/example/repo.git",
+ Agent: task.AgentConfig{Type: "claude"},
+ }
+ e := &storage.Execution{ID: "test-exec", TaskID: "test-missing-creds"}
+
+ err := runner.Run(context.Background(), tk, e)
+ if err == nil {
+ t.Fatal("expected error due to missing credentials, got nil")
+ }
+ if !strings.Contains(err.Error(), "credentials not found") {
+ t.Errorf("expected 'credentials not found' error, got: %v", err)
+ }
+}
+
+func TestContainerRunner_MissingSettings_FailsFast(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+
+ claudeConfigDir := t.TempDir()
+
+ // Only create credentials but NOT .claude.json
+ if err := os.WriteFile(filepath.Join(claudeConfigDir, ".credentials.json"), []byte("{}"), 0600); err != nil {
+ t.Fatal(err)
+ }
+
+ runner := &ContainerRunner{
+ Logger: logger,
+ Image: "busybox",
+ ClaudeConfigDir: claudeConfigDir,
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if name == "git" && len(arg) > 0 && arg[0] == "clone" {
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ return exec.Command("true")
+ }
+ return exec.Command("true")
+ },
+ }
+
+ tk := &task.Task{
+ ID: "test-missing-settings",
+ RepositoryURL: "https://github.com/example/repo.git",
+ Agent: task.AgentConfig{Type: "claude"},
+ }
+ e := &storage.Execution{ID: "test-exec-2", TaskID: "test-missing-settings"}
+
+ err := runner.Run(context.Background(), tk, e)
+ if err == nil {
+ t.Fatal("expected error due to missing settings, got nil")
+ }
+ if !strings.Contains(err.Error(), "claude settings") {
+ t.Errorf("expected 'claude settings' error, got: %v", err)
+ }
+}
+
+func TestIsAuthError_DetectsAllVariants(t *testing.T) {
+ tests := []struct {
+ msg string
+ want bool
+ }{
+ {"Not logged in", true},
+ {"OAuth token has expired", true},
+ {"authentication_error: invalid token", true},
+ {"Please run /login to authenticate", true},
+ {"container execution failed: exit status 1", false},
+ {"git clone failed", false},
+ {"", false},
+ }
+ for _, tt := range tests {
+ var err error
+ if tt.msg != "" {
+ err = fmt.Errorf("%s", tt.msg)
+ }
+ got := isAuthError(err)
+ if got != tt.want {
+ t.Errorf("isAuthError(%q) = %v, want %v", tt.msg, got, tt.want)
+ }
+ }
+}
+
+func TestContainerRunner_AuthError_SyncsAndRetries(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+
+ // Create a sync script that creates a marker file
+ syncDir := t.TempDir()
+ syncMarker := filepath.Join(syncDir, "sync-called")
+ syncScript := filepath.Join(syncDir, "sync-creds")
+ os.WriteFile(syncScript, []byte("#!/bin/sh\ntouch "+syncMarker+"\n"), 0755)
+
+ claudeConfigDir := t.TempDir()
+ // Create both credential files in ClaudeConfigDir
+ os.WriteFile(filepath.Join(claudeConfigDir, ".credentials.json"), []byte(`{"token":"fresh"}`), 0600)
+ os.WriteFile(filepath.Join(claudeConfigDir, ".claude.json"), []byte("{}"), 0644)
+
+ callCount := 0
+ runner := &ContainerRunner{
+ Logger: logger,
+ Image: "busybox",
+ ClaudeConfigDir: claudeConfigDir,
+ CredentialSyncCmd: syncScript,
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if name == "git" {
+ if len(arg) > 0 && arg[0] == "clone" {
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ }
+ return exec.Command("true")
+ }
+ if name == "docker" {
+ callCount++
+ if callCount == 1 {
+ // First docker call fails with auth error
+ return exec.Command("sh", "-c", "echo 'Not logged in' >&2; exit 1")
+ }
+ // Second docker call "succeeds"
+ return exec.Command("sh", "-c", "exit 0")
+ }
+ if name == syncScript {
+ return exec.Command("sh", "-c", "touch "+syncMarker)
+ }
+ return exec.Command("true")
+ },
+ }
+
+ tk := &task.Task{
+ ID: "auth-retry-test",
+ RepositoryURL: "https://github.com/example/repo.git",
+ Agent: task.AgentConfig{Type: "claude", Instructions: "test"},
+ }
+ e := &storage.Execution{ID: "auth-retry-exec", TaskID: "auth-retry-test"}
+
+ // Run — first attempt will fail with auth error, triggering sync+retry
+ runner.Run(context.Background(), tk, e)
+ // We don't check error strictly since second run may also fail (git push etc.)
+ // What we care about is that docker was called twice and sync was called
+ if callCount < 2 {
+ t.Errorf("expected docker to be called at least twice (original + retry), got %d", callCount)
+ }
+ if _, err := os.Stat(syncMarker); os.IsNotExist(err) {
+ t.Error("expected sync-credentials to be called, but marker file not found")
+ }
+}
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 1f2c27d..7513916 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -2,6 +2,7 @@ package executor
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"log/slog"
@@ -55,21 +56,24 @@ type workItem struct {
// Pool manages a bounded set of concurrent task workers.
type Pool struct {
maxConcurrent int
+ maxPerAgent int
runners map[string]Runner
store Store
logger *slog.Logger
depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s
- mu sync.Mutex
- active int
- activePerAgent map[string]int
- rateLimited map[string]time.Time // agentType -> until
- cancels map[string]context.CancelFunc // taskID → cancel
- resultCh chan *Result
- workCh chan workItem // internal bounded queue; Submit enqueues here
- doneCh chan struct{} // signals when a worker slot is freed
- Questions *QuestionRegistry
- Classifier *Classifier
+ mu sync.Mutex
+ active int
+ activePerAgent map[string]int
+ rateLimited map[string]time.Time // agentType -> until
+ cancels map[string]context.CancelFunc // taskID → cancel
+ consecutiveFailures map[string]int // agentType -> count
+ drained map[string]bool // agentType -> true if halted pending human ack
+ resultCh chan *Result
+ workCh chan workItem // internal bounded queue; Submit enqueues here
+ doneCh chan struct{} // signals when a worker slot is freed
+ Questions *QuestionRegistry
+ Classifier *Classifier
}
// Result is emitted when a task execution completes.
@@ -84,18 +88,21 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger *
maxConcurrent = 1
}
p := &Pool{
- maxConcurrent: maxConcurrent,
- runners: runners,
- store: store,
- logger: logger,
- depPollInterval: 5 * time.Second,
- activePerAgent: make(map[string]int),
- rateLimited: make(map[string]time.Time),
- cancels: make(map[string]context.CancelFunc),
- resultCh: make(chan *Result, maxConcurrent*2),
- workCh: make(chan workItem, maxConcurrent*10+100),
- doneCh: make(chan struct{}, maxConcurrent),
- Questions: NewQuestionRegistry(),
+ maxConcurrent: maxConcurrent,
+ maxPerAgent: 1,
+ runners: runners,
+ store: store,
+ logger: logger,
+ depPollInterval: 5 * time.Second,
+ activePerAgent: make(map[string]int),
+ rateLimited: make(map[string]time.Time),
+ cancels: make(map[string]context.CancelFunc),
+ consecutiveFailures: make(map[string]int),
+ drained: make(map[string]bool),
+ resultCh: make(chan *Result, maxConcurrent*2),
+ workCh: make(chan workItem, maxConcurrent*10+100),
+ doneCh: make(chan struct{}, maxConcurrent),
+ Questions: NewQuestionRegistry(),
}
go p.dispatch()
return p
@@ -336,8 +343,29 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.
if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil {
p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err)
}
+ p.mu.Lock()
+ p.consecutiveFailures[agentType]++
+ failures := p.consecutiveFailures[agentType]
+ p.mu.Unlock()
+ if failures >= 2 {
+ p.mu.Lock()
+ p.drained[agentType] = true
+ p.mu.Unlock()
+ p.logger.Warn("agent drained after consecutive failures", "agent", agentType, "failures", failures)
+ questionJSON, _ := json.Marshal(map[string]string{
+ "question": fmt.Sprintf("Agent %q has failed %d times in a row (last error: %s). Acknowledge to resume.", agentType, failures, exec.ErrorMsg),
+ "options": "acknowledge",
+ })
+ if err := p.store.UpdateTaskQuestion(t.ID, string(questionJSON)); err != nil {
+ p.logger.Error("failed to set drain question", "error", err)
+ }
+ }
}
} else {
+ p.mu.Lock()
+ p.consecutiveFailures[agentType] = 0
+ p.drained[agentType] = false
+ p.mu.Unlock()
if t.ParentTaskID == "" {
subtasks, subErr := p.store.ListSubtasks(t.ID)
if subErr != nil {
@@ -392,6 +420,14 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.
p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err}
}
+// UndrainingAgent resets the drain state and failure counter for the given agent type.
+func (p *Pool) UndrainingAgent(agentType string) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.drained[agentType] = false
+ p.consecutiveFailures[agentType] = 0
+}
+
// ActiveCount returns the number of currently running tasks.
func (p *Pool) ActiveCount() int {
p.mu.Lock()
@@ -520,13 +556,6 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
agentType = "claude"
}
- p.mu.Lock()
- if deadline, ok := p.rateLimited[agentType]; ok && time.Now().After(deadline) {
- delete(p.rateLimited, agentType)
- }
- p.activePerAgent[agentType]++
- p.mu.Unlock()
-
defer func() {
p.mu.Lock()
p.active--
@@ -537,6 +566,35 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
}
}()
+ p.mu.Lock()
+ if p.drained[agentType] {
+ p.mu.Unlock()
+ time.AfterFunc(2*time.Minute, func() { p.workCh <- workItem{ctx: ctx, task: t} })
+ return
+ }
+ if p.activePerAgent[agentType] >= p.maxPerAgent {
+ p.mu.Unlock()
+ time.AfterFunc(30*time.Second, func() { p.workCh <- workItem{ctx: ctx, task: t} })
+ return
+ }
+ if deadline, ok := p.rateLimited[agentType]; ok && time.Now().After(deadline) {
+ delete(p.rateLimited, agentType)
+ agentName := agentType
+ go func() {
+ ev := storage.AgentEvent{
+ ID: uuid.New().String(),
+ Agent: agentName,
+ Event: "available",
+ Timestamp: time.Now(),
+ }
+ if recErr := p.store.RecordAgentEvent(ev); recErr != nil {
+ p.logger.Warn("failed to record agent available event", "error", recErr)
+ }
+ }()
+ }
+ p.activePerAgent[agentType]++
+ p.mu.Unlock()
+
runner, err := p.getRunner(t)
if err != nil {
p.logger.Error("failed to get runner", "error", err, "taskID", t.ID)
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
index 91d0137..ac09cfc 100644
--- a/internal/executor/executor_test.go
+++ b/internal/executor/executor_test.go
@@ -1071,17 +1071,20 @@ func (m *minimalMockStore) lastStateUpdate() (string, task.State, bool) {
func newPoolWithMockStore(store Store) *Pool {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
return &Pool{
- maxConcurrent: 2,
- runners: map[string]Runner{"claude": &mockRunner{}},
- store: store,
- logger: logger,
- activePerAgent: make(map[string]int),
- rateLimited: make(map[string]time.Time),
- cancels: make(map[string]context.CancelFunc),
- resultCh: make(chan *Result, 4),
- workCh: make(chan workItem, 4),
- doneCh: make(chan struct{}, 2),
- Questions: NewQuestionRegistry(),
+ maxConcurrent: 2,
+ maxPerAgent: 1,
+ runners: map[string]Runner{"claude": &mockRunner{}},
+ store: store,
+ logger: logger,
+ activePerAgent: make(map[string]int),
+ rateLimited: make(map[string]time.Time),
+ cancels: make(map[string]context.CancelFunc),
+ consecutiveFailures: make(map[string]int),
+ drained: make(map[string]bool),
+ resultCh: make(chan *Result, 4),
+ workCh: make(chan workItem, 4),
+ doneCh: make(chan struct{}, 2),
+ Questions: NewQuestionRegistry(),
}
}
@@ -1418,3 +1421,241 @@ func TestExecute_MalformedChangestats(t *testing.T) {
t.Errorf("expected nil changestats for malformed output, got %+v", execs[0].Changestats)
}
}
+
+func TestPool_MaxPerAgent_BlocksSecondTask(t *testing.T) {
+ store := testStore(t)
+
+ var mu sync.Mutex
+ concurrentRuns := 0
+ maxConcurrent := 0
+
+ runner := &mockRunner{
+ delay: 100 * time.Millisecond,
+ onRun: func(tk *task.Task, e *storage.Execution) error {
+ mu.Lock()
+ concurrentRuns++
+ if concurrentRuns > maxConcurrent {
+ maxConcurrent = concurrentRuns
+ }
+ mu.Unlock()
+ time.Sleep(100 * time.Millisecond)
+ mu.Lock()
+ concurrentRuns--
+ mu.Unlock()
+ return nil
+ },
+ }
+ runners := map[string]Runner{"claude": runner}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runners, store, logger) // pool size 2, but maxPerAgent=1
+
+ tk1 := makeTask("mpa-1")
+ tk2 := makeTask("mpa-2")
+ store.CreateTask(tk1)
+ store.CreateTask(tk2)
+
+ pool.Submit(context.Background(), tk1)
+ pool.Submit(context.Background(), tk2)
+
+ for i := 0; i < 2; i++ {
+ select {
+ case <-pool.Results():
+ case <-time.After(10 * time.Second):
+ t.Fatal("timed out waiting for result")
+ }
+ }
+
+ mu.Lock()
+ got := maxConcurrent
+ mu.Unlock()
+ if got > 1 {
+ t.Errorf("maxPerAgent=1 violated: %d claude tasks ran concurrently", got)
+ }
+}
+
+func TestPool_MaxPerAgent_AllowsDifferentAgents(t *testing.T) {
+ store := testStore(t)
+
+ var mu sync.Mutex
+ concurrentRuns := 0
+ maxConcurrent := 0
+
+ makeSlowRunner := func() *mockRunner {
+ return &mockRunner{
+ onRun: func(tk *task.Task, e *storage.Execution) error {
+ mu.Lock()
+ concurrentRuns++
+ if concurrentRuns > maxConcurrent {
+ maxConcurrent = concurrentRuns
+ }
+ mu.Unlock()
+ time.Sleep(80 * time.Millisecond)
+ mu.Lock()
+ concurrentRuns--
+ mu.Unlock()
+ return nil
+ },
+ }
+ }
+
+ runners := map[string]Runner{
+ "claude": makeSlowRunner(),
+ "gemini": makeSlowRunner(),
+ }
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runners, store, logger)
+
+ tk1 := makeTask("da-1")
+ tk1.Agent.Type = "claude"
+ tk2 := makeTask("da-2")
+ tk2.Agent.Type = "gemini"
+ store.CreateTask(tk1)
+ store.CreateTask(tk2)
+
+ pool.Submit(context.Background(), tk1)
+ pool.Submit(context.Background(), tk2)
+
+ for i := 0; i < 2; i++ {
+ select {
+ case <-pool.Results():
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for result")
+ }
+ }
+
+ mu.Lock()
+ got := maxConcurrent
+ mu.Unlock()
+ if got < 2 {
+ t.Errorf("different agents should run concurrently; max concurrent was %d", got)
+ }
+}
+
+func TestPool_ConsecutiveFailures_DrainAtTwo(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{err: fmt.Errorf("boom")}
+ runners := map[string]Runner{"claude": runner}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runners, store, logger)
+
+ // Submit two failing tasks
+ for _, id := range []string{"cf-1", "cf-2"} {
+ tk := makeTask(id)
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+ <-pool.Results() // drain
+ }
+
+ pool.mu.Lock()
+ drained := pool.drained["claude"]
+ failures := pool.consecutiveFailures["claude"]
+ pool.mu.Unlock()
+
+ if !drained {
+ t.Error("expected claude to be drained after 2 consecutive failures")
+ }
+ if failures < 2 {
+ t.Errorf("expected consecutiveFailures >= 2, got %d", failures)
+ }
+
+ // The second task should have a drain question set
+ tk2, err := store.GetTask("cf-2")
+ if err != nil {
+ t.Fatalf("GetTask: %v", err)
+ }
+ if tk2.QuestionJSON == "" {
+ t.Error("expected drain question to be set on task after drain")
+ }
+}
+
+func TestPool_ConsecutiveFailures_ResetOnSuccess(t *testing.T) {
+ store := testStore(t)
+
+ callCount := 0
+ runner := &mockRunner{
+ onRun: func(tk *task.Task, e *storage.Execution) error {
+ callCount++
+ if callCount == 1 {
+ return fmt.Errorf("first failure")
+ }
+ return nil // second call succeeds
+ },
+ }
+ runners := map[string]Runner{"claude": runner}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runners, store, logger)
+
+ // First task fails
+ tk1 := makeTask("rs-1")
+ store.CreateTask(tk1)
+ pool.Submit(context.Background(), tk1)
+ <-pool.Results()
+
+ pool.mu.Lock()
+ failsBefore := pool.consecutiveFailures["claude"]
+ pool.mu.Unlock()
+ if failsBefore != 1 {
+ t.Errorf("expected 1 failure after first task, got %d", failsBefore)
+ }
+
+ // Second task succeeds
+ tk2 := makeTask("rs-2")
+ store.CreateTask(tk2)
+ pool.Submit(context.Background(), tk2)
+ <-pool.Results()
+
+ pool.mu.Lock()
+ failsAfter := pool.consecutiveFailures["claude"]
+ isDrained := pool.drained["claude"]
+ pool.mu.Unlock()
+
+ if failsAfter != 0 {
+ t.Errorf("expected consecutiveFailures reset to 0 after success, got %d", failsAfter)
+ }
+ if isDrained {
+ t.Error("expected drained to be false after success")
+ }
+}
+
+func TestPool_Undrain_ResumesExecution(t *testing.T) {
+ store := testStore(t)
+
+ // Force drain state
+ runner := &mockRunner{}
+ runners := map[string]Runner{"claude": runner}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runners, store, logger)
+
+ pool.mu.Lock()
+ pool.drained["claude"] = true
+ pool.consecutiveFailures["claude"] = 3
+ pool.mu.Unlock()
+
+ // Undrain
+ pool.UndrainingAgent("claude")
+
+ pool.mu.Lock()
+ drained := pool.drained["claude"]
+ failures := pool.consecutiveFailures["claude"]
+ pool.mu.Unlock()
+
+ if drained {
+ t.Error("expected drained=false after UndrainingAgent")
+ }
+ if failures != 0 {
+ t.Errorf("expected consecutiveFailures=0 after UndrainingAgent, got %d", failures)
+ }
+
+ // Verify a task can now run
+ tk := makeTask("undrain-1")
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+ select {
+ case result := <-pool.Results():
+ if result.Err != nil {
+ t.Errorf("unexpected error after undrain: %v", result.Err)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for task after undrain")
+ }
+}