diff options
Diffstat (limited to 'internal/executor')
| -rw-r--r-- | internal/executor/claude.go | 170 | ||||
| -rw-r--r-- | internal/executor/claude_test.go | 127 | ||||
| -rw-r--r-- | internal/executor/container.go | 549 | ||||
| -rw-r--r-- | internal/executor/container_test.go | 687 | ||||
| -rw-r--r-- | internal/executor/executor.go | 695 | ||||
| -rw-r--r-- | internal/executor/executor_test.go | 911 | ||||
| -rw-r--r-- | internal/executor/helpers.go | 205 | ||||
| -rw-r--r-- | internal/executor/preamble.go | 1 | ||||
| -rw-r--r-- | internal/executor/preamble_test.go | 7 | ||||
| -rw-r--r-- | internal/executor/question.go | 84 | ||||
| -rw-r--r-- | internal/executor/question_test.go | 58 | ||||
| -rw-r--r-- | internal/executor/ratelimit.go | 6 | ||||
| -rw-r--r-- | internal/executor/stream_test.go | 25 |
13 files changed, 2985 insertions, 540 deletions
diff --git a/internal/executor/claude.go b/internal/executor/claude.go index fa68382..3c87f26 100644 --- a/internal/executor/claude.go +++ b/internal/executor/claude.go @@ -1,11 +1,8 @@ package executor import ( - "bufio" "context" - "encoding/json" "fmt" - "io" "log/slog" "os" "os/exec" @@ -30,14 +27,6 @@ type ClaudeRunner struct { // BlockedError is returned by Run when the agent wrote a question file and exited. // The pool transitions the task to BLOCKED and stores the question for the user. -type BlockedError struct { - QuestionJSON string // raw JSON from the question file - SessionID string // claude session to resume once the user answers - SandboxDir string // preserved sandbox path; resume must run here so Claude finds its session files -} - -func (e *BlockedError) Error() string { return fmt.Sprintf("task blocked: %s", e.QuestionJSON) } - // ExecLogDir returns the log directory for the given execution ID. // Implements LogPather so the pool can persist paths before execution starts. func (r *ClaudeRunner) ExecLogDir(execID string) string { @@ -200,50 +189,6 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi return nil } -// isCompletionReport returns true when a question-file JSON looks like a -// completion report rather than a real user question. Heuristic: no options -// (or empty options) and no "?" anywhere in the text. -func isCompletionReport(questionJSON string) bool { - var q struct { - Text string `json:"text"` - Options []string `json:"options"` - } - if err := json.Unmarshal([]byte(questionJSON), &q); err != nil { - return false - } - return len(q.Options) == 0 && !strings.Contains(q.Text, "?") -} - -// extractQuestionText returns the "text" field from a question-file JSON, or -// the raw string if parsing fails. -func extractQuestionText(questionJSON string) string { - var q struct { - Text string `json:"text"` - } - if err := json.Unmarshal([]byte(questionJSON), &q); err != nil { - return questionJSON - } - return strings.TrimSpace(q.Text) -} - -// gitSafe returns git arguments that prepend safety overrides so that -// commands succeed regardless of the repository owner or the host's global -// git configuration. Specifically: -// -// - "-c safe.directory=*" lets us operate on directories owned by a -// different OS user. -// - "-c commit.gpgsign=false" / "-c tag.gpgsign=false" stop git from -// trying to sign commits via the host's signing tooling. Sandbox commits -// are internal and don't need to be signed; an unconfigured or broken -// signing setup on the host should never block a sandbox merge. -func gitSafe(args ...string) []string { - return append([]string{ - "-c", "safe.directory=*", - "-c", "commit.gpgsign=false", - "-c", "tag.gpgsign=false", - }, args...) -} - // sandboxCloneSource returns the URL to clone the sandbox from. It prefers a // remote named "local" (a local bare repo that accepts pushes cleanly), then // falls back to "origin", then to the working copy path itself. @@ -497,7 +442,7 @@ func (r *ClaudeRunner) execOnce(ctx context.Context, args []string, workingDir, wg.Add(1) go func() { defer wg.Done() - costUSD, streamErr = parseStream(stdoutR, stdoutFile, r.Logger) + costUSD, _, streamErr = parseStream(stdoutR, stdoutFile, r.Logger) stdoutR.Close() }() @@ -605,116 +550,3 @@ func (r *ClaudeRunner) buildArgs(t *task.Task, e *storage.Execution, questionFil return args } -// parseStream reads streaming JSON from claude, writes to w, and returns -// (costUSD, error). error is non-nil if the stream signals task failure: -// - result message has is_error:true -// - a tool_result was denied due to missing permissions -func parseStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, error) { - tee := io.TeeReader(r, w) - scanner := bufio.NewScanner(tee) - scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines - - var totalCost float64 - var streamErr error - - for scanner.Scan() { - line := scanner.Bytes() - var msg map[string]interface{} - if err := json.Unmarshal(line, &msg); err != nil { - continue - } - - msgType, _ := msg["type"].(string) - switch msgType { - case "rate_limit_event": - if info, ok := msg["rate_limit_info"].(map[string]interface{}); ok { - status, _ := info["status"].(string) - if status == "rejected" { - streamErr = fmt.Errorf("claude rate limit reached (rejected): %v", msg) - // Immediately break since we can't continue anyway - break - } - } - case "assistant": - if errStr, ok := msg["error"].(string); ok && errStr == "rate_limit" { - streamErr = fmt.Errorf("claude rate limit reached: %v", msg) - } - case "result": - if isErr, _ := msg["is_error"].(bool); isErr { - result, _ := msg["result"].(string) - if result != "" { - streamErr = fmt.Errorf("claude task failed: %s", result) - } else { - streamErr = fmt.Errorf("claude task failed (is_error=true in result)") - } - } - // Prefer total_cost_usd from result message; fall through to legacy check below. - if cost, ok := msg["total_cost_usd"].(float64); ok { - totalCost = cost - } - case "user": - // Detect permission-denial tool_results. These occur when permission_mode - // is not bypassPermissions and claude exits 0 without completing its task. - if err := permissionDenialError(msg); err != nil && streamErr == nil { - streamErr = err - } - } - - // Legacy cost field used by older claude versions. - if cost, ok := msg["cost_usd"].(float64); ok { - totalCost = cost - } - } - - return totalCost, streamErr -} - -// permissionDenialError inspects a "user" stream message for tool_result entries -// that were denied due to missing permissions. Returns an error if found. -func permissionDenialError(msg map[string]interface{}) error { - message, ok := msg["message"].(map[string]interface{}) - if !ok { - return nil - } - content, ok := message["content"].([]interface{}) - if !ok { - return nil - } - for _, item := range content { - itemMap, ok := item.(map[string]interface{}) - if !ok { - continue - } - if itemMap["type"] != "tool_result" { - continue - } - if isErr, _ := itemMap["is_error"].(bool); !isErr { - continue - } - text, _ := itemMap["content"].(string) - if strings.Contains(text, "requested permissions") || strings.Contains(text, "haven't granted") { - return fmt.Errorf("permission denied by host: %s", text) - } - } - return nil -} - -// tailFile returns the last n lines of the file at path, or empty string if -// the file cannot be read. Used to surface subprocess stderr on failure. -func tailFile(path string, n int) string { - f, err := os.Open(path) - if err != nil { - return "" - } - defer f.Close() - - var lines []string - scanner := bufio.NewScanner(f) - for scanner.Scan() { - lines = append(lines, scanner.Text()) - if len(lines) > n { - lines = lines[1:] - } - } - return strings.Join(lines, "\n") -} diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go index cbb5947..c01e160 100644 --- a/internal/executor/claude_test.go +++ b/internal/executor/claude_test.go @@ -2,7 +2,6 @@ package executor import ( "context" - "errors" "fmt" "io" "log/slog" @@ -697,57 +696,6 @@ func TestTeardownSandbox_CleanSandboxWithNoNewCommits_RemovesSandbox(t *testing. } } -// TestBlockedError_IncludesSandboxDir verifies that when a task is blocked in a -// sandbox, the BlockedError carries the sandbox path so the resume execution can -// run in the same directory (where Claude's session files are stored). -func TestBlockedError_IncludesSandboxDir(t *testing.T) { - src := t.TempDir() - initGitRepo(t, src) - - logDir := t.TempDir() - - // Use a script that writes question.json to the env-var path and exits 0 - // (simulating a blocked agent that asks a question before exiting). - scriptPath := filepath.Join(t.TempDir(), "fake-claude.sh") - if err := os.WriteFile(scriptPath, []byte(`#!/bin/sh -if [ -n "$CLAUDOMATOR_QUESTION_FILE" ]; then - printf '{"text":"Should I continue?"}' > "$CLAUDOMATOR_QUESTION_FILE" -fi -`), 0755); err != nil { - t.Fatalf("write script: %v", err) - } - - r := &ClaudeRunner{ - BinaryPath: scriptPath, - Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), - LogDir: logDir, - } - tk := &task.Task{ - Agent: task.AgentConfig{ - Type: "claude", - Instructions: "do something", - ProjectDir: src, - SkipPlanning: true, - }, - } - exec := &storage.Execution{ID: "blocked-exec-uuid", TaskID: "task-1"} - - err := r.Run(context.Background(), tk, exec) - - var blocked *BlockedError - if !errors.As(err, &blocked) { - t.Fatalf("expected BlockedError, got: %v", err) - } - if blocked.SandboxDir == "" { - t.Error("BlockedError.SandboxDir should be set when task runs in a sandbox") - } - // Sandbox should still exist (preserved for resume). - if _, statErr := os.Stat(blocked.SandboxDir); os.IsNotExist(statErr) { - t.Error("sandbox directory should be preserved when blocked") - } else { - os.RemoveAll(blocked.SandboxDir) // cleanup - } -} // TestClaudeRunner_Run_ResumeUsesStoredSandboxDir verifies that when a resume // execution has SandboxDir set, the runner uses that directory (not project_dir) @@ -853,69 +801,6 @@ func TestClaudeRunner_Run_StaleSandboxDir_ClonesAfresh(t *testing.T) { } } -func TestIsCompletionReport(t *testing.T) { - tests := []struct { - name string - json string - expected bool - }{ - { - name: "real question with options", - json: `{"text": "Should I proceed with implementation?", "options": ["Yes", "No"]}`, - expected: false, - }, - { - name: "real question no options", - json: `{"text": "Which approach do you prefer?"}`, - expected: false, - }, - { - name: "completion report no options no question mark", - json: `{"text": "All tests pass. Implementation complete. Summary written to CLAUDOMATOR_SUMMARY_FILE."}`, - expected: true, - }, - { - name: "completion report with empty options", - json: `{"text": "Feature implemented and committed.", "options": []}`, - expected: true, - }, - { - name: "invalid json treated as not a report", - json: `not json`, - expected: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := isCompletionReport(tt.json) - if got != tt.expected { - t.Errorf("isCompletionReport(%q) = %v, want %v", tt.json, got, tt.expected) - } - }) - } -} - -func TestTailFile_ReturnsLastNLines(t *testing.T) { - f, err := os.CreateTemp("", "tailfile-*") - if err != nil { - t.Fatal(err) - } - defer os.Remove(f.Name()) - for i := 1; i <= 30; i++ { - fmt.Fprintf(f, "line %d\n", i) - } - f.Close() - - got := tailFile(f.Name(), 5) - lines := strings.Split(got, "\n") - if len(lines) != 5 { - t.Fatalf("want 5 lines, got %d: %q", len(lines), got) - } - if lines[0] != "line 26" || lines[4] != "line 30" { - t.Errorf("want lines 26-30, got: %q", got) - } -} - func TestTailFile_MissingFile_ReturnsEmpty(t *testing.T) { got := tailFile("/nonexistent/path/file.log", 10) if got != "" { @@ -923,15 +808,3 @@ func TestTailFile_MissingFile_ReturnsEmpty(t *testing.T) { } } -func TestGitSafe_PrependsSafeDirectory(t *testing.T) { - got := gitSafe("-C", "/some/path", "status") - want := []string{"-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-c", "tag.gpgsign=false", "-C", "/some/path", "status"} - if len(got) != len(want) { - t.Fatalf("gitSafe() = %v, want %v", got, want) - } - for i := range want { - if got[i] != want[i] { - t.Errorf("gitSafe()[%d] = %q, want %q", i, got[i], want[i]) - } - } -} diff --git a/internal/executor/container.go b/internal/executor/container.go new file mode 100644 index 0000000..61ac29c --- /dev/null +++ b/internal/executor/container.go @@ -0,0 +1,549 @@ +package executor + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "syscall" + + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +// ContainerRunner executes an agent inside a container. +type ContainerRunner struct { + Image string // default image if not specified in task + Logger *slog.Logger + LogDir string + APIURL string + DropsDir string + SSHAuthSock string // optional path to host SSH agent + ClaudeBinary string // optional path to claude binary in container + GeminiBinary string // optional path to gemini binary in container + ClaudeConfigDir string // host path to ~/.claude; mounted into container for auth credentials + CredentialSyncCmd string // optional path to sync-credentials script for auth-error auto-recovery + Store Store // optional; used to look up stories and projects for story-aware cloning + // Command allows mocking exec.CommandContext for tests. + Command func(ctx context.Context, name string, arg ...string) *exec.Cmd +} + +func isAuthError(err error) bool { + if err == nil { + return false + } + s := err.Error() + return strings.Contains(s, "Not logged in") || + strings.Contains(s, "OAuth token has expired") || + strings.Contains(s, "authentication_error") || + strings.Contains(s, "Please run /login") +} + +func (r *ContainerRunner) command(ctx context.Context, name string, arg ...string) *exec.Cmd { + if r.Command != nil { + return r.Command(ctx, name, arg...) + } + return exec.CommandContext(ctx, name, arg...) +} + +func (r *ContainerRunner) ExecLogDir(execID string) string { + if r.LogDir == "" { + return "" + } + return filepath.Join(r.LogDir, execID) +} + +// ensureStoryBranch checks whether branchName exists in remoteURL and creates +// it from main if not. Uses localPath as a reference clone for speed if set. +func (r *ContainerRunner) ensureStoryBranch(ctx context.Context, remoteURL, branchName, localPath string) error { + // Check if branch already exists. + out, err := r.command(ctx, "git", "ls-remote", "--heads", remoteURL, branchName).CombinedOutput() + if err == nil && len(strings.TrimSpace(string(out))) > 0 { + return nil // already exists + } + + r.Logger.Info("story branch missing, creating from main", "branch", branchName, "remote", remoteURL) + + // Clone into a temp dir so we can create the branch. + tmp, err := os.MkdirTemp("", "claudomator-branchsetup-*") + if err != nil { + return fmt.Errorf("mktemp for branch setup: %w", err) + } + defer os.RemoveAll(tmp) + + // Remove the dir git clone expects to create. + if err := os.Remove(tmp); err != nil { + return fmt.Errorf("removing tmp dir before clone: %w", err) + } + + var cloneArgs []string + if localPath != "" { + cloneArgs = []string{"clone", "--reference", localPath, remoteURL, tmp} + } else { + cloneArgs = []string{"clone", remoteURL, tmp} + } + if out, err := r.command(ctx, "git", cloneArgs...).CombinedOutput(); err != nil { + return fmt.Errorf("git clone for branch setup: %w\n%s", err, string(out)) + } + if out, err := r.command(ctx, "git", "-C", tmp, "checkout", "-b", branchName).CombinedOutput(); err != nil { + return fmt.Errorf("git checkout -b %q: %w\n%s", branchName, err, string(out)) + } + if out, err := r.command(ctx, "git", "-C", tmp, "push", "origin", branchName).CombinedOutput(); err != nil { + return fmt.Errorf("git push %q: %w\n%s", branchName, err, string(out)) + } + r.Logger.Info("story branch created and pushed", "branch", branchName) + return nil +} + +func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + var err error + repoURL := t.RepositoryURL + if repoURL == "" { + return fmt.Errorf("task %s has no repository_url", t.ID) + } + + // 1. Setup workspace on host + var workspace string + isResume := false + if e.SandboxDir != "" { + if _, err = os.Stat(e.SandboxDir); err == nil { + workspace = e.SandboxDir + isResume = true + r.Logger.Info("resuming in preserved workspace", "path", workspace) + } + } + + if workspace == "" { + workspace, err = os.MkdirTemp("", "claudomator-workspace-*") + if err != nil { + return fmt.Errorf("creating workspace: %w", err) + } + // chmod applied after clone; see step 2. + } + + // Note: workspace is only removed on success. On failure, it's preserved for debugging. + // If the task becomes BLOCKED, it's also preserved for resumption. + success := false + isBlocked := false + defer func() { + if success && !isBlocked { + os.RemoveAll(workspace) + } else { + r.Logger.Warn("preserving workspace", "path", workspace, "success", success, "blocked", isBlocked) + } + }() + + // Resolve story branch and project local path if this is a story task. + var storyBranch string + var storyLocalPath string + if t.StoryID != "" && r.Store != nil { + if story, err := r.Store.GetStory(t.StoryID); err == nil && story != nil { + storyBranch = story.BranchName + if story.ProjectID != "" { + if proj, err := r.Store.GetProject(story.ProjectID); err == nil && proj != nil { + storyLocalPath = proj.LocalPath + } + } + } + } + // Fall back to task-level BranchName (e.g. set explicitly by executor or tests). + if storyBranch == "" { + storyBranch = t.BranchName + } + + // 2. Ensure story branch exists in the remote before cloning. + // If the branch is missing (e.g. story approved before fix, or branch push failed), + // create it from main using the project local path as a reference repo. + if storyBranch != "" && !isResume { + if err := r.ensureStoryBranch(ctx, repoURL, storyBranch, storyLocalPath); err != nil { + r.Logger.Warn("ensureStoryBranch failed (will attempt checkout anyway)", "branch", storyBranch, "error", err) + } + } + + // 3. Clone repo into workspace if not resuming. + // git clone requires the target directory to not exist; remove the MkdirTemp-created dir first. + if !isResume { + if err := os.Remove(workspace); err != nil { + return fmt.Errorf("removing workspace before clone: %w", err) + } + r.Logger.Info("cloning repository", "url", repoURL, "workspace", workspace) + var cloneArgs []string + if storyLocalPath != "" { + cloneArgs = []string{"clone", "--reference", storyLocalPath, repoURL, workspace} + } else { + cloneArgs = []string{"clone", repoURL, workspace} + } + if out, err := r.command(ctx, "git", cloneArgs...).CombinedOutput(); err != nil { + return fmt.Errorf("git clone failed: %w\n%s", err, string(out)) + } + if storyBranch != "" { + r.Logger.Info("checking out story branch", "branch", storyBranch) + if out, err := r.command(ctx, "git", "-C", workspace, "checkout", storyBranch).CombinedOutput(); err != nil { + return fmt.Errorf("git checkout story branch %q failed: %w\n%s", storyBranch, err, string(out)) + } + } + if err = os.Chmod(workspace, 0755); err != nil { + return fmt.Errorf("chmod cloned workspace: %w", err) + } + } + e.SandboxDir = workspace + + // Set up a writable $HOME staging dir so any agent tool (claude, gemini, etc.) + // can freely create subdirs (session-env, .gemini, .cache, …) without hitting + // a non-existent or read-only home. We copy only the claude credentials into it. + agentHome := filepath.Join(workspace, ".agent-home") + if err := os.MkdirAll(filepath.Join(agentHome, ".claude"), 0755); err != nil { + return fmt.Errorf("creating agent home staging dir: %w", err) + } + if err := os.MkdirAll(filepath.Join(agentHome, ".gemini"), 0755); err != nil { + return fmt.Errorf("creating .gemini dir: %w", err) + } + if r.ClaudeConfigDir != "" { + // credentials + if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".credentials.json")); readErr == nil { + _ = os.WriteFile(filepath.Join(agentHome, ".claude", ".credentials.json"), srcData, 0600) + } + // settings (used by claude CLI; copy so it can write updates without hitting the host) + if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".claude.json")); readErr == nil { + _ = os.WriteFile(filepath.Join(agentHome, ".claude.json"), srcData, 0644) + } + } + + // Pre-flight: verify credentials were actually copied before spinning up a container. + if r.ClaudeConfigDir != "" { + credsPath := filepath.Join(agentHome, ".claude", ".credentials.json") + settingsPath := filepath.Join(agentHome, ".claude.json") + if _, err := os.Stat(credsPath); os.IsNotExist(err) { + return fmt.Errorf("credentials not found at %s — run sync-credentials", r.ClaudeConfigDir) + } + if _, err := os.Stat(settingsPath); os.IsNotExist(err) { + return fmt.Errorf("claude settings (.claude.json) not found at %s — run sync-credentials", r.ClaudeConfigDir) + } + } + + // Run container (with auth retry on failure). + runErr := r.runContainer(ctx, t, e, workspace, agentHome, isResume, storyBranch) + if runErr != nil && isAuthError(runErr) && r.CredentialSyncCmd != "" { + r.Logger.Warn("auth failure detected, syncing credentials and retrying once", "taskID", t.ID) + syncOut, syncErr := r.command(ctx, r.CredentialSyncCmd).CombinedOutput() + if syncErr != nil { + r.Logger.Warn("sync-credentials failed", "error", syncErr, "output", string(syncOut)) + } + // Re-copy credentials into agentHome with fresh files. + if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".credentials.json")); readErr == nil { + _ = os.WriteFile(filepath.Join(agentHome, ".claude", ".credentials.json"), srcData, 0600) + } + if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".claude.json")); readErr == nil { + _ = os.WriteFile(filepath.Join(agentHome, ".claude.json"), srcData, 0644) + } + runErr = r.runContainer(ctx, t, e, workspace, agentHome, isResume, storyBranch) + } + + if runErr == nil { + success = true + } + var blockedErr *BlockedError + if errors.As(runErr, &blockedErr) { + isBlocked = true + success = true // preserve workspace for resumption + } + return runErr +} + +// runContainer runs the docker container for the given task and handles log setup, +// environment files, instructions, and post-execution git operations. +func (r *ContainerRunner) runContainer(ctx context.Context, t *task.Task, e *storage.Execution, workspace, agentHome string, isResume bool, storyBranch string) error { + repoURL := t.RepositoryURL + + image := t.Agent.ContainerImage + if image == "" { + image = r.Image + } + if image == "" { + image = "claudomator-agent:latest" + } + + // 3. Prepare logs + logDir := r.ExecLogDir(e.ID) + if logDir == "" { + logDir = filepath.Join(workspace, ".claudomator-logs") + } + if err := os.MkdirAll(logDir, 0700); err != nil { + return fmt.Errorf("creating log dir: %w", err) + } + e.StdoutPath = filepath.Join(logDir, "stdout.log") + e.StderrPath = filepath.Join(logDir, "stderr.log") + e.ArtifactDir = logDir + + stdoutFile, err := os.Create(e.StdoutPath) + if err != nil { + return fmt.Errorf("creating stdout log: %w", err) + } + defer stdoutFile.Close() + + stderrFile, err := os.Create(e.StderrPath) + if err != nil { + return fmt.Errorf("creating stderr log: %w", err) + } + defer stderrFile.Close() + + // 4. Run container + + // Write API keys to a temporary env file to avoid exposure in 'ps' or 'docker inspect' + envFile := filepath.Join(workspace, ".claudomator-env") + envContent := fmt.Sprintf("ANTHROPIC_API_KEY=%s\nGOOGLE_API_KEY=%s\nGEMINI_API_KEY=%s\n", os.Getenv("ANTHROPIC_API_KEY"), os.Getenv("GOOGLE_API_KEY"), os.Getenv("GEMINI_API_KEY")) + if err := os.WriteFile(envFile, []byte(envContent), 0600); err != nil { + return fmt.Errorf("writing env file: %w", err) + } + + // Inject custom instructions via file to avoid CLI length limits + instructionsFile := filepath.Join(workspace, ".claudomator-instructions.txt") + if err := os.WriteFile(instructionsFile, []byte(t.Agent.Instructions), 0644); err != nil { + return fmt.Errorf("writing instructions: %w", err) + } + + args := r.buildDockerArgs(workspace, agentHome, e.TaskID) + innerCmd := r.buildInnerCmd(t, e, isResume) + + fullArgs := append(args, image) + fullArgs = append(fullArgs, innerCmd...) + + r.Logger.Info("starting container", "image", image, "taskID", t.ID) + cmd := r.command(ctx, "docker", fullArgs...) + cmd.Stderr = stderrFile + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + // Use os.Pipe for stdout so we can parse it in real-time + var stdoutR, stdoutW *os.File + stdoutR, stdoutW, err = os.Pipe() + if err != nil { + return fmt.Errorf("creating stdout pipe: %w", err) + } + cmd.Stdout = stdoutW + + if err := cmd.Start(); err != nil { + stdoutW.Close() + stdoutR.Close() + return fmt.Errorf("starting container: %w", err) + } + stdoutW.Close() + + // Watch for context cancellation to kill the process group (Issue 1) + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + r.Logger.Info("killing container process group due to context cancellation", "taskID", t.ID) + syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + case <-done: + } + }() + + // Stream stdout to the log file and parse cost/errors. + var costUSD float64 + var sessionID string + var streamErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + costUSD, sessionID, streamErr = parseStream(stdoutR, stdoutFile, r.Logger) + stdoutR.Close() + }() + + waitErr := cmd.Wait() + wg.Wait() + + e.CostUSD = costUSD + if sessionID != "" { + e.SessionID = sessionID + } + + // Check whether the agent left a question before exiting. + questionFile := filepath.Join(logDir, "question.json") + if data, readErr := os.ReadFile(questionFile); readErr == nil { + os.Remove(questionFile) // consumed + questionJSON := strings.TrimSpace(string(data)) + if isCompletionReport(questionJSON) { + r.Logger.Info("treating question file as completion report", "taskID", e.TaskID) + e.Summary = extractQuestionText(questionJSON) + } else { + if e.SessionID == "" { + r.Logger.Warn("missing session ID; resume will start fresh", "taskID", e.TaskID) + } + return &BlockedError{ + QuestionJSON: questionJSON, + SessionID: e.SessionID, + SandboxDir: workspace, + } + } + } + + // Read agent summary if written. + summaryFile := filepath.Join(logDir, "summary.txt") + if summaryData, readErr := os.ReadFile(summaryFile); readErr == nil { + os.Remove(summaryFile) // consumed + e.Summary = strings.TrimSpace(string(summaryData)) + } + + // 5. Post-execution: push changes if successful + if waitErr == nil && streamErr == nil { + // Check if there are any commits to push (HEAD ahead of origin/HEAD). + // If origin/HEAD doesn't exist (e.g. fresh clone with no commits), we attempt push anyway. + hasCommits := true + if out, err := r.command(ctx, "git", "-C", workspace, "rev-list", "origin/HEAD..HEAD").CombinedOutput(); err == nil { + if len(strings.TrimSpace(string(out))) == 0 { + hasCommits = false + } + } + + if hasCommits { + pushRef := "HEAD" + if storyBranch != "" { + pushRef = storyBranch + } + r.Logger.Info("pushing changes back to remote", "url", repoURL, "ref", pushRef) + if out, err := r.command(ctx, "git", "-C", workspace, "push", "origin", pushRef).CombinedOutput(); err != nil { + r.Logger.Warn("git push failed", "error", err, "output", string(out)) + return fmt.Errorf("git push failed: %w\n%s", err, string(out)) + } + } else { + // No commits pushed — check whether the agent left uncommitted work behind. + // If so, fail loudly: the work would be silently lost when the sandbox is deleted. + if err := detectUncommittedChanges(workspace); err != nil { + return err + } + r.Logger.Info("no new commits to push", "taskID", t.ID) + } + } + + if waitErr != nil { + // Append the tail of stderr so error classifiers (isQuotaExhausted, isRateLimitError) + // can inspect agent-specific messages (e.g. Gemini TerminalQuotaError). + stderrTail := readFileTail(e.StderrPath, 4096) + if stderrTail != "" { + return fmt.Errorf("container execution failed: %w\n%s", waitErr, stderrTail) + } + return fmt.Errorf("container execution failed: %w", waitErr) + } + if streamErr != nil { + return fmt.Errorf("stream parsing failed: %w", streamErr) + } + + return nil +} + +func (r *ContainerRunner) buildDockerArgs(workspace, claudeHome, taskID string) []string { + // --env-file takes a HOST path. + hostEnvFile := filepath.Join(workspace, ".claudomator-env") + + // Replace localhost with host.docker.internal so the container can reach the host API. + apiURL := strings.ReplaceAll(r.APIURL, "localhost", "host.docker.internal") + + args := []string{ + "run", "--rm", + // Allow container to reach the host via host.docker.internal. + "--add-host=host.docker.internal:host-gateway", + // Run as the current process UID:GID so the container can read host-owned files. + fmt.Sprintf("--user=%d:%d", os.Getuid(), os.Getgid()), + "-v", workspace + ":/workspace", + "-v", claudeHome + ":/home/agent", + "-w", "/workspace", + "--env-file", hostEnvFile, + "-e", "HOME=/home/agent", + "-e", "CLAUDOMATOR_API_URL=" + apiURL, + "-e", "CLAUDOMATOR_TASK_ID=" + taskID, + "-e", "CLAUDOMATOR_DROP_DIR=" + r.DropsDir, + } + if r.SSHAuthSock != "" { + args = append(args, "-v", r.SSHAuthSock+":/tmp/ssh-auth.sock", "-e", "SSH_AUTH_SOCK=/tmp/ssh-auth.sock") + } + return args +} + +func (r *ContainerRunner) buildInnerCmd(t *task.Task, e *storage.Execution, isResume bool) []string { + // Claude CLI uses -p for prompt text. To pass a file, we use a shell to cat it. + // We use a shell variable to capture the expansion to avoid quoting issues with instructions contents. + // The outer single quotes around the sh -c argument prevent host-side expansion. + + claudeBin := r.ClaudeBinary + if claudeBin == "" { + claudeBin = "claude" + } + geminiBin := r.GeminiBinary + if geminiBin == "" { + geminiBin = "gemini" + } + + if t.Agent.Type == "gemini" { + return []string{"sh", "-c", fmt.Sprintf("INST=$(cat /workspace/.claudomator-instructions.txt); %s -p \"$INST\"", geminiBin)} + } + + // Claude + var claudeCmd strings.Builder + claudeCmd.WriteString(fmt.Sprintf("INST=$(cat /workspace/.claudomator-instructions.txt); %s -p \"$INST\"", claudeBin)) + if isResume && e.ResumeSessionID != "" { + claudeCmd.WriteString(fmt.Sprintf(" --resume %s", e.ResumeSessionID)) + } + claudeCmd.WriteString(" --output-format stream-json --verbose --permission-mode bypassPermissions") + + return []string{"sh", "-c", claudeCmd.String()} +} + +// scaffoldPrefixes are files/dirs written by the harness into the workspace before the agent +// runs. They are not part of the repo and must not trigger the uncommitted-changes check. +var scaffoldPrefixes = []string{ + ".claudomator-env", + ".claudomator-instructions.txt", + ".agent-home", +} + +func isScaffold(path string) bool { + for _, p := range scaffoldPrefixes { + if path == p || strings.HasPrefix(path, p+"/") { + return true + } + } + return false +} + +// detectUncommittedChanges returns an error if the workspace contains modified or +// untracked source files that the agent forgot to commit. Scaffold files written by +// the harness (.claudomator-env, .claudomator-instructions.txt, .agent-home/) are +// excluded from the check. +func detectUncommittedChanges(workspace string) error { + // Modified or staged tracked files + diffOut, err := exec.Command("git", "-c", "safe.directory=*", "-C", workspace, + "diff", "--name-only", "HEAD").CombinedOutput() + if err == nil { + for _, line := range strings.Split(strings.TrimSpace(string(diffOut)), "\n") { + if line != "" && !isScaffold(line) { + return fmt.Errorf("agent left uncommitted changes (work would be lost on sandbox deletion):\n%s\nInstructions must include: git add -A && git commit && git push origin main", strings.TrimSpace(string(diffOut))) + } + } + } + + // Untracked new source files (excludes gitignored files) + lsOut, err := exec.Command("git", "-c", "safe.directory=*", "-C", workspace, + "ls-files", "--others", "--exclude-standard").CombinedOutput() + if err == nil { + var dirty []string + for _, line := range strings.Split(strings.TrimSpace(string(lsOut)), "\n") { + if line != "" && !isScaffold(line) { + dirty = append(dirty, line) + } + } + if len(dirty) > 0 { + return fmt.Errorf("agent left untracked files not committed (work would be lost on sandbox deletion):\n%s\nInstructions must include: git add -A && git commit && git push origin main", strings.Join(dirty, "\n")) + } + } + + return nil +} + diff --git a/internal/executor/container_test.go b/internal/executor/container_test.go new file mode 100644 index 0000000..f0b2a3a --- /dev/null +++ b/internal/executor/container_test.go @@ -0,0 +1,687 @@ +package executor + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +func TestContainerRunner_BuildDockerArgs(t *testing.T) { + runner := &ContainerRunner{ + APIURL: "http://localhost:8484", + DropsDir: "/data/drops", + SSHAuthSock: "/tmp/ssh.sock", + } + workspace := "/tmp/ws" + taskID := "task-123" + + agentHome := "/tmp/ws/.agent-home" + args := runner.buildDockerArgs(workspace, agentHome, taskID) + + expected := []string{ + "run", "--rm", + "--add-host=host.docker.internal:host-gateway", + fmt.Sprintf("--user=%d:%d", os.Getuid(), os.Getgid()), + "-v", "/tmp/ws:/workspace", + "-v", "/tmp/ws/.agent-home:/home/agent", + "-w", "/workspace", + "--env-file", "/tmp/ws/.claudomator-env", + "-e", "HOME=/home/agent", + "-e", "CLAUDOMATOR_API_URL=http://host.docker.internal:8484", + "-e", "CLAUDOMATOR_TASK_ID=task-123", + "-e", "CLAUDOMATOR_DROP_DIR=/data/drops", + "-v", "/tmp/ssh.sock:/tmp/ssh-auth.sock", + "-e", "SSH_AUTH_SOCK=/tmp/ssh-auth.sock", + } + + if len(args) != len(expected) { + t.Fatalf("expected %d args, got %d. Got: %v", len(expected), len(args), args) + } + for i, v := range args { + if v != expected[i] { + t.Errorf("arg %d: expected %q, got %q", i, expected[i], v) + } + } +} + +func TestContainerRunner_BuildInnerCmd(t *testing.T) { + runner := &ContainerRunner{} + + t.Run("claude-fresh", func(t *testing.T) { + tk := &task.Task{Agent: task.AgentConfig{Type: "claude"}} + exec := &storage.Execution{} + cmd := runner.buildInnerCmd(tk, exec, false) + + cmdStr := strings.Join(cmd, " ") + if strings.Contains(cmdStr, "--resume") { + t.Errorf("unexpected --resume flag in fresh run: %q", cmdStr) + } + if !strings.Contains(cmdStr, "INST=$(cat /workspace/.claudomator-instructions.txt); claude -p \"$INST\"") { + t.Errorf("expected cat instructions in sh command, got %q", cmdStr) + } + }) + + t.Run("claude-resume", func(t *testing.T) { + tk := &task.Task{Agent: task.AgentConfig{Type: "claude"}} + exec := &storage.Execution{ResumeSessionID: "orig-session-123"} + cmd := runner.buildInnerCmd(tk, exec, true) + + cmdStr := strings.Join(cmd, " ") + if !strings.Contains(cmdStr, "--resume orig-session-123") { + t.Errorf("expected --resume flag with correct session ID, got %q", cmdStr) + } + }) + + t.Run("gemini", func(t *testing.T) { + tk := &task.Task{Agent: task.AgentConfig{Type: "gemini"}} + exec := &storage.Execution{} + cmd := runner.buildInnerCmd(tk, exec, false) + + cmdStr := strings.Join(cmd, " ") + if !strings.Contains(cmdStr, "gemini -p \"$INST\"") { + t.Errorf("expected gemini command with safer quoting, got %q", cmdStr) + } + }) + + t.Run("custom-binaries", func(t *testing.T) { + runnerCustom := &ContainerRunner{ + ClaudeBinary: "/usr/bin/claude-v2", + GeminiBinary: "/usr/local/bin/gemini-pro", + } + + tkClaude := &task.Task{Agent: task.AgentConfig{Type: "claude"}} + cmdClaude := runnerCustom.buildInnerCmd(tkClaude, &storage.Execution{}, false) + if !strings.Contains(strings.Join(cmdClaude, " "), "/usr/bin/claude-v2 -p") { + t.Errorf("expected custom claude binary, got %q", cmdClaude) + } + + tkGemini := &task.Task{Agent: task.AgentConfig{Type: "gemini"}} + cmdGemini := runnerCustom.buildInnerCmd(tkGemini, &storage.Execution{}, false) + if !strings.Contains(strings.Join(cmdGemini, " "), "/usr/local/bin/gemini-pro -p") { + t.Errorf("expected custom gemini binary, got %q", cmdGemini) + } + }) +} + +func TestContainerRunner_Run_PreservesWorkspaceOnFailure(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + // Mock docker run to exit 1 + if name == "docker" { + return exec.Command("sh", "-c", "exit 1") + } + // Mock git clone to succeed and create the directory + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "test-task", + RepositoryURL: "https://github.com/example/repo.git", + Agent: task.AgentConfig{Type: "claude"}, + } + exec := &storage.Execution{ID: "test-exec", TaskID: "test-task"} + + err := runner.Run(context.Background(), tk, exec) + if err == nil { + t.Fatal("expected error due to mocked docker failure") + } + + // Verify SandboxDir was set and directory exists. + if exec.SandboxDir == "" { + t.Fatal("expected SandboxDir to be set even on failure") + } + if _, statErr := os.Stat(exec.SandboxDir); statErr != nil { + t.Errorf("expected sandbox directory to be preserved, but stat failed: %v", statErr) + } else { + os.RemoveAll(exec.SandboxDir) + } +} + +func TestBlockedError_IncludesSandboxDir(t *testing.T) { + // This test requires mocking 'docker run' or the whole Run() which is hard. + // But we can test that returning BlockedError works. + err := &BlockedError{ + QuestionJSON: `{"text":"?"}`, + SessionID: "s1", + SandboxDir: "/tmp/s1", + } + if !strings.Contains(err.Error(), "task blocked") { + t.Errorf("wrong error message: %v", err) + } +} + +func TestIsCompletionReport(t *testing.T) { + tests := []struct { + name string + json string + expected bool + }{ + { + name: "real question with options", + json: `{"text": "Should I proceed with implementation?", "options": ["Yes", "No"]}`, + expected: false, + }, + { + name: "real question no options", + json: `{"text": "Which approach do you prefer?"}`, + expected: false, + }, + { + name: "completion report no options no question mark", + json: `{"text": "All tests pass. Implementation complete. Summary written to CLAUDOMATOR_SUMMARY_FILE."}`, + expected: true, + }, + { + name: "completion report with empty options", + json: `{"text": "Feature implemented and committed.", "options": []}`, + expected: true, + }, + { + name: "invalid json treated as not a report", + json: `not json`, + expected: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isCompletionReport(tt.json) + if got != tt.expected { + t.Errorf("isCompletionReport(%q) = %v, want %v", tt.json, got, tt.expected) + } + }) + } +} + +func TestTailFile_ReturnsLastNLines(t *testing.T) { + f, err := os.CreateTemp("", "tailfile-*") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + for i := 1; i <= 30; i++ { + fmt.Fprintf(f, "line %d\n", i) + } + f.Close() + + got := tailFile(f.Name(), 5) + lines := strings.Split(strings.TrimSpace(got), "\n") + if len(lines) != 5 { + t.Fatalf("want 5 lines, got %d: %q", len(lines), got) + } + if lines[0] != "line 26" || lines[4] != "line 30" { + t.Errorf("want lines 26-30, got: %q", got) + } +} + +func TestDetectUncommittedChanges_ModifiedFile(t *testing.T) { + dir := t.TempDir() + run := func(args ...string) { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("%v: %s", args, out) + } + } + run("git", "init", dir) + run("git", "config", "user.email", "test@test.com") + run("git", "config", "user.name", "Test") + // Create and commit a file + if err := os.WriteFile(dir+"/main.go", []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + run("git", "add", "main.go") + run("git", "commit", "-m", "init") + // Now modify without committing — simulates agent that forgot to commit + if err := os.WriteFile(dir+"/main.go", []byte("package main\n// changed"), 0644); err != nil { + t.Fatal(err) + } + err := detectUncommittedChanges(dir) + if err == nil { + t.Fatal("expected error for modified uncommitted file, got nil") + } + if !strings.Contains(err.Error(), "uncommitted") { + t.Errorf("error should mention uncommitted, got: %v", err) + } +} + +func TestDetectUncommittedChanges_NewUntrackedSourceFile(t *testing.T) { + dir := t.TempDir() + run := func(args ...string) { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("%v: %s", args, out) + } + } + run("git", "init", dir) + run("git", "config", "user.email", "test@test.com") + run("git", "config", "user.name", "Test") + run("git", "commit", "--allow-empty", "-m", "init") + // Agent wrote a new file but never committed it + if err := os.WriteFile(dir+"/newfile.go", []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + err := detectUncommittedChanges(dir) + if err == nil { + t.Fatal("expected error for new untracked source file, got nil") + } +} + +func TestDetectUncommittedChanges_ScaffoldFilesIgnored(t *testing.T) { + dir := t.TempDir() + run := func(args ...string) { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("%v: %s", args, out) + } + } + run("git", "init", dir) + run("git", "config", "user.email", "test@test.com") + run("git", "config", "user.name", "Test") + run("git", "commit", "--allow-empty", "-m", "init") + // Write only scaffold files that the harness injects — should not trigger error + _ = os.WriteFile(dir+"/.claudomator-env", []byte("KEY=val"), 0600) + _ = os.WriteFile(dir+"/.claudomator-instructions.txt", []byte("do stuff"), 0644) + _ = os.MkdirAll(dir+"/.agent-home/.claude", 0755) + err := detectUncommittedChanges(dir) + if err != nil { + t.Errorf("scaffold files should not trigger uncommitted error, got: %v", err) + } +} + +func TestDetectUncommittedChanges_CleanRepo(t *testing.T) { + dir := t.TempDir() + run := func(args ...string) { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("%v: %s", args, out) + } + } + run("git", "init", dir) + run("git", "config", "user.email", "test@test.com") + run("git", "config", "user.name", "Test") + if err := os.WriteFile(dir+"/main.go", []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + run("git", "add", "main.go") + run("git", "commit", "-m", "init") + // No modifications — should pass + err := detectUncommittedChanges(dir) + if err != nil { + t.Errorf("clean repo should not error, got: %v", err) + } +} + +func TestGitSafe_PrependsSafeDirectory(t *testing.T) { + got := gitSafe("-C", "/some/path", "status") + want := []string{"-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-c", "tag.gpgsign=false", "-C", "/some/path", "status"} + if len(got) != len(want) { + t.Fatalf("gitSafe() = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("gitSafe()[%d] = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestContainerRunner_MissingCredentials_FailsFast(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + claudeConfigDir := t.TempDir() + + // Set up ClaudeConfigDir with MISSING credentials (so pre-flight fails) + // Don't create .credentials.json + // But DO create .claude.json so the test isolates the credentials check + if err := os.WriteFile(filepath.Join(claudeConfigDir, ".claude.json"), []byte("{}"), 0644); err != nil { + t.Fatal(err) + } + + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + ClaudeConfigDir: claudeConfigDir, + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "test-missing-creds", + RepositoryURL: "https://github.com/example/repo.git", + Agent: task.AgentConfig{Type: "claude"}, + } + e := &storage.Execution{ID: "test-exec", TaskID: "test-missing-creds"} + + err := runner.Run(context.Background(), tk, e) + if err == nil { + t.Fatal("expected error due to missing credentials, got nil") + } + if !strings.Contains(err.Error(), "credentials not found") { + t.Errorf("expected 'credentials not found' error, got: %v", err) + } +} + +func TestContainerRunner_MissingSettings_FailsFast(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + claudeConfigDir := t.TempDir() + + // Only create credentials but NOT .claude.json + if err := os.WriteFile(filepath.Join(claudeConfigDir, ".credentials.json"), []byte("{}"), 0600); err != nil { + t.Fatal(err) + } + + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + ClaudeConfigDir: claudeConfigDir, + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "test-missing-settings", + RepositoryURL: "https://github.com/example/repo.git", + Agent: task.AgentConfig{Type: "claude"}, + } + e := &storage.Execution{ID: "test-exec-2", TaskID: "test-missing-settings"} + + err := runner.Run(context.Background(), tk, e) + if err == nil { + t.Fatal("expected error due to missing settings, got nil") + } + if !strings.Contains(err.Error(), "claude settings") { + t.Errorf("expected 'claude settings' error, got: %v", err) + } +} + +func TestIsAuthError_DetectsAllVariants(t *testing.T) { + tests := []struct { + msg string + want bool + }{ + {"Not logged in", true}, + {"OAuth token has expired", true}, + {"authentication_error: invalid token", true}, + {"Please run /login to authenticate", true}, + {"container execution failed: exit status 1", false}, + {"git clone failed", false}, + {"", false}, + } + for _, tt := range tests { + var err error + if tt.msg != "" { + err = fmt.Errorf("%s", tt.msg) + } + got := isAuthError(err) + if got != tt.want { + t.Errorf("isAuthError(%q) = %v, want %v", tt.msg, got, tt.want) + } + } +} + +func TestContainerRunner_AuthError_SyncsAndRetries(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + // Create a sync script that creates a marker file + syncDir := t.TempDir() + syncMarker := filepath.Join(syncDir, "sync-called") + syncScript := filepath.Join(syncDir, "sync-creds") + os.WriteFile(syncScript, []byte("#!/bin/sh\ntouch "+syncMarker+"\n"), 0755) + + claudeConfigDir := t.TempDir() + // Create both credential files in ClaudeConfigDir + os.WriteFile(filepath.Join(claudeConfigDir, ".credentials.json"), []byte(`{"token":"fresh"}`), 0600) + os.WriteFile(filepath.Join(claudeConfigDir, ".claude.json"), []byte("{}"), 0644) + + callCount := 0 + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + ClaudeConfigDir: claudeConfigDir, + CredentialSyncCmd: syncScript, + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" { + if len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + } + return exec.Command("true") + } + if name == "docker" { + callCount++ + if callCount == 1 { + // First docker call fails with auth error + return exec.Command("sh", "-c", "echo 'Not logged in' >&2; exit 1") + } + // Second docker call "succeeds" + return exec.Command("sh", "-c", "exit 0") + } + if name == syncScript { + return exec.Command("sh", "-c", "touch "+syncMarker) + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "auth-retry-test", + RepositoryURL: "https://github.com/example/repo.git", + Agent: task.AgentConfig{Type: "claude", Instructions: "test"}, + } + e := &storage.Execution{ID: "auth-retry-exec", TaskID: "auth-retry-test"} + + // Run — first attempt will fail with auth error, triggering sync+retry + runner.Run(context.Background(), tk, e) + // We don't check error strictly since second run may also fail (git push etc.) + // What we care about is that docker was called twice and sync was called + if callCount < 2 { + t.Errorf("expected docker to be called at least twice (original + retry), got %d", callCount) + } + if _, err := os.Stat(syncMarker); os.IsNotExist(err) { + t.Error("expected sync-credentials to be called, but marker file not found") + } +} + +func TestContainerRunner_ClonesStoryBranch(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + var checkoutArgs []string + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + // Capture checkout calls: both "git checkout <branch>" and "git -C <dir> checkout <branch>" + for i, a := range arg { + if a == "checkout" { + checkoutArgs = append([]string{}, arg[i:]...) + break + } + } + if name == "docker" { + return exec.Command("sh", "-c", "exit 1") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "story-branch-test", + RepositoryURL: "https://example.com/repo.git", + BranchName: "story/my-feature", + Agent: task.AgentConfig{Type: "claude"}, + } + e := &storage.Execution{ID: "exec-1", TaskID: "story-branch-test"} + + runner.Run(context.Background(), tk, e) + os.RemoveAll(e.SandboxDir) + + // Assert git checkout was called with the story branch name. + if len(checkoutArgs) == 0 { + t.Fatal("expected git checkout to be called for story branch, but it was not") + } + found := false + for _, a := range checkoutArgs { + if a == "story/my-feature" { + found = true + break + } + } + if !found { + t.Errorf("expected git checkout story/my-feature, got args: %v", checkoutArgs) + } +} + +func TestContainerRunner_ClonesDefaultBranchWhenNoBranchName(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + var cloneArgs []string + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + cloneArgs = append([]string{}, arg...) + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + if name == "docker" { + return exec.Command("sh", "-c", "exit 1") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "no-branch-test", + RepositoryURL: "https://example.com/repo.git", + Agent: task.AgentConfig{Type: "claude"}, + } + e := &storage.Execution{ID: "exec-2", TaskID: "no-branch-test"} + + runner.Run(context.Background(), tk, e) + os.RemoveAll(e.SandboxDir) + + for _, a := range cloneArgs { + if a == "--branch" { + t.Errorf("expected no --branch flag for task without BranchName, got args: %v", cloneArgs) + } + } +} + +func TestEnsureStoryBranch_CreatesMissingBranch(t *testing.T) { + // Set up a bare repo and a local clone to test branch creation. + dir := t.TempDir() + bare := filepath.Join(dir, "bare.git") + local := filepath.Join(dir, "local") + + // Create bare repo with an initial commit. + if out, err := exec.Command("git", "init", "--bare", bare).CombinedOutput(); err != nil { + t.Fatalf("git init bare: %v\n%s", err, out) + } + if out, err := exec.Command("git", "clone", bare, local).CombinedOutput(); err != nil { + t.Fatalf("git clone: %v\n%s", err, out) + } + if out, err := exec.Command("git", "-C", local, "commit", "--allow-empty", "-m", "init").CombinedOutput(); err != nil { + t.Fatalf("git commit: %v\n%s", err, out) + } + if out, err := exec.Command("git", "-C", local, "push", "origin", "main").CombinedOutput(); err != nil { + // try master + if out2, err2 := exec.Command("git", "-C", local, "push", "origin", "HEAD:main").CombinedOutput(); err2 != nil { + t.Fatalf("git push main: %v\n%s\n%s", err, out, out2) + } + } + + runner := &ContainerRunner{Logger: slog.Default()} + + branch := "story/test-branch" + + // Branch should not exist yet. + out, _ := exec.Command("git", "ls-remote", "--heads", bare, branch).CombinedOutput() + if len(strings.TrimSpace(string(out))) > 0 { + t.Fatal("branch should not exist before ensureStoryBranch") + } + + if err := runner.ensureStoryBranch(context.Background(), bare, branch, ""); err != nil { + t.Fatalf("ensureStoryBranch: %v", err) + } + + // Branch should now exist in the bare repo. + out, err := exec.Command("git", "ls-remote", "--heads", bare, branch).CombinedOutput() + if err != nil || len(strings.TrimSpace(string(out))) == 0 { + t.Errorf("branch %q not found in bare repo after ensureStoryBranch: %s", branch, out) + } +} + +func TestEnsureStoryBranch_IdempotentIfExists(t *testing.T) { + dir := t.TempDir() + bare := filepath.Join(dir, "bare.git") + local := filepath.Join(dir, "local") + + if out, err := exec.Command("git", "init", "--bare", bare).CombinedOutput(); err != nil { + t.Fatalf("git init bare: %v\n%s", err, out) + } + if out, err := exec.Command("git", "clone", bare, local).CombinedOutput(); err != nil { + t.Fatalf("git clone: %v\n%s", err, out) + } + if out, err := exec.Command("git", "-C", local, "commit", "--allow-empty", "-m", "init").CombinedOutput(); err != nil { + t.Fatalf("git commit: %v\n%s", err, out) + } + if _, err := exec.Command("git", "-C", local, "push", "origin", "HEAD:main").CombinedOutput(); err != nil { + t.Fatalf("push main: %v", err) + } + + branch := "story/existing-branch" + // Pre-create the branch. + if out, err := exec.Command("git", "-C", local, "checkout", "-b", branch).CombinedOutput(); err != nil { + t.Fatalf("checkout -b: %v\n%s", err, out) + } + if out, err := exec.Command("git", "-C", local, "push", "origin", branch).CombinedOutput(); err != nil { + t.Fatalf("push branch: %v\n%s", err, out) + } + + runner := &ContainerRunner{Logger: slog.Default()} + + // Should be a no-op, not an error. + if err := runner.ensureStoryBranch(context.Background(), bare, branch, ""); err != nil { + t.Fatalf("ensureStoryBranch on existing branch: %v", err) + } +} diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 315030d..09169bd 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -2,9 +2,11 @@ package executor import ( "context" + "encoding/json" "errors" "fmt" "log/slog" + "os/exec" "path/filepath" "strings" "sync" @@ -25,6 +27,7 @@ type Store interface { ListSubtasks(parentID string) ([]*task.Task, error) ListExecutions(taskID string) ([]*storage.Execution, error) CreateExecution(e *storage.Execution) error + CreateExecutionAndSetRunning(e *storage.Execution) error UpdateExecution(e *storage.Execution) error UpdateTaskState(id string, newState task.State) error UpdateTaskQuestion(taskID, questionJSON string) error @@ -32,6 +35,14 @@ type Store interface { AppendTaskInteraction(taskID string, interaction task.Interaction) error UpdateTaskAgent(id string, agent task.AgentConfig) error UpdateExecutionChangestats(execID string, stats *task.Changestats) error + RecordAgentEvent(e storage.AgentEvent) error + GetProject(id string) (*task.Project, error) + GetStory(id string) (*task.Story, error) + ListTasksByStory(storyID string) ([]*task.Task, error) + UpdateStoryStatus(id string, status task.StoryState) error + CreateTask(t *task.Task) error + UpdateTaskCheckerReport(id, report string) error + GetCheckerTask(checkedTaskID string) (*task.Task, error) } // LogPather is an optional interface runners can implement to provide the log @@ -56,24 +67,28 @@ type workItem struct { // Pool manages a bounded set of concurrent task workers. type Pool struct { maxConcurrent int + maxPerAgent int runners map[string]Runner store Store logger *slog.Logger - depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s - - mu sync.Mutex - active int - activePerAgent map[string]int - rateLimited map[string]time.Time // agentType -> until - cancels map[string]context.CancelFunc // taskID → cancel - resultCh chan *Result - workCh chan workItem // internal bounded queue; Submit enqueues here - doneCh chan struct{} // signals when a worker slot is freed - Questions *QuestionRegistry - Classifier *Classifier - // LLM, when non-nil, enables LLM-synthesized summaries for executions - // whose stdout did not include a "## Summary" heading. - LLM *llm.Client + depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s + requeueDelay time.Duration // how long to wait before requeuing a blocked-per-agent task; defaults to 30s + + mu sync.Mutex + active int + activePerAgent map[string]int + rateLimited map[string]time.Time // agentType -> until + cancels map[string]context.CancelFunc // taskID → cancel + consecutiveFailures map[string]int // agentType -> count + closed bool // set to true when Shutdown has been called + resultCh chan *Result + startedCh chan string // task IDs that just transitioned to RUNNING + workCh chan workItem // internal bounded queue; Submit enqueues here + doneCh chan struct{} // signals when a worker slot is freed + workerWg sync.WaitGroup // tracks in-flight execute/executeResume goroutines + dispatchDone chan struct{} // closed when the dispatch goroutine exits + Classifier *Classifier + LLM *llm.Client } // Result is emitted when a task execution completes. @@ -88,18 +103,22 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger * maxConcurrent = 1 } p := &Pool{ - maxConcurrent: maxConcurrent, - runners: runners, - store: store, - logger: logger, - depPollInterval: 5 * time.Second, - activePerAgent: make(map[string]int), - rateLimited: make(map[string]time.Time), - cancels: make(map[string]context.CancelFunc), - resultCh: make(chan *Result, maxConcurrent*2), - workCh: make(chan workItem, maxConcurrent*10+100), - doneCh: make(chan struct{}, maxConcurrent), - Questions: NewQuestionRegistry(), + maxConcurrent: maxConcurrent, + maxPerAgent: 1, + runners: runners, + store: store, + logger: logger, + depPollInterval: 5 * time.Second, + requeueDelay: 30 * time.Second, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), + cancels: make(map[string]context.CancelFunc), + consecutiveFailures: make(map[string]int), + resultCh: make(chan *Result, maxConcurrent*2), + startedCh: make(chan string, maxConcurrent*2), + workCh: make(chan workItem, maxConcurrent*10+100), + doneCh: make(chan struct{}, maxConcurrent), + dispatchDone: make(chan struct{}), } go p.dispatch() return p @@ -109,6 +128,7 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger * // and launches goroutines as soon as a pool slot is available. This prevents // tasks from being rejected when the pool is temporarily at capacity. func (p *Pool) dispatch() { + defer close(p.dispatchDone) for item := range p.workCh { for { p.mu.Lock() @@ -116,9 +136,9 @@ func (p *Pool) dispatch() { p.active++ p.mu.Unlock() if item.exec != nil { - go p.executeResume(item.ctx, item.task, item.exec) + p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.executeResume(i.ctx, i.task, i.exec) }(item) } else { - go p.execute(item.ctx, item.task) + p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.execute(i.ctx, i.task) }(item) } break } @@ -132,19 +152,64 @@ func (p *Pool) dispatch() { // work queue is full. When the pool is at capacity the task is buffered and // dispatched as soon as a slot becomes available. func (p *Pool) Submit(ctx context.Context, t *task.Task) error { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return fmt.Errorf("executor pool is shut down") + } + // Send while holding the lock so that Shutdown cannot close workCh between + // the closed-check above and the send below. The dispatch goroutine never + // holds p.mu while receiving from workCh, so this cannot deadlock. select { case p.workCh <- workItem{ctx: ctx, task: t}: + p.mu.Unlock() return nil default: + p.mu.Unlock() return fmt.Errorf("executor work queue full (capacity %d)", cap(p.workCh)) } } +// Started returns a channel that emits task IDs when they transition to RUNNING. +func (p *Pool) Started() <-chan string { + return p.startedCh +} + // Results returns the channel for reading execution results. func (p *Pool) Results() <-chan *Result { return p.resultCh } +// Shutdown stops accepting new work and waits for all in-flight workers to +// finish. Returns ctx.Err() if the context deadline is exceeded before all +// workers complete. +func (p *Pool) Shutdown(ctx context.Context) error { + // Stop the dispatch goroutine. We must wait for it to exit before calling + // workerWg.Wait() to avoid a race between dispatch's Add(1) and Wait(). + p.mu.Lock() + p.closed = true + p.mu.Unlock() + close(p.workCh) + select { + case <-p.dispatchDone: + case <-ctx.Done(): + return ctx.Err() + } + + done := make(chan struct{}) + go func() { + p.workerWg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Cancel requests cancellation of a running task. Returns false if the task // is not currently running in this pool. func (p *Pool) Cancel(taskID string) bool { @@ -250,11 +315,12 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex exec.StartTime = time.Now().UTC() exec.Status = "RUNNING" - if err := p.store.CreateExecution(exec); err != nil { + if err := p.store.CreateExecutionAndSetRunning(exec); err != nil { p.logger.Error("failed to create resume execution record", "error", err) } - if err := p.store.UpdateTaskState(t.ID, task.StateRunning); err != nil { - p.logger.Error("failed to update task state", "error", err) + select { + case p.startedCh <- t.ID: + default: } var cancel context.CancelFunc @@ -273,6 +339,19 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex p.mu.Unlock() }() + // Populate RepositoryURL from Project registry if missing (ADR-007). + if t.RepositoryURL == "" && t.Project != "" { + if proj, err := p.store.GetProject(t.Project); err == nil && proj.RemoteURL != "" { + t.RepositoryURL = proj.RemoteURL + } + } + // Populate BranchName from Story if missing (ADR-007). + if t.BranchName == "" && t.StoryID != "" { + if story, err := p.store.GetStory(t.StoryID); err == nil && story.BranchName != "" { + t.BranchName = story.BranchName + } + } + err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() @@ -289,16 +368,32 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. if retry.IsRateLimitError(err) || isQuotaExhausted(err) { p.mu.Lock() retryAfter := retry.ParseRetryAfter(err.Error()) - if retryAfter == 0 { - if isQuotaExhausted(err) { + reason := "transient" + if isQuotaExhausted(err) { + reason = "quota" + if retryAfter == 0 { retryAfter = 5 * time.Hour - } else { - retryAfter = 1 * time.Minute } + } else if retryAfter == 0 { + retryAfter = 1 * time.Minute } - p.rateLimited[agentType] = time.Now().Add(retryAfter) + until := time.Now().Add(retryAfter) + p.rateLimited[agentType] = until p.logger.Info("agent rate limited", "agent", agentType, "retryAfter", retryAfter, "quotaExhausted", isQuotaExhausted(err)) p.mu.Unlock() + go func() { + ev := storage.AgentEvent{ + ID: uuid.New().String(), + Agent: agentType, + Event: "rate_limited", + Timestamp: time.Now(), + Until: &until, + Reason: reason, + } + if recErr := p.store.RecordAgentEvent(ev); recErr != nil { + p.logger.Warn("failed to record agent event", "error", recErr) + } + }() } var blockedErr *BlockedError @@ -335,9 +430,51 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil { p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err) } + p.mu.Lock() + p.consecutiveFailures[agentType]++ + p.mu.Unlock() + } + // If this is a checker task, attach the failure report for any terminal + // failure state (FAILED, TIMED_OUT, CANCELLED, BUDGET_EXCEEDED). + if t.CheckerForTaskID != "" && exec.ErrorMsg != "" { + if reportErr := p.store.UpdateTaskCheckerReport(t.CheckerForTaskID, exec.ErrorMsg); reportErr != nil { + p.logger.Error("handleRunResult: failed to set checker report", "taskID", t.CheckerForTaskID, "error", reportErr) + } + } + if t.StoryID != "" && exec.Status == "FAILED" { + storyID := t.StoryID + errMsg := exec.ErrorMsg + go func() { + story, getErr := p.store.GetStory(storyID) + if getErr != nil { + return + } + if story.Status == task.StoryValidating { + p.checkValidationResult(ctx, storyID, task.StateFailed, errMsg) + } + }() } } else { - if t.ParentTaskID == "" { + p.mu.Lock() + p.consecutiveFailures[agentType] = 0 + p.mu.Unlock() + if t.CheckerForTaskID != "" { + // Checker task succeeded — auto-accept the checked task. + exec.Status = "COMPLETED" + if err := p.store.UpdateTaskState(t.ID, task.StateCompleted); err != nil { + p.logger.Error("handleRunResult: failed to complete checker task", "taskID", t.ID, "error", err) + } + checkedTask, getErr := p.store.GetTask(t.CheckerForTaskID) + if getErr == nil { + if acceptErr := p.store.UpdateTaskState(t.CheckerForTaskID, task.StateCompleted); acceptErr != nil { + p.logger.Error("handleRunResult: failed to auto-accept checked task", "taskID", t.CheckerForTaskID, "error", acceptErr) + } else if checkedTask.StoryID != "" { + go p.checkStoryCompletion(context.Background(), checkedTask.StoryID) + } + } else { + p.logger.Error("handleRunResult: failed to get checked task", "taskID", t.CheckerForTaskID, "error", getErr) + } + } else if t.ParentTaskID == "" { subtasks, subErr := p.store.ListSubtasks(t.ID) if subErr != nil { p.logger.Error("failed to list subtasks", "taskID", t.ID, "error", subErr) @@ -352,6 +489,7 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. if err := p.store.UpdateTaskState(t.ID, task.StateReady); err != nil { p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateReady, "error", err) } + go p.spawnCheckerTask(context.Background(), t) } } else { exec.Status = "COMPLETED" @@ -360,6 +498,21 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. } p.maybeUnblockParent(t.ParentTaskID) } + if t.StoryID != "" { + storyID := t.StoryID + go func() { + story, getErr := p.store.GetStory(storyID) + if getErr != nil { + p.logger.Error("handleRunResult: failed to get story", "storyID", storyID, "error", getErr) + return + } + if story.Status == task.StoryValidating { + p.checkValidationResult(ctx, storyID, task.StateCompleted, "") + } else { + p.checkStoryCompletion(ctx, storyID) + } + }() + } } summary := exec.Summary @@ -374,6 +527,13 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. p.logger.Error("failed to update task summary", "taskID", t.ID, "error", summaryErr) } } + terminalFailure := exec.Status == "FAILED" || exec.Status == "TIMED_OUT" || exec.Status == "CANCELLED" || exec.Status == "BUDGET_EXCEEDED" + if t.CheckerForTaskID != "" && terminalFailure && summary != "" { + // Overwrite the initial error-message report with the richer summary. + if reportErr := p.store.UpdateTaskCheckerReport(t.CheckerForTaskID, summary); reportErr != nil { + p.logger.Error("handleRunResult: failed to update checker report with summary", "taskID", t.CheckerForTaskID, "error", reportErr) + } + } if exec.StdoutPath != "" { if cs := task.ParseChangestatFromFile(exec.StdoutPath); cs != nil { exec.Changestats = cs @@ -388,6 +548,256 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} } +// checkStoryCompletion checks whether all top-level tasks in a story have reached +// a terminal success state and transitions the story to SHIPPABLE if so. +// Subtasks are intentionally excluded — a parent task reaching READY/COMPLETED +// already accounts for its subtasks. +// CheckStoryCompletion is the exported entry point for story completion checks +// called from outside the package (e.g. the API accept handler). +func (p *Pool) CheckStoryCompletion(ctx context.Context, storyID string) { + p.checkStoryCompletion(ctx, storyID) +} + +func (p *Pool) checkStoryCompletion(ctx context.Context, storyID string) { + story, err := p.store.GetStory(storyID) + if err != nil { + p.logger.Error("checkStoryCompletion: failed to get story", "storyID", storyID, "error", err) + return + } + if story.Status != task.StoryInProgress { + return // already SHIPPABLE or beyond — nothing to do + } + tasks, err := p.store.ListTasksByStory(storyID) + if err != nil { + p.logger.Error("checkStoryCompletion: failed to list tasks", "storyID", storyID, "error", err) + return + } + if len(tasks) == 0 { + return + } + topLevelCount := 0 + for _, t := range tasks { + if t.ParentTaskID != "" { + continue // subtasks are covered by their parent + } + topLevelCount++ + if t.State != task.StateCompleted { + return // not all top-level tasks done; READY alone is not sufficient (checker may be pending) + } + } + if topLevelCount == 0 { + return // no top-level tasks — don't auto-complete + } + if err := p.store.UpdateStoryStatus(storyID, task.StoryShippable); err != nil { + p.logger.Error("checkStoryCompletion: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("story transitioned to SHIPPABLE", "storyID", storyID) + // Deploy is now triggered explicitly by the human via POST /api/stories/{id}/ship. +} + +// ShipStory merges the story branch and runs the deploy script. +// Returns an error if the story is not in SHIPPABLE state. +func (p *Pool) ShipStory(ctx context.Context, storyID string) error { + story, err := p.store.GetStory(storyID) + if err != nil { + return fmt.Errorf("story not found: %w", err) + } + if story.Status != task.StoryShippable { + return fmt.Errorf("story is not SHIPPABLE (current status: %s)", story.Status) + } + go p.triggerStoryDeploy(context.Background(), storyID) + return nil +} + +// spawnCheckerTask creates and submits a checker task for the given completed task. +// Guards: not called for subtasks, checker tasks, tasks without a repository URL, +// or tasks that already have a checker. +func (p *Pool) spawnCheckerTask(ctx context.Context, checked *task.Task) { + // Never spawn a checker for subtasks, checker tasks, or tasks without a repository. + if checked.ParentTaskID != "" || checked.CheckerForTaskID != "" || checked.RepositoryURL == "" { + return + } + // Idempotent: don't create a second checker if one already exists. + existing, err := p.store.GetCheckerTask(checked.ID) + if err != nil { + p.logger.Error("spawnCheckerTask: GetCheckerTask failed", "taskID", checked.ID, "error", err) + return + } + if existing != nil { + return + } + + criteria := checked.AcceptanceCriteria + if criteria == "" { + criteria = checked.Agent.Instructions + } + + instructions := fmt.Sprintf(`You are validating a completed task. Do not make any changes to the code or repository. + +Task: %s +Instructions given to the implementor: +%s + +Acceptance criteria: +%s + +Steps: +1. Clone the repository and review the changes made. +2. Verify each acceptance criterion is met. Run tests or make HTTP requests as needed. +3. If all criteria are satisfied, exit normally (success). +4. If any criterion is not met, use the Bash tool to exit with a non-zero code: + bash -c "exit 1" + Before exiting, write a brief summary of what failed.`, checked.Name, checked.Agent.Instructions, criteria) + + now := time.Now().UTC() + checker := &task.Task{ + ID: uuid.New().String(), + Name: "Check: " + checked.Name, + CheckerForTaskID: checked.ID, + RepositoryURL: checked.RepositoryURL, + Agent: task.AgentConfig{ + Type: "claude", + Instructions: instructions, + MaxBudgetUSD: 0.50, + AllowedTools: []string{"Bash", "Read", "Glob", "Grep"}, + }, + Timeout: task.Duration{Duration: 10 * time.Minute}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + + if err := p.store.CreateTask(checker); err != nil { + p.logger.Error("spawnCheckerTask: CreateTask failed", "error", err) + return + } + checker.State = task.StateQueued + if err := p.store.UpdateTaskState(checker.ID, task.StateQueued); err != nil { + p.logger.Error("spawnCheckerTask: UpdateTaskState failed", "error", err) + return + } + if err := p.Submit(ctx, checker); err != nil { + p.logger.Error("spawnCheckerTask: Submit failed", "error", err) + } +} + +// triggerStoryDeploy runs the project deploy script for a SHIPPABLE story +// and advances it to DEPLOYED on success. +func (p *Pool) triggerStoryDeploy(ctx context.Context, storyID string) { + story, err := p.store.GetStory(storyID) + if err != nil { + p.logger.Error("triggerStoryDeploy: failed to get story", "storyID", storyID, "error", err) + return + } + if story.ProjectID == "" { + return + } + proj, err := p.store.GetProject(story.ProjectID) + if err != nil { + p.logger.Error("triggerStoryDeploy: failed to get project", "storyID", storyID, "projectID", story.ProjectID, "error", err) + return + } + if proj.DeployScript == "" { + return + } + // Merge story branch to main before deploying (ADR-007). + if story.BranchName != "" && proj.LocalPath != "" { + mergeSteps := [][]string{ + {"git", "-C", proj.LocalPath, "fetch", "origin"}, + {"git", "-C", proj.LocalPath, "checkout", "main"}, + {"git", "-C", proj.LocalPath, "merge", "--no-ff", story.BranchName, "-m", "Merge " + story.BranchName}, + {"git", "-C", proj.LocalPath, "push", "origin", "main"}, + } + for _, args := range mergeSteps { + if mergeOut, mergeErr := exec.CommandContext(ctx, args[0], args[1:]...).CombinedOutput(); mergeErr != nil { + p.logger.Error("triggerStoryDeploy: merge failed", "cmd", args, "output", string(mergeOut), "error", mergeErr) + return + } + } + p.logger.Info("story branch merged to main", "storyID", storyID, "branch", story.BranchName) + } + out, err := exec.CommandContext(ctx, proj.DeployScript).CombinedOutput() + if err != nil { + p.logger.Error("triggerStoryDeploy: deploy script failed", "storyID", storyID, "script", proj.DeployScript, "output", string(out), "error", err) + return + } + if err := p.store.UpdateStoryStatus(storyID, task.StoryDeployed); err != nil { + p.logger.Error("triggerStoryDeploy: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("story transitioned to DEPLOYED", "storyID", storyID) + go p.createValidationTask(ctx, storyID) +} + +// createValidationTask creates a validation subtask from the story's ValidationJSON +// and transitions the story to VALIDATING. +func (p *Pool) createValidationTask(ctx context.Context, storyID string) { + story, err := p.store.GetStory(storyID) + if err != nil { + p.logger.Error("createValidationTask: failed to get story", "storyID", storyID, "error", err) + return + } + if story.ValidationJSON == "" { + p.logger.Warn("createValidationTask: story has no ValidationJSON, skipping", "storyID", storyID) + return + } + + var spec map[string]interface{} + if err := json.Unmarshal([]byte(story.ValidationJSON), &spec); err != nil { + p.logger.Error("createValidationTask: failed to parse ValidationJSON", "storyID", storyID, "error", err) + return + } + + instructions := fmt.Sprintf("Validate the deployment for story %q.\n\nValidation spec:\n%s", story.Name, story.ValidationJSON) + + now := time.Now().UTC() + vtask := &task.Task{ + ID: uuid.New().String(), + Name: fmt.Sprintf("validation: %s", story.Name), + StoryID: storyID, + State: task.StateQueued, + Agent: task.AgentConfig{Type: "claude", Instructions: instructions}, + Tags: []string{}, + DependsOn: []string{}, + CreatedAt: now, + UpdatedAt: now, + } + + if err := p.store.CreateTask(vtask); err != nil { + p.logger.Error("createValidationTask: failed to create task", "storyID", storyID, "error", err) + return + } + if err := p.store.UpdateStoryStatus(storyID, task.StoryValidating); err != nil { + p.logger.Error("createValidationTask: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("validation task created and story transitioned to VALIDATING", "storyID", storyID, "taskID", vtask.ID) + p.Submit(ctx, vtask) //nolint:errcheck +} + +// checkValidationResult inspects a completed validation task and transitions +// the story to REVIEW_READY or NEEDS_FIX accordingly. +func (p *Pool) checkValidationResult(ctx context.Context, storyID string, taskState task.State, errorMsg string) { + if taskState == task.StateCompleted { + if err := p.store.UpdateStoryStatus(storyID, task.StoryReviewReady); err != nil { + p.logger.Error("checkValidationResult: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("story transitioned to REVIEW_READY", "storyID", storyID) + } else { + if err := p.store.UpdateStoryStatus(storyID, task.StoryNeedsFix); err != nil { + p.logger.Error("checkValidationResult: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("story transitioned to NEEDS_FIX", "storyID", storyID, "error", errorMsg) + } +} + // ActiveCount returns the number of currently running tasks. func (p *Pool) ActiveCount() int { p.mu.Lock() @@ -395,6 +805,34 @@ func (p *Pool) ActiveCount() int { return p.active } +// AgentStatusInfo holds the current state of a single agent. +type AgentStatusInfo struct { + Agent string `json:"agent"` + ActiveTasks int `json:"active_tasks"` + RateLimited bool `json:"rate_limited"` + Until *time.Time `json:"until,omitempty"` +} + +// AgentStatuses returns the current status of all registered agents. +func (p *Pool) AgentStatuses() []AgentStatusInfo { + p.mu.Lock() + defer p.mu.Unlock() + now := time.Now() + var out []AgentStatusInfo + for agent := range p.runners { + info := AgentStatusInfo{ + Agent: agent, + ActiveTasks: p.activePerAgent[agent], + } + if deadline, ok := p.rateLimited[agent]; ok && now.Before(deadline) { + info.RateLimited = true + info.Until = &deadline + } + out = append(out, info) + } + return out +} + // pickAgent selects the best agent from the given SystemStatus using explicit // load balancing: prefer the available (non-rate-limited) agent with the fewest // active tasks. If all agents are rate-limited, fall back to fewest active. @@ -436,6 +874,18 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { activeTasks[agent] = p.activePerAgent[agent] if deadline, ok := p.rateLimited[agent]; ok && now.After(deadline) { delete(p.rateLimited, agent) + agentName := agent + go func() { + ev := storage.AgentEvent{ + ID: uuid.New().String(), + Agent: agentName, + Event: "available", + Timestamp: time.Now(), + } + if recErr := p.store.RecordAgentEvent(ev); recErr != nil { + p.logger.Warn("failed to record agent available event", "error", recErr) + } + }() } rateLimited[agent] = now.Before(p.rateLimited[agent]) } @@ -479,9 +929,58 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { agentType = "claude" } + // Check dependencies before taking the per-agent slot to avoid deadlock: + // if a dependent task holds the slot while waiting for its dependency to run, + // the dependency can never start (maxPerAgent=1). + if len(t.DependsOn) > 0 { + ready, depErr := p.checkDepsReady(t) + if depErr != nil { + // A dependency hit a terminal failure — cancel this task immediately. + now := time.Now().UTC() + exec := &storage.Execution{ + ID: uuid.New().String(), + TaskID: t.ID, + StartTime: now, + EndTime: now, + Status: "CANCELLED", + ErrorMsg: depErr.Error(), + } + if createErr := p.store.CreateExecution(exec); createErr != nil { + p.logger.Error("failed to create execution record", "error", createErr) + } + if err := p.store.UpdateTaskState(t.ID, task.StateCancelled); err != nil { + p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateCancelled, "error", err) + } + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: depErr} + return + } + if !ready { + // Dependencies not yet done — requeue without holding the slot. + time.AfterFunc(p.requeueDelay, func() { p.workCh <- workItem{ctx: ctx, task: t} }) + return + } + } p.mu.Lock() + + if p.activePerAgent[agentType] >= p.maxPerAgent { + p.mu.Unlock() + time.AfterFunc(p.requeueDelay, func() { p.workCh <- workItem{ctx: ctx, task: t} }) + return + } if deadline, ok := p.rateLimited[agentType]; ok && time.Now().After(deadline) { delete(p.rateLimited, agentType) + agentName := agentType + go func() { + ev := storage.AgentEvent{ + ID: uuid.New().String(), + Agent: agentName, + Event: "available", + Timestamp: time.Now(), + } + if recErr := p.store.RecordAgentEvent(ev); recErr != nil { + p.logger.Warn("failed to record agent available event", "error", recErr) + } + }() } p.activePerAgent[agentType]++ p.mu.Unlock() @@ -512,30 +1011,6 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { return } - // Wait for all dependencies to complete before starting execution. - if len(t.DependsOn) > 0 { - if err := p.waitForDependencies(ctx, t); err != nil { - now := time.Now().UTC() - exec := &storage.Execution{ - ID: uuid.New().String(), - TaskID: t.ID, - StartTime: now, - EndTime: now, - Status: "FAILED", - ErrorMsg: err.Error(), - } - if createErr := p.store.CreateExecution(exec); createErr != nil { - p.logger.Error("failed to create execution record", "error", createErr) - } - if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err) - } - p.decActiveAgent(agentType, &cleaned) - p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} - return - } - } - execID := uuid.New().String() exec := &storage.Execution{ ID: execID, @@ -554,12 +1029,13 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { } } - // Record execution start. - if err := p.store.CreateExecution(exec); err != nil { + // Record execution start atomically with the RUNNING state transition. + if err := p.store.CreateExecutionAndSetRunning(exec); err != nil { p.logger.Error("failed to create execution record", "error", err) } - if err := p.store.UpdateTaskState(t.ID, task.StateRunning); err != nil { - p.logger.Error("failed to update task state", "error", err) + select { + case p.startedCh <- t.ID: + default: } // Apply task timeout and register cancel so callers can stop this task. @@ -583,6 +1059,19 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { priorExecs, priorErr := p.store.ListExecutions(t.ID) t = withFailureHistory(t, priorExecs, priorErr) + // Populate RepositoryURL from Project registry if missing (ADR-007). + if t.RepositoryURL == "" && t.Project != "" { + if proj, err := p.store.GetProject(t.Project); err == nil && proj.RemoteURL != "" { + t.RepositoryURL = proj.RemoteURL + } + } + // Populate BranchName from Story if missing (ADR-007). + if t.BranchName == "" && t.StoryID != "" { + if story, err := p.store.GetStory(t.StoryID); err == nil && story.BranchName != "" { + t.BranchName = story.BranchName + } + } + // Run the task. err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() @@ -650,18 +1139,31 @@ func (p *Pool) RecoverStaleQueued(ctx context.Context) { } } -// RecoverStaleBlocked promotes any BLOCKED parent task to READY when all of its -// subtasks are already COMPLETED. This handles the case where the server was -// restarted after subtasks finished but before maybeUnblockParent could fire. +// RecoverStaleBlocked promotes any BLOCKED or QUEUED parent task to READY when +// all of its subtasks are already COMPLETED. This handles the case where the +// server was restarted after subtasks finished but before maybeUnblockParent +// could fire, and also the case where story approval pre-created subtasks +// without ever running the parent task. // Call this once on server startup, after RecoverStaleRunning and RecoverStaleQueued. func (p *Pool) RecoverStaleBlocked() { - tasks, err := p.store.ListTasks(storage.TaskFilter{State: task.StateBlocked}) - if err != nil { - p.logger.Error("RecoverStaleBlocked: list tasks", "error", err) - return - } - for _, t := range tasks { - p.maybeUnblockParent(t.ID) + ctx := context.Background() + for _, state := range []task.State{task.StateBlocked, task.StateQueued} { + tasks, err := p.store.ListTasks(storage.TaskFilter{State: state}) + if err != nil { + p.logger.Error("RecoverStaleBlocked: list tasks", "error", err, "state", state) + continue + } + for _, t := range tasks { + if t.ParentTaskID != "" { + continue // only promote actual parents + } + before := t.State + p.maybeUnblockParent(t.ID) + // If the parent was promoted, check story completion. + if after, err := p.store.GetTask(t.ID); err == nil && after.State != before && t.StoryID != "" { + p.checkStoryCompletion(ctx, t.StoryID) + } + } } } @@ -673,6 +1175,32 @@ var terminalFailureStates = map[task.State]bool{ task.StateBudgetExceeded: true, } +// depDoneStates are task states that satisfy a DependsOn dependency. +var depDoneStates = map[task.State]bool{ + task.StateCompleted: true, + task.StateReady: true, // leaf tasks finish at READY +} + +// checkDepsReady does a single synchronous check of t.DependsOn. +// Returns (true, nil) if all deps are done, (false, nil) if any are still pending, +// or (false, err) if a dep entered a terminal failure state. +func (p *Pool) checkDepsReady(t *task.Task) (bool, error) { + for _, depID := range t.DependsOn { + dep, err := p.store.GetTask(depID) + if err != nil { + return false, fmt.Errorf("dependency %q not found: %w", depID, err) + } + if depDoneStates[dep.State] { + continue + } + if terminalFailureStates[dep.State] { + return false, fmt.Errorf("dependency %q ended in state %s", depID, dep.State) + } + return false, nil // still pending + } + return true, nil +} + // withFailureHistory returns a shallow copy of t with prior failed execution // error messages prepended to SystemPromptAppend so the agent knows what went // wrong in previous attempts. @@ -710,16 +1238,16 @@ func withFailureHistory(t *task.Task, execs []*storage.Execution, err error) *ta return © } -// maybeUnblockParent transitions the parent task from BLOCKED to READY if all -// of its subtasks are in the COMPLETED state. If any subtask is not COMPLETED -// (including FAILED, CANCELLED, RUNNING, etc.) the parent stays BLOCKED. +// maybeUnblockParent transitions the parent task to READY if all of its subtasks +// are in the COMPLETED state. Handles both BLOCKED parents (ran, created subtasks, +// paused) and QUEUED parents (story approval created subtasks without running parent). func (p *Pool) maybeUnblockParent(parentID string) { parent, err := p.store.GetTask(parentID) if err != nil { p.logger.Error("maybeUnblockParent: get parent", "parentID", parentID, "error", err) return } - if parent.State != task.StateBlocked { + if parent.State != task.StateBlocked && parent.State != task.StateQueued { return } subtasks, err := p.store.ListSubtasks(parentID) @@ -727,6 +1255,11 @@ func (p *Pool) maybeUnblockParent(parentID string) { p.logger.Error("maybeUnblockParent: list subtasks", "parentID", parentID, "error", err) return } + // A task with no subtasks was never blocked by subtask delegation — don't promote it. + // This prevents incorrectly promoting leaf tasks that are stuck in QUEUED to READY. + if len(subtasks) == 0 { + return + } for _, sub := range subtasks { if sub.State != task.StateCompleted { return @@ -747,7 +1280,7 @@ func (p *Pool) waitForDependencies(ctx context.Context, t *task.Task) error { if err != nil { return fmt.Errorf("dependency %q not found: %w", depID, err) } - if dep.State == task.StateCompleted { + if depDoneStates[dep.State] { continue } if terminalFailureStates[dep.State] { diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index b1173cb..9214872 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "os" + "os/exec" "path/filepath" "strings" "sync" @@ -600,10 +601,17 @@ func TestPool_RecoverStaleRunning(t *testing.T) { // Execution record should be closed as FAILED. execs, _ := store.ListExecutions(tk.ID) - if len(execs) == 0 || execs[0].Status != "FAILED" { + var failedExec *storage.Execution + for _, e := range execs { + if e.ID == "exec-stale-1" { + failedExec = e + break + } + } + if failedExec == nil || failedExec.Status != "FAILED" { t.Errorf("execution status: want FAILED, got %+v", execs) } - if execs[0].ErrorMsg == "" { + if failedExec.ErrorMsg == "" { t.Error("expected non-empty error message on recovered execution") } @@ -739,6 +747,119 @@ func TestPool_RecoverStaleBlocked_KeepsBlockedWhenSubtaskIncomplete(t *testing.T } } +func TestPool_RecoverStaleBlocked_PromotesQueuedParentWithAllSubtasksDone(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-queued-parent", Name: "Queued Parent Story", + Status: task.StoryInProgress, CreatedAt: now, UpdatedAt: now, + } + store.CreateStory(story) + + // Parent task stuck QUEUED (approved with pre-created subtasks, never run). + parent := makeTask("queued-parent-1") + parent.State = task.StateQueued + parent.StoryID = story.ID + store.CreateTask(parent) + + for i := 0; i < 2; i++ { + sub := makeTask(fmt.Sprintf("queued-sub-%d", i)) + sub.ParentTaskID = parent.ID + sub.StoryID = story.ID + sub.State = task.StateCompleted + store.CreateTask(sub) + } + + pool.RecoverStaleBlocked() + + got, err := store.GetTask(parent.ID) + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.State != task.StateReady { + t.Errorf("parent state: want READY, got %s", got.State) + } + + // Story should still be IN_PROGRESS — READY tasks don't satisfy the completion check; + // the task must be accepted (READY → COMPLETED) before the story advances to SHIPPABLE. + s, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if s.Status != task.StoryInProgress { + t.Errorf("story status: want IN_PROGRESS, got %s", s.Status) + } +} + +// TestPool_RecoverStaleBlocked_DoesNotPromoteQueuedLeafTask verifies that a top-level +// QUEUED task with NO subtasks is not promoted to READY by RecoverStaleBlocked. +// This guards against the bug where a task that failed to start (stuck in QUEUED due +// to a DB error) was incorrectly promoted to READY because the "all subtasks done" +// check is vacuously true when there are no subtasks. +func TestPool_RecoverStaleBlocked_DoesNotPromoteQueuedLeafTask(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + // A top-level task stuck in QUEUED with no subtasks (e.g. DB lock prevented RUNNING transition). + leaf := makeTask("queued-leaf-no-subtasks") + leaf.State = task.StateQueued + store.CreateTask(leaf) + + pool.RecoverStaleBlocked() + + got, err := store.GetTask(leaf.ID) + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.State != task.StateQueued { + t.Errorf("leaf task state: want QUEUED (unchanged), got %s", got.State) + } +} + +// TestPool_CheckStoryCompletion_ReadyTasksNotSufficient verifies that READY tasks +// alone do not advance a story to SHIPPABLE — tasks must be COMPLETED. +func TestPool_CheckStoryCompletion_ReadyTasksNotSufficient(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-ready-only", + Name: "Ready Only Story", + Status: task.StoryInProgress, + CreatedAt: now, + UpdatedAt: now, + } + store.CreateStory(story) + + // One task driven to READY (checker pending), one COMPLETED. + tk1 := makeTask("ro-task-1") + tk1.StoryID = story.ID + store.CreateTask(tk1) + for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady} { + store.UpdateTaskState(tk1.ID, s) + } + + tk2 := makeTask("ro-task-2") + tk2.StoryID = story.ID + store.CreateTask(tk2) + for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady, task.StateCompleted} { + store.UpdateTaskState(tk2.ID, s) + } + + pool.checkStoryCompletion(context.Background(), story.ID) + + got, _ := store.GetStory(story.ID) + if got.Status != task.StoryInProgress { + t.Errorf("story status: want IN_PROGRESS (tk1 still READY/checker pending), got %s", got.Status) + } +} + func TestPool_ActivePerAgent_DeletesZeroEntries(t *testing.T) { store := testStore(t) runner := &mockRunner{} @@ -1014,7 +1135,10 @@ func (m *minimalMockStore) ListSubtasks(parentID string) ([]*task.Task, error) { return nil, nil } func (m *minimalMockStore) ListExecutions(_ string) ([]*storage.Execution, error) { return nil, nil } -func (m *minimalMockStore) CreateExecution(e *storage.Execution) error { return nil } +func (m *minimalMockStore) CreateExecution(e *storage.Execution) error { return nil } +func (m *minimalMockStore) CreateExecutionAndSetRunning(e *storage.Execution) error { + return nil +} func (m *minimalMockStore) UpdateExecution(e *storage.Execution) error { return m.updateExecErr } @@ -1064,6 +1188,14 @@ func (m *minimalMockStore) UpdateExecutionChangestats(execID string, stats *task m.mu.Unlock() return nil } +func (m *minimalMockStore) RecordAgentEvent(_ storage.AgentEvent) error { return nil } +func (m *minimalMockStore) GetProject(_ string) (*task.Project, error) { return nil, nil } +func (m *minimalMockStore) GetStory(_ string) (*task.Story, error) { return nil, nil } +func (m *minimalMockStore) ListTasksByStory(_ string) ([]*task.Task, error) { return nil, nil } +func (m *minimalMockStore) UpdateStoryStatus(_ string, _ task.StoryState) error { return nil } +func (m *minimalMockStore) CreateTask(_ *task.Task) error { return nil } +func (m *minimalMockStore) UpdateTaskCheckerReport(_ string, _ string) error { return nil } +func (m *minimalMockStore) GetCheckerTask(_ string) (*task.Task, error) { return nil, nil } func (m *minimalMockStore) lastStateUpdate() (string, task.State, bool) { m.mu.Lock() @@ -1078,17 +1210,18 @@ func (m *minimalMockStore) lastStateUpdate() (string, task.State, bool) { func newPoolWithMockStore(store Store) *Pool { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) return &Pool{ - maxConcurrent: 2, - runners: map[string]Runner{"claude": &mockRunner{}}, - store: store, - logger: logger, - activePerAgent: make(map[string]int), - rateLimited: make(map[string]time.Time), - cancels: make(map[string]context.CancelFunc), - resultCh: make(chan *Result, 4), - workCh: make(chan workItem, 4), - doneCh: make(chan struct{}, 2), - Questions: NewQuestionRegistry(), + maxConcurrent: 2, + maxPerAgent: 1, + runners: map[string]Runner{"claude": &mockRunner{}}, + store: store, + logger: logger, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), + cancels: make(map[string]context.CancelFunc), + consecutiveFailures: make(map[string]int), + resultCh: make(chan *Result, 4), + workCh: make(chan workItem, 4), + doneCh: make(chan struct{}, 2), } } @@ -1236,6 +1369,11 @@ func TestPool_SpecificAgent_SkipsLoadBalancing(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) pool := NewPool(4, runners, store, logger) + // Raise per-agent limit so the concurrency gate doesn't interfere with this test. + // The injected activePerAgent is only to make pickAgent prefer "claude", + // verifying that explicit agent type bypasses load balancing. + pool.maxPerAgent = 10 + // Inject 2 active tasks for gemini, 0 for claude. // pickAgent would normally pick "claude". pool.mu.Lock() @@ -1425,3 +1563,748 @@ func TestExecute_MalformedChangestats(t *testing.T) { t.Errorf("expected nil changestats for malformed output, got %+v", execs[0].Changestats) } } + +func TestPool_MaxPerAgent_BlocksSecondTask(t *testing.T) { + store := testStore(t) + + var mu sync.Mutex + concurrentRuns := 0 + maxConcurrent := 0 + + runner := &mockRunner{ + delay: 100 * time.Millisecond, + onRun: func(tk *task.Task, e *storage.Execution) error { + mu.Lock() + concurrentRuns++ + if concurrentRuns > maxConcurrent { + maxConcurrent = concurrentRuns + } + mu.Unlock() + time.Sleep(100 * time.Millisecond) + mu.Lock() + concurrentRuns-- + mu.Unlock() + return nil + }, + } + runners := map[string]Runner{"claude": runner} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runners, store, logger) // pool size 2, but maxPerAgent=1 + pool.requeueDelay = 50 * time.Millisecond // speed up test + + tk1 := makeTask("mpa-1") + tk2 := makeTask("mpa-2") + store.CreateTask(tk1) + store.CreateTask(tk2) + + pool.Submit(context.Background(), tk1) + pool.Submit(context.Background(), tk2) + + for i := 0; i < 2; i++ { + select { + case <-pool.Results(): + case <-time.After(10 * time.Second): + t.Fatal("timed out waiting for result") + } + } + + mu.Lock() + got := maxConcurrent + mu.Unlock() + if got > 1 { + t.Errorf("maxPerAgent=1 violated: %d claude tasks ran concurrently", got) + } +} + +func TestPool_MaxPerAgent_AllowsDifferentAgents(t *testing.T) { + store := testStore(t) + + var mu sync.Mutex + concurrentRuns := 0 + maxConcurrent := 0 + + makeSlowRunner := func() *mockRunner { + return &mockRunner{ + onRun: func(tk *task.Task, e *storage.Execution) error { + mu.Lock() + concurrentRuns++ + if concurrentRuns > maxConcurrent { + maxConcurrent = concurrentRuns + } + mu.Unlock() + time.Sleep(80 * time.Millisecond) + mu.Lock() + concurrentRuns-- + mu.Unlock() + return nil + }, + } + } + + runners := map[string]Runner{ + "claude": makeSlowRunner(), + "gemini": makeSlowRunner(), + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runners, store, logger) + + tk1 := makeTask("da-1") + tk1.Agent.Type = "claude" + tk2 := makeTask("da-2") + tk2.Agent.Type = "gemini" + store.CreateTask(tk1) + store.CreateTask(tk2) + + pool.Submit(context.Background(), tk1) + pool.Submit(context.Background(), tk2) + + for i := 0; i < 2; i++ { + select { + case <-pool.Results(): + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for result") + } + } + + mu.Lock() + got := maxConcurrent + mu.Unlock() + if got < 2 { + t.Errorf("different agents should run concurrently; max concurrent was %d", got) + } +} + +func TestPool_ConsecutiveFailures_ResetOnSuccess(t *testing.T) { + store := testStore(t) + + callCount := 0 + runner := &mockRunner{ + onRun: func(tk *task.Task, e *storage.Execution) error { + callCount++ + if callCount == 1 { + return fmt.Errorf("first failure") + } + return nil // second call succeeds + }, + } + runners := map[string]Runner{"claude": runner} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runners, store, logger) + + // First task fails + tk1 := makeTask("rs-1") + store.CreateTask(tk1) + pool.Submit(context.Background(), tk1) + <-pool.Results() + + pool.mu.Lock() + failsBefore := pool.consecutiveFailures["claude"] + pool.mu.Unlock() + if failsBefore != 1 { + t.Errorf("expected 1 failure after first task, got %d", failsBefore) + } + + // Second task succeeds — counter resets. + tk2 := makeTask("rs-2") + store.CreateTask(tk2) + pool.Submit(context.Background(), tk2) + <-pool.Results() + + pool.mu.Lock() + failsAfter := pool.consecutiveFailures["claude"] + pool.mu.Unlock() + + if failsAfter != 0 { + t.Errorf("expected consecutiveFailures reset to 0 after success, got %d", failsAfter) + } +} + +func TestPool_CheckStoryCompletion_AllComplete(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + // Create a story in IN_PROGRESS state. + now := time.Now().UTC() + story := &task.Story{ + ID: "story-comp-1", + Name: "Completion Test", + Status: task.StoryInProgress, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + // Create two top-level story tasks and drive them through valid transitions to COMPLETED. + for i, id := range []string{"sctask-1", "sctask-2"} { + tk := makeTask(id) + tk.StoryID = "story-comp-1" + tk.State = task.StatePending + if err := store.CreateTask(tk); err != nil { + t.Fatalf("CreateTask %d: %v", i, err) + } + for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady, task.StateCompleted} { + if err := store.UpdateTaskState(id, s); err != nil { + t.Fatalf("UpdateTaskState %s → %s: %v", id, s, err) + } + } + } + + pool.checkStoryCompletion(context.Background(), "story-comp-1") + + got, err := store.GetStory("story-comp-1") + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryShippable { + t.Errorf("story status: want SHIPPABLE, got %v", got.Status) + } +} + +func TestPool_CheckStoryCompletion_PartialComplete(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-partial-1", + Name: "Partial Test", + Status: task.StoryInProgress, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + // First top-level task driven to READY. + tk1 := makeTask("sptask-1") + tk1.StoryID = "story-partial-1" + store.CreateTask(tk1) + for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady} { + store.UpdateTaskState("sptask-1", s) + } + + // Second top-level task still in PENDING (not done). + tk2 := makeTask("sptask-2") + tk2.StoryID = "story-partial-1" + store.CreateTask(tk2) + + pool.checkStoryCompletion(context.Background(), "story-partial-1") + + got, err := store.GetStory("story-partial-1") + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryInProgress { + t.Errorf("story status: want IN_PROGRESS (no transition), got %v", got.Status) + } +} + +func TestPool_StoryDeploy_RunsDeployScript(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} + runners := map[string]Runner{"claude": runner} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runners, store, logger) + + // Create a deploy script that writes a marker file. + tmpDir := t.TempDir() + markerFile := filepath.Join(tmpDir, "deployed.marker") + scriptPath := filepath.Join(tmpDir, "deploy.sh") + scriptContent := "#!/bin/sh\ntouch " + markerFile + "\n" + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil { + t.Fatalf("write deploy script: %v", err) + } + + proj := &task.Project{ + ID: "proj-deploy-1", + Name: "Deploy Test Project", + DeployScript: scriptPath, + } + if err := store.CreateProject(proj); err != nil { + t.Fatalf("create project: %v", err) + } + + story := &task.Story{ + ID: "story-deploy-1", + Name: "Deploy Test Story", + ProjectID: proj.ID, + Status: task.StoryShippable, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("create story: %v", err) + } + + pool.triggerStoryDeploy(context.Background(), story.ID) + + if _, err := os.Stat(markerFile); os.IsNotExist(err) { + t.Error("deploy script did not run: marker file not found") + } + + got, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("get story: %v", err) + } + if got.Status != task.StoryDeployed { + t.Errorf("story status: want DEPLOYED, got %q", got.Status) + } +} + +func runGit(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + if dir != "" { + cmd.Dir = dir + } + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git %v: %v\n%s", args, err, out) + } +} + +func TestPool_StoryDeploy_MergesStoryBranch(t *testing.T) { + tmpDir := t.TempDir() + + // Set up bare repo + working copy with a story branch. + bareDir := filepath.Join(tmpDir, "bare.git") + localDir := filepath.Join(tmpDir, "local") + runGit(t, "", "init", "--bare", bareDir) + runGit(t, "", "clone", bareDir, localDir) + runGit(t, localDir, "config", "user.email", "test@test.com") + runGit(t, localDir, "config", "user.name", "Test") + + // Initial commit on main. + runGit(t, localDir, "checkout", "-b", "main") + os.WriteFile(filepath.Join(localDir, "README.md"), []byte("initial"), 0644) + runGit(t, localDir, "add", ".") + runGit(t, localDir, "commit", "-m", "initial") + runGit(t, localDir, "push", "-u", "origin", "main") + + // Story branch with a feature commit. + runGit(t, localDir, "checkout", "-b", "story/test-feature") + os.WriteFile(filepath.Join(localDir, "feature.go"), []byte("package main"), 0644) + runGit(t, localDir, "add", ".") + runGit(t, localDir, "commit", "-m", "feature work") + runGit(t, localDir, "push", "origin", "story/test-feature") + runGit(t, localDir, "checkout", "main") + + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + scriptPath := filepath.Join(tmpDir, "deploy.sh") + os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0\n"), 0755) + + proj := &task.Project{ + ID: "proj-merge-1", Name: "Merge Test", + LocalPath: localDir, DeployScript: scriptPath, + } + if err := store.CreateProject(proj); err != nil { + t.Fatalf("create project: %v", err) + } + story := &task.Story{ + ID: "story-merge-1", Name: "Merge Test Story", + ProjectID: proj.ID, BranchName: "story/test-feature", + Status: task.StoryShippable, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("create story: %v", err) + } + + pool.triggerStoryDeploy(context.Background(), story.ID) + + // feature.go should now be on main in the working copy. + if _, err := os.Stat(filepath.Join(localDir, "feature.go")); os.IsNotExist(err) { + t.Error("story branch was not merged to main: feature.go missing") + } + got, _ := store.GetStory(story.ID) + if got.Status != task.StoryDeployed { + t.Errorf("story status: want DEPLOYED, got %q", got.Status) + } +} + +func TestPool_PostDeploy_CreatesValidationTask(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + validationSpec := `{"type":"smoke","steps":["curl /health"],"success_criteria":"status 200"}` + story := &task.Story{ + ID: "story-postdeploy-1", + Name: "Post Deploy Test", + Status: task.StoryDeployed, + ValidationJSON: validationSpec, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + pool.createValidationTask(context.Background(), story.ID) + + // Story should now be VALIDATING. + got, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryValidating { + t.Errorf("story status: want VALIDATING, got %q", got.Status) + } + + // A validation task should have been created. + tasks, err := store.ListTasksByStory(story.ID) + if err != nil { + t.Fatalf("ListTasksByStory: %v", err) + } + if len(tasks) == 0 { + t.Fatal("expected a validation task to be created, got none") + } + vtask := tasks[0] + if !strings.Contains(strings.ToLower(vtask.Name), "validation") { + t.Errorf("task name %q does not contain 'validation'", vtask.Name) + } + if vtask.StoryID != story.ID { + t.Errorf("task story_id: want %q, got %q", story.ID, vtask.StoryID) + } + if !strings.Contains(vtask.Agent.Instructions, "smoke") { + t.Errorf("task instructions %q do not reference validation spec content", vtask.Agent.Instructions) + } +} + +func TestPool_ValidationTask_Pass_SetsReviewReady(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-val-pass-1", + Name: "Validation Pass", + Status: task.StoryValidating, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + pool.checkValidationResult(context.Background(), story.ID, task.StateCompleted, "") + + got, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryReviewReady { + t.Errorf("story status: want REVIEW_READY, got %q", got.Status) + } +} + +// TestPool_DependsOn_NoDeadlock verifies that a task waiting for a dependency +// does NOT hold the per-agent slot, allowing the dependency to run first. +func TestPool_DependsOn_NoDeadlock(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} // succeeds immediately + pool := NewPool(2, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + pool.requeueDelay = 10 * time.Millisecond + + // Task A has no deps; Task B depends on A. + taskA := makeTask("dep-a") + taskA.State = task.StateQueued + taskB := makeTask("dep-b") + taskB.DependsOn = []string{"dep-a"} + taskB.State = task.StateQueued + + store.CreateTask(taskA) + store.CreateTask(taskB) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Submit B first — it should not deadlock by holding the slot while waiting for A. + pool.Submit(ctx, taskB) + pool.Submit(ctx, taskA) + + var gotA, gotB bool + for i := 0; i < 2; i++ { + select { + case res := <-pool.Results(): + if res.TaskID == "dep-a" { + gotA = true + } + if res.TaskID == "dep-b" { + gotB = true + } + case <-ctx.Done(): + t.Fatal("timeout: likely deadlock — dep-b held the slot while waiting for dep-a") + } + } + if !gotA || !gotB { + t.Errorf("expected both tasks to complete: gotA=%v gotB=%v", gotA, gotB) + } + + // B must complete after A. + ta, _ := store.GetTask("dep-a") + tb, _ := store.GetTask("dep-b") + if ta.State != task.StateReady && ta.State != task.StateCompleted { + t.Errorf("dep-a should be READY/COMPLETED, got %s", ta.State) + } + if tb.State != task.StateReady && tb.State != task.StateCompleted { + t.Errorf("dep-b should be READY/COMPLETED, got %s", tb.State) + } +} + +func TestPool_ValidationTask_Fail_SetsNeedsFix(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-val-fail-1", + Name: "Validation Fail", + Status: task.StoryValidating, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + execErr := "smoke test failed: /health returned 503" + pool.checkValidationResult(context.Background(), story.ID, task.StateFailed, execErr) + + got, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryNeedsFix { + t.Errorf("story status: want NEEDS_FIX, got %q", got.Status) + } +} + +func TestPool_Shutdown_WaitsForWorkers(t *testing.T) { + store := testStore(t) + started := make(chan struct{}) + unblock := make(chan struct{}) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + close(started) + <-unblock + return nil + }, + } + pool := NewPool(1, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("shutdown-task") + tk.State = task.StateQueued + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + // Wait until the worker has started. + select { + case <-started: + case <-time.After(5 * time.Second): + t.Fatal("worker did not start") + } + + // Shutdown should block until we unblock the worker. + done := make(chan error, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + done <- pool.Shutdown(ctx) + }() + + // Shutdown should not have returned yet. + select { + case err := <-done: + t.Fatalf("Shutdown returned early: %v", err) + case <-time.After(50 * time.Millisecond): + } + + close(unblock) // let the worker finish + + select { + case err := <-done: + if err != nil { + t.Errorf("Shutdown returned error: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Shutdown did not return after worker finished") + } +} + +func TestPool_Shutdown_TimesOut(t *testing.T) { + store := testStore(t) + unblock := make(chan struct{}) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + <-unblock // never unblocked + return nil + }, + } + pool := NewPool(1, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("shutdown-timeout-task") + tk.State = task.StateQueued + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + // Give worker a moment to start. + time.Sleep(50 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + err := pool.Shutdown(ctx) + if err == nil { + t.Error("expected timeout error, got nil") + } + close(unblock) // cleanup +} + +func TestPool_CheckerSpawned_OnReady(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} // succeeds instantly + pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("checker-spawn-1") + tk.RepositoryURL = "https://github.com/x/y" + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + <-pool.Results() // wait for original task to finish + + // Poll until the async spawnCheckerTask goroutine has written the checker task. + var checker *task.Task + var err error + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + checker, err = store.GetCheckerTask("checker-spawn-1") + if err != nil { + t.Fatalf("GetCheckerTask: %v", err) + } + if checker != nil { + break + } + time.Sleep(50 * time.Millisecond) + } + if checker == nil { + t.Fatal("expected a checker task to be created, got nil") + } + if checker.CheckerForTaskID != "checker-spawn-1" { + t.Errorf("expected CheckerForTaskID=checker-spawn-1, got %q", checker.CheckerForTaskID) + } +} + +func TestPool_CheckerNotSpawned_ForSubtask(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} + pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + parent := makeTask("no-checker-parent") + parent.RepositoryURL = "https://github.com/x/y" + store.CreateTask(parent) + + sub := makeTask("no-checker-sub") + sub.ParentTaskID = "no-checker-parent" + sub.RepositoryURL = "https://github.com/x/y" + store.CreateTask(sub) + + pool.Submit(context.Background(), sub) + <-pool.Results() + + time.Sleep(100 * time.Millisecond) + + checker, err := store.GetCheckerTask("no-checker-sub") + if err != nil { + t.Fatalf("GetCheckerTask: %v", err) + } + if checker != nil { + t.Error("expected no checker for subtask, but one was created") + } +} + +func TestPool_CheckerPass_AutoAcceptsTask(t *testing.T) { + store := testStore(t) + // Two-phase: first runner succeeds (original task), second also succeeds (checker). + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + return nil // both original and checker succeed + }, + } + pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("autoaccept-1") + tk.RepositoryURL = "https://github.com/x/y" + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + <-pool.Results() // original finishes → READY + checker spawned + + // Wait for checker to run and complete. + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + got, _ := store.GetTask("autoaccept-1") + if got != nil && got.State == task.StateCompleted { + break + } + <-pool.Results() + } + + got, err := store.GetTask("autoaccept-1") + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.State != task.StateCompleted { + t.Errorf("expected COMPLETED after checker pass, got %s", got.State) + } +} + +func TestPool_CheckerFail_AttachesReport(t *testing.T) { + store := testStore(t) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + if t.CheckerForTaskID != "" { + return fmt.Errorf("test suite failed: 3 failures") + } + return nil // original task succeeds + }, + } + pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("fail-checker-1") + tk.RepositoryURL = "https://github.com/x/y" + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + <-pool.Results() // original → READY + + // Wait for checker to fail. + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + got, _ := store.GetTask("fail-checker-1") + if got != nil && got.CheckerReport != "" { + break + } + select { + case <-pool.Results(): + case <-time.After(100 * time.Millisecond): + } + } + + got, err := store.GetTask("fail-checker-1") + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.State != task.StateReady { + t.Errorf("expected task to stay READY after checker fail, got %s", got.State) + } + if got.CheckerReport == "" { + t.Error("expected checker_report to be set after checker failure") + } +} diff --git a/internal/executor/helpers.go b/internal/executor/helpers.go new file mode 100644 index 0000000..76bf8b1 --- /dev/null +++ b/internal/executor/helpers.go @@ -0,0 +1,205 @@ +package executor + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "log/slog" + "os" + "strings" +) + +// BlockedError is returned by Run when the agent wrote a question file and exited. +// The pool transitions the task to BLOCKED and stores the question for the user. +type BlockedError struct { + QuestionJSON string // raw JSON from the question file + SessionID string // claude session to resume once the user answers + SandboxDir string // preserved sandbox path; resume must run here so Claude finds its session files +} + +func (e *BlockedError) Error() string { return fmt.Sprintf("task blocked: %s", e.QuestionJSON) } + +// parseStream reads streaming JSON from claude, writes to w, and returns +// (costUSD, error). error is non-nil if the stream signals task failure: +// - result message has is_error:true +// - a tool_result was denied due to missing permissions +func parseStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, string, error) { + tee := io.TeeReader(r, w) + scanner := bufio.NewScanner(tee) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines + + var totalCost float64 + var sessionID string + var streamErr error + +Loop: + for scanner.Scan() { + line := scanner.Bytes() + var msg map[string]interface{} + if err := json.Unmarshal(line, &msg); err != nil { + continue + } + + msgType, _ := msg["type"].(string) + switch msgType { + case "system": + if subtype, ok := msg["subtype"].(string); ok && subtype == "init" { + if sid, ok := msg["session_id"].(string); ok { + sessionID = sid + } + } + case "rate_limit_event": + if info, ok := msg["rate_limit_info"].(map[string]interface{}); ok { + status, _ := info["status"].(string) + if status == "rejected" { + streamErr = fmt.Errorf("claude rate limit reached (rejected): %v", msg) + // Immediately break since we can't continue anyway + break Loop + } + } + case "assistant": + if errStr, ok := msg["error"].(string); ok && errStr == "rate_limit" { + streamErr = fmt.Errorf("claude rate limit reached: %v", msg) + } + case "result": + if isErr, _ := msg["is_error"].(bool); isErr { + result, _ := msg["result"].(string) + if result != "" { + streamErr = fmt.Errorf("claude task failed: %s", result) + } else { + streamErr = fmt.Errorf("claude task failed (is_error=true in result)") + } + } + // Prefer total_cost_usd from result message; fall through to legacy check below. + if cost, ok := msg["total_cost_usd"].(float64); ok { + totalCost = cost + } + case "user": + // Detect permission-denial tool_results. These occur when permission_mode + // is not bypassPermissions and claude exits 0 without completing its task. + if err := permissionDenialError(msg); err != nil && streamErr == nil { + streamErr = err + } + } + + // Legacy cost field used by older claude versions. + if cost, ok := msg["cost_usd"].(float64); ok { + totalCost = cost + } + } + if err := scanner.Err(); err != nil && streamErr == nil { + streamErr = fmt.Errorf("reading claude stdout: %w", err) + } + + return totalCost, sessionID, streamErr +} + + +// permissionDenialError inspects a "user" stream message for tool_result entries +// that were denied due to missing permissions. Returns an error if found. +func permissionDenialError(msg map[string]interface{}) error { + message, ok := msg["message"].(map[string]interface{}) + if !ok { + return nil + } + content, ok := message["content"].([]interface{}) + if !ok { + return nil + } + for _, item := range content { + itemMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + if itemMap["type"] != "tool_result" { + continue + } + if isErr, _ := itemMap["is_error"].(bool); !isErr { + continue + } + text, _ := itemMap["content"].(string) + if strings.Contains(text, "requested permissions") || strings.Contains(text, "haven't granted") { + return fmt.Errorf("permission denied by host: %s", text) + } + } + return nil +} + +// tailFile returns the last n lines of the file at path, or empty string if +// the file cannot be read. Used to surface subprocess stderr on failure. +func tailFile(path string, n int) string { + f, err := os.Open(path) + if err != nil { + return "" + } + defer f.Close() + + var lines []string + scanner := bufio.NewScanner(f) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + if len(lines) > n { + lines = lines[1:] + } + } + return strings.Join(lines, "\n") +} + +// readFileTail returns the last maxBytes bytes of the file at path as a string, +// or empty string if the file cannot be read. Used to surface agent stderr on failure. +func readFileTail(path string, maxBytes int64) string { + f, err := os.Open(path) + if err != nil { + return "" + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + return "" + } + offset := fi.Size() - maxBytes + if offset < 0 { + offset = 0 + } + buf := make([]byte, fi.Size()-offset) + n, err := f.ReadAt(buf, offset) + if err != nil && n == 0 { + return "" + } + return strings.TrimSpace(string(buf[:n])) +} + +func gitSafe(args ...string) []string { + return append([]string{ + "-c", "safe.directory=*", + "-c", "commit.gpgsign=false", + "-c", "tag.gpgsign=false", + }, args...) +} + +// isCompletionReport returns true when a question-file JSON looks like a +// completion report rather than a real user question. Heuristic: no options +// (or empty options) and no "?" anywhere in the text. +func isCompletionReport(questionJSON string) bool { + var q struct { + Text string `json:"text"` + Options []string `json:"options"` + } + if err := json.Unmarshal([]byte(questionJSON), &q); err != nil { + return false + } + return len(q.Options) == 0 && !strings.Contains(q.Text, "?") +} + +// extractQuestionText returns the "text" field from a question-file JSON, or +// the raw string if parsing fails. +func extractQuestionText(questionJSON string) string { + var q struct { + Text string `json:"text"` + } + if err := json.Unmarshal([]byte(questionJSON), &q); err != nil { + return questionJSON + } + return strings.TrimSpace(q.Text) +} diff --git a/internal/executor/preamble.go b/internal/executor/preamble.go index f5dba2b..b949986 100644 --- a/internal/executor/preamble.go +++ b/internal/executor/preamble.go @@ -45,6 +45,7 @@ The sandbox is rejected if there are any uncommitted modifications. - One commit is fine. Multiple focused commits are also fine. - If you realise the task was already done and you made no changes, that is also fine — just exit cleanly without committing. - Do not exit with uncommitted edits. +- **CRITICAL:** Run ALL git commands from your current directory — do NOT use absolute paths or "cd <project_path> && git ...". Your working directory IS the project. Using absolute paths bypasses the sandbox and breaks commit tracking. --- diff --git a/internal/executor/preamble_test.go b/internal/executor/preamble_test.go index 984f786..5c31b4f 100644 --- a/internal/executor/preamble_test.go +++ b/internal/executor/preamble_test.go @@ -22,3 +22,10 @@ func TestPlanningPreamble_SummaryInstructsEchoToFile(t *testing.T) { t.Error("planningPreamble should show example of writing to $CLAUDOMATOR_SUMMARY_FILE via echo") } } + +func TestPlanningPreamble_GitDiscipline_ForbidsAbsolutePaths(t *testing.T) { + // Agents must not bypass the sandbox by using absolute project paths in git commands. + if !strings.Contains(planningPreamble, "do NOT use absolute paths") { + t.Error("planningPreamble should warn agents not to use absolute paths in git commands") + } +} diff --git a/internal/executor/question.go b/internal/executor/question.go index 9a2b55d..0ae1b08 100644 --- a/internal/executor/question.go +++ b/internal/executor/question.go @@ -5,92 +5,8 @@ import ( "encoding/json" "io" "log/slog" - "sync" ) -// QuestionHandler is called when an agent invokes AskUserQuestion. -// Implementations should broadcast the question and block until an answer arrives. -type QuestionHandler interface { - HandleQuestion(taskID, toolUseID string, input json.RawMessage) (string, error) -} - -// PendingQuestion holds state for a question awaiting a user answer. -type PendingQuestion struct { - TaskID string `json:"task_id"` - ToolUseID string `json:"tool_use_id"` - Input json.RawMessage `json:"input"` - AnswerCh chan string `json:"-"` -} - -// QuestionRegistry tracks pending questions across running tasks. -type QuestionRegistry struct { - mu sync.Mutex - questions map[string]*PendingQuestion // keyed by toolUseID -} - -// NewQuestionRegistry creates a new registry. -func NewQuestionRegistry() *QuestionRegistry { - return &QuestionRegistry{ - questions: make(map[string]*PendingQuestion), - } -} - -// Register adds a pending question and returns its answer channel. -func (qr *QuestionRegistry) Register(taskID, toolUseID string, input json.RawMessage) chan string { - ch := make(chan string, 1) - qr.mu.Lock() - qr.questions[toolUseID] = &PendingQuestion{ - TaskID: taskID, - ToolUseID: toolUseID, - Input: input, - AnswerCh: ch, - } - qr.mu.Unlock() - return ch -} - -// Answer delivers an answer for a pending question. Returns false if no such question exists. -func (qr *QuestionRegistry) Answer(toolUseID, answer string) bool { - qr.mu.Lock() - pq, ok := qr.questions[toolUseID] - if ok { - delete(qr.questions, toolUseID) - } - qr.mu.Unlock() - if !ok { - return false - } - pq.AnswerCh <- answer - return true -} - -// Get returns a pending question by tool_use_id, or nil. -func (qr *QuestionRegistry) Get(toolUseID string) *PendingQuestion { - qr.mu.Lock() - defer qr.mu.Unlock() - return qr.questions[toolUseID] -} - -// PendingForTask returns all pending questions for a given task. -func (qr *QuestionRegistry) PendingForTask(taskID string) []*PendingQuestion { - qr.mu.Lock() - defer qr.mu.Unlock() - var result []*PendingQuestion - for _, pq := range qr.questions { - if pq.TaskID == taskID { - result = append(result, pq) - } - } - return result -} - -// Remove removes a question without answering it (e.g., on task cancellation). -func (qr *QuestionRegistry) Remove(toolUseID string) { - qr.mu.Lock() - delete(qr.questions, toolUseID) - qr.mu.Unlock() -} - // extractAskUserQuestion parses a stream-json line and returns the tool_use_id and input // if the line is an assistant event containing an AskUserQuestion tool_use. func extractAskUserQuestion(line []byte) (string, json.RawMessage) { diff --git a/internal/executor/question_test.go b/internal/executor/question_test.go index d0fbed9..6686c15 100644 --- a/internal/executor/question_test.go +++ b/internal/executor/question_test.go @@ -9,64 +9,6 @@ import ( "testing" ) -func TestQuestionRegistry_RegisterAndAnswer(t *testing.T) { - qr := NewQuestionRegistry() - - ch := qr.Register("task-1", "toolu_abc", json.RawMessage(`{"question":"color?"}`)) - - // Answer should unblock the channel. - go func() { - ok := qr.Answer("toolu_abc", "blue") - if !ok { - t.Error("Answer returned false, expected true") - } - }() - - answer := <-ch - if answer != "blue" { - t.Errorf("want 'blue', got %q", answer) - } - - // Question should be removed after answering. - if qr.Get("toolu_abc") != nil { - t.Error("question should be removed after answering") - } -} - -func TestQuestionRegistry_AnswerUnknown(t *testing.T) { - qr := NewQuestionRegistry() - ok := qr.Answer("nonexistent", "anything") - if ok { - t.Error("expected false for unknown question") - } -} - -func TestQuestionRegistry_PendingForTask(t *testing.T) { - qr := NewQuestionRegistry() - qr.Register("task-1", "toolu_1", json.RawMessage(`{}`)) - qr.Register("task-1", "toolu_2", json.RawMessage(`{}`)) - qr.Register("task-2", "toolu_3", json.RawMessage(`{}`)) - - pending := qr.PendingForTask("task-1") - if len(pending) != 2 { - t.Errorf("want 2 pending for task-1, got %d", len(pending)) - } - - pending2 := qr.PendingForTask("task-2") - if len(pending2) != 1 { - t.Errorf("want 1 pending for task-2, got %d", len(pending2)) - } -} - -func TestQuestionRegistry_Remove(t *testing.T) { - qr := NewQuestionRegistry() - qr.Register("task-1", "toolu_x", json.RawMessage(`{}`)) - qr.Remove("toolu_x") - if qr.Get("toolu_x") != nil { - t.Error("question should be removed") - } -} - func TestExtractAskUserQuestion_DetectsQuestion(t *testing.T) { // Simulate a stream-json assistant event containing an AskUserQuestion tool_use. event := map[string]interface{}{ diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go index 109aa49..ee9a336 100644 --- a/internal/executor/ratelimit.go +++ b/internal/executor/ratelimit.go @@ -13,5 +13,9 @@ func isQuotaExhausted(err error) bool { strings.Contains(msg, "you've hit your limit") || strings.Contains(msg, "you have hit your limit") || strings.Contains(msg, "rate limit reached (rejected)") || - strings.Contains(msg, "status: rejected") + strings.Contains(msg, "status: rejected") || + // Gemini CLI quota exhaustion + strings.Contains(msg, "terminalquotaerror") || + strings.Contains(msg, "exhausted your daily quota") || + strings.Contains(msg, "generate_content_free_tier_requests") } diff --git a/internal/executor/stream_test.go b/internal/executor/stream_test.go index 10eb858..11a6178 100644 --- a/internal/executor/stream_test.go +++ b/internal/executor/stream_test.go @@ -12,7 +12,7 @@ func streamLine(json string) string { return json + "\n" } func TestParseStream_ResultIsError_ReturnsError(t *testing.T) { input := streamLine(`{"type":"result","subtype":"error_during_execution","is_error":true,"result":"something went wrong"}`) - _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err == nil { t.Fatal("expected error when result.is_error=true, got nil") } @@ -27,7 +27,7 @@ func TestParseStream_PermissionDenied_ReturnsError(t *testing.T) { input := streamLine(`{"type":"user","message":{"role":"user","content":[{"type":"tool_result","is_error":true,"content":"Claude requested permissions to write to /foo/bar.go, but you haven't granted it yet.","tool_use_id":"tu_abc"}]}}`) + streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"I need permission","total_cost_usd":0.1}`) - _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err == nil { t.Fatal("expected error for permission denial, got nil") } @@ -40,7 +40,7 @@ func TestParseStream_Success_ReturnsNilError(t *testing.T) { input := streamLine(`{"type":"assistant","message":{"content":[{"type":"text","text":"Done."}]}}`) + streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"All tests pass.","total_cost_usd":0.05}`) - _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err != nil { t.Fatalf("expected nil error for success stream, got: %v", err) } @@ -49,7 +49,7 @@ func TestParseStream_Success_ReturnsNilError(t *testing.T) { func TestParseStream_ExtractsCostFromResultMessage(t *testing.T) { input := streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"done","total_cost_usd":1.2345}`) - cost, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + cost, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -62,7 +62,7 @@ func TestParseStream_ExtractsCostFromLegacyCostUSD(t *testing.T) { // Some versions emit cost_usd at the top level rather than total_cost_usd. input := streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"done","cost_usd":0.99}`) - cost, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + cost, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -78,8 +78,21 @@ func TestParseStream_NonToolResultIsError_DoesNotFail(t *testing.T) { input := streamLine(`{"type":"user","message":{"role":"user","content":[{"type":"tool_result","is_error":true,"content":"exit status 1","tool_use_id":"tu_xyz"}]}}`) + streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"Fixed it.","total_cost_usd":0.2}`) - _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err != nil { t.Fatalf("non-permission tool errors should not fail the task, got: %v", err) } } + +func TestParseStream_ExtractsSessionID(t *testing.T) { + input := streamLine(`{"type":"system","subtype":"init","session_id":"sess-999"}`) + + streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"ok","total_cost_usd":0.01}`) + + _, sid, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sid != "sess-999" { + t.Errorf("want session ID sess-999, got %q", sid) + } +} |
