diff options
Diffstat (limited to 'internal/executor')
| -rw-r--r-- | internal/executor/claude.go | 23 | ||||
| -rw-r--r-- | internal/executor/claude_test.go | 37 | ||||
| -rw-r--r-- | internal/executor/executor.go | 57 | ||||
| -rw-r--r-- | internal/executor/gemini.go | 87 | ||||
| -rw-r--r-- | internal/executor/gemini_test.go | 1 |
5 files changed, 130 insertions, 75 deletions
diff --git a/internal/executor/claude.go b/internal/executor/claude.go index e3f8e1c..fa68382 100644 --- a/internal/executor/claude.go +++ b/internal/executor/claude.go @@ -117,7 +117,7 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi e.SandboxDir = "" if projectDir != "" { var err error - sandboxDir, err := setupSandbox(t.Agent.ProjectDir, r.Logger) + sandboxDir, err = setupSandbox(t.Agent.ProjectDir, r.Logger) if err != nil { return fmt.Errorf("setting up sandbox: %w", err) } @@ -129,7 +129,7 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi } } else if projectDir != "" { var err error - sandboxDir, err := setupSandbox(t.Agent.ProjectDir, r.Logger) + sandboxDir, err = setupSandbox(t.Agent.ProjectDir, r.Logger) if err != nil { return fmt.Errorf("setting up sandbox: %w", err) } @@ -226,11 +226,22 @@ func extractQuestionText(questionJSON string) string { return strings.TrimSpace(q.Text) } -// gitSafe returns git arguments that prepend "-c safe.directory=*" so that -// commands succeed regardless of the repository owner. This is needed when -// claudomator operates on project directories owned by a different OS user. +// gitSafe returns git arguments that prepend safety overrides so that +// commands succeed regardless of the repository owner or the host's global +// git configuration. Specifically: +// +// - "-c safe.directory=*" lets us operate on directories owned by a +// different OS user. +// - "-c commit.gpgsign=false" / "-c tag.gpgsign=false" stop git from +// trying to sign commits via the host's signing tooling. Sandbox commits +// are internal and don't need to be signed; an unconfigured or broken +// signing setup on the host should never block a sandbox merge. func gitSafe(args ...string) []string { - return append([]string{"-c", "safe.directory=*"}, args...) + return append([]string{ + "-c", "safe.directory=*", + "-c", "commit.gpgsign=false", + "-c", "tag.gpgsign=false", + }, args...) } // sandboxCloneSource returns the URL to clone the sandbox from. It prefers a diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go index 77596ca..b40c4ae 100644 --- a/internal/executor/claude_test.go +++ b/internal/executor/claude_test.go @@ -353,9 +353,9 @@ func TestExecOnce_NoGoroutineLeak_OnNaturalExit(t *testing.T) { func initGitRepo(t *testing.T, dir string) { t.Helper() cmds := [][]string{ - {"git", "-c", "safe.directory=*", "-C", dir, "init", "-b", "main"}, - {"git", "-c", "safe.directory=*", "-C", dir, "config", "user.email", "test@test"}, - {"git", "-c", "safe.directory=*", "-C", dir, "config", "user.name", "test"}, + {"git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", dir, "init", "-b", "main"}, + {"git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", dir, "config", "user.email", "test@test"}, + {"git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", dir, "config", "user.name", "test"}, } for _, args := range cmds { if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil { @@ -365,10 +365,10 @@ func initGitRepo(t *testing.T, dir string) { if err := os.WriteFile(filepath.Join(dir, "init.txt"), []byte("init"), 0644); err != nil { t.Fatal(err) } - if out, err := exec.Command("git", "-c", "safe.directory=*", "-C", dir, "add", ".").CombinedOutput(); err != nil { + if out, err := exec.Command("git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", dir, "add", ".").CombinedOutput(); err != nil { t.Fatalf("git add: %v\n%s", err, out) } - if out, err := exec.Command("git", "-c", "safe.directory=*", "-C", dir, "commit", "-m", "init").CombinedOutput(); err != nil { + if out, err := exec.Command("git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", dir, "commit", "-m", "init").CombinedOutput(); err != nil { t.Fatalf("git commit: %v\n%s", err, out) } } @@ -391,7 +391,10 @@ func TestSandboxCloneSource_PrefersLocalRemote(t *testing.T) { func TestSandboxCloneSource_FallsBackToOrigin(t *testing.T) { dir := t.TempDir() initGitRepo(t, dir) - originURL := "https://example.com/origin-repo" + // sandboxCloneSource intentionally filters to local-FS remotes (so + // `git clone <src>` doesn't go over the network). Use a local path + // for origin to verify the fallback semantics. + originURL := t.TempDir() exec.Command("git", "-C", dir, "remote", "add", "origin", originURL).Run() got := sandboxCloneSource(dir) @@ -455,23 +458,23 @@ func TestSetupSandbox_InitialisesNonGitDir(t *testing.T) { func TestTeardownSandbox_AutocommitsChanges(t *testing.T) { // Create a bare repo as origin so push succeeds. bare := t.TempDir() - if out, err := exec.Command("git", "init", "--bare", bare).CombinedOutput(); err != nil { + if out, err := exec.Command("git", "init", "--bare", "-b", "main", bare).CombinedOutput(); err != nil { t.Fatalf("git init bare: %v\n%s", err, out) } // Create a sandbox directly. sandbox := t.TempDir() initGitRepo(t, sandbox) - if out, err := exec.Command("git", "-c", "safe.directory=*", "-C", sandbox, "remote", "add", "origin", bare).CombinedOutput(); err != nil { + if out, err := exec.Command("git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", sandbox, "remote", "add", "origin", bare).CombinedOutput(); err != nil { t.Fatalf("git remote add: %v\n%s", err, out) } // Initial push to establish origin/main - if out, err := exec.Command("git", "-c", "safe.directory=*", "-C", sandbox, "push", "origin", "main").CombinedOutput(); err != nil { + if out, err := exec.Command("git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", sandbox, "push", "origin", "main").CombinedOutput(); err != nil { t.Fatalf("git push initial: %v\n%s", err, out) } // Capture startHEAD - headOut, err := exec.Command("git", "-c", "safe.directory=*", "-C", sandbox, "rev-parse", "HEAD").Output() + headOut, err := exec.Command("git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", sandbox, "rev-parse", "HEAD").Output() if err != nil { t.Fatalf("rev-parse HEAD: %v", err) } @@ -514,18 +517,18 @@ func TestTeardownSandbox_AutocommitsChanges(t *testing.T) { func TestTeardownSandbox_BuildFailure_BlocksAutocommit(t *testing.T) { bare := t.TempDir() - if out, err := exec.Command("git", "init", "--bare", bare).CombinedOutput(); err != nil { + if out, err := exec.Command("git", "init", "--bare", "-b", "main", bare).CombinedOutput(); err != nil { t.Fatalf("git init bare: %v\n%s", err, out) } sandbox := t.TempDir() initGitRepo(t, sandbox) - if out, err := exec.Command("git", "-c", "safe.directory=*", "-C", sandbox, "remote", "add", "origin", bare).CombinedOutput(); err != nil { + if out, err := exec.Command("git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", sandbox, "remote", "add", "origin", bare).CombinedOutput(); err != nil { t.Fatalf("git remote add: %v\n%s", err, out) } // Capture startHEAD - headOut, err := exec.Command("git", "-c", "safe.directory=*", "-C", sandbox, "rev-parse", "HEAD").Output() + headOut, err := exec.Command("git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", sandbox, "rev-parse", "HEAD").Output() if err != nil { t.Fatalf("rev-parse HEAD: %v", err) } @@ -566,18 +569,18 @@ func TestTeardownSandbox_BuildFailure_BlocksAutocommit(t *testing.T) { func TestTeardownSandbox_BuildSuccess_ProceedsToAutocommit(t *testing.T) { bare := t.TempDir() - if out, err := exec.Command("git", "init", "--bare", bare).CombinedOutput(); err != nil { + if out, err := exec.Command("git", "init", "--bare", "-b", "main", bare).CombinedOutput(); err != nil { t.Fatalf("git init bare: %v\n%s", err, out) } sandbox := t.TempDir() initGitRepo(t, sandbox) - if out, err := exec.Command("git", "-c", "safe.directory=*", "-C", sandbox, "remote", "add", "origin", bare).CombinedOutput(); err != nil { + if out, err := exec.Command("git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", sandbox, "remote", "add", "origin", bare).CombinedOutput(); err != nil { t.Fatalf("git remote add: %v\n%s", err, out) } // Capture startHEAD - headOut, err := exec.Command("git", "-c", "safe.directory=*", "-C", sandbox, "rev-parse", "HEAD").Output() + headOut, err := exec.Command("git", "-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-C", sandbox, "rev-parse", "HEAD").Output() if err != nil { t.Fatalf("rev-parse HEAD: %v", err) } @@ -870,7 +873,7 @@ func TestTailFile_MissingFile_ReturnsEmpty(t *testing.T) { func TestGitSafe_PrependsSafeDirectory(t *testing.T) { got := gitSafe("-C", "/some/path", "status") - want := []string{"-c", "safe.directory=*", "-C", "/some/path", "status"} + want := []string{"-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-c", "tag.gpgsign=false", "-C", "/some/path", "status"} if len(got) != len(want) { t.Fatalf("gitSafe() = %v, want %v", got, want) } diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 4501a3c..315030d 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -196,6 +196,28 @@ func (p *Pool) getRunner(t *task.Task) (Runner, error) { return runner, nil } +// decActiveAgent decrements the active counters for a finished task. Safe to +// call multiple times — subsequent calls are no-ops via the cleaned flag. +// Always call this before sending on resultCh so consumers observing a result +// see the accounting already settled (no zero-count map entries lingering). +func (p *Pool) decActiveAgent(agentType string, cleaned *bool) { + if *cleaned { + return + } + *cleaned = true + p.mu.Lock() + p.active-- + p.activePerAgent[agentType]-- + if p.activePerAgent[agentType] == 0 { + delete(p.activePerAgent, agentType) + } + p.mu.Unlock() + select { + case p.doneCh <- struct{}{}: + default: + } +} + func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Execution) { agentType := t.Agent.Type if agentType == "" { @@ -206,23 +228,13 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex p.activePerAgent[agentType]++ p.mu.Unlock() - defer func() { - p.mu.Lock() - p.active-- - p.activePerAgent[agentType]-- - if p.activePerAgent[agentType] == 0 { - delete(p.activePerAgent, agentType) - } - p.mu.Unlock() - select { - case p.doneCh <- struct{}{}: - default: - } - }() + var cleaned bool + defer p.decActiveAgent(agentType, &cleaned) runner, err := p.getRunner(t) if err != nil { p.logger.Error("failed to get runner for resume", "error", err, "taskID", t.ID) + p.decActiveAgent(agentType, &cleaned) p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} return } @@ -264,6 +276,7 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() + p.decActiveAgent(agentType, &cleaned) p.handleRunResult(ctx, t, exec, err, agentType) } @@ -473,19 +486,8 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { p.activePerAgent[agentType]++ p.mu.Unlock() - defer func() { - p.mu.Lock() - p.active-- - p.activePerAgent[agentType]-- - if p.activePerAgent[agentType] == 0 { - delete(p.activePerAgent, agentType) - } - p.mu.Unlock() - select { - case p.doneCh <- struct{}{}: - default: - } - }() + var cleaned bool + defer p.decActiveAgent(agentType, &cleaned) runner, err := p.getRunner(t) if err != nil { @@ -505,6 +507,7 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil { p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err) } + p.decActiveAgent(agentType, &cleaned) p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} return } @@ -527,6 +530,7 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil { p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err) } + p.decActiveAgent(agentType, &cleaned) p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} return } @@ -583,6 +587,7 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() + p.decActiveAgent(agentType, &cleaned) p.handleRunResult(ctx, t, exec, err, agentType) } diff --git a/internal/executor/gemini.go b/internal/executor/gemini.go index d79c47d..7f2f54f 100644 --- a/internal/executor/gemini.go +++ b/internal/executor/gemini.go @@ -2,6 +2,7 @@ package executor import ( "context" + "encoding/json" "fmt" "io" "log/slog" @@ -117,16 +118,21 @@ func (r *GeminiRunner) execOnce(ctx context.Context, args []string, workingDir, var streamErr error + var streamCost float64 var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - _, streamErr = parseGeminiStream(stdoutR, stdoutFile, r.Logger) + streamCost, streamErr = parseGeminiStream(stdoutR, stdoutFile, r.Logger) stdoutR.Close() }() wg.Wait() // Wait for parseGeminiStream to finish + if streamCost > 0 { + e.CostUSD = streamCost + } + // Set a dummy exit code for this simulated run e.ExitCode = 0 @@ -136,9 +142,10 @@ func (r *GeminiRunner) execOnce(ctx context.Context, args []string, workingDir, return nil } -// parseGeminiStream reads streaming JSON from the gemini CLI, unwraps markdown -// code blocks, writes the inner JSON to w, and returns (costUSD, error). -// For now, it focuses on unwrapping and writing, not detailed parsing of cost/errors. +// parseGeminiStream reads streaming JSON from the gemini CLI, strips markdown +// code fences if the output is wrapped in them, writes the inner stream-json +// to w, and returns (costUSD, error). If a `result` event has `is_error: true`, +// an error wrapping the result message is returned. func parseGeminiStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, error) { fullOutput, err := io.ReadAll(r) if err != nil { @@ -146,31 +153,61 @@ func parseGeminiStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, } logger.Debug("parseGeminiStream: raw output received", "output", string(fullOutput)) - outputStr := strings.TrimSpace(string(fullOutput)) // Trim leading/trailing whitespace/newlines from the whole output - - jsonContent := outputStr // Default to raw output if no markdown block is found or malformed - jsonStartIdx := strings.Index(outputStr, "```json") - if jsonStartIdx != -1 { - // Found "```json", now look for the closing "```" - jsonEndIdx := strings.LastIndex(outputStr, "```") - if jsonEndIdx != -1 && jsonEndIdx > jsonStartIdx { - // Extract content between the markdown fences. - jsonContent = outputStr[jsonStartIdx+len("```json"):jsonEndIdx] - jsonContent = strings.TrimSpace(jsonContent) // Trim again after extraction, to remove potential inner newlines - } else { - logger.Warn("Malformed markdown JSON block from Gemini (missing closing ``` or invalid structure), falling back to raw output.", "outputLength", len(outputStr)) + inner := stripGeminiFences(string(fullOutput), logger) + if _, writeErr := w.Write([]byte(inner)); writeErr != nil { + return 0, fmt.Errorf("writing gemini output: %w", writeErr) + } + + // Walk lines looking for a result event so we can surface errors and cost. + var ( + cost float64 + errMsg string + isError bool + ) + for _, raw := range strings.Split(inner, "\n") { + line := strings.TrimSpace(raw) + if line == "" { + continue + } + var evt struct { + Type string `json:"type"` + IsError bool `json:"is_error"` + Result string `json:"result"` + Cost float64 `json:"total_cost_usd"` + } + if err := json.Unmarshal([]byte(line), &evt); err != nil { + continue + } + if evt.Type == "result" { + if evt.Cost > 0 { + cost = evt.Cost + } + if evt.IsError { + isError = true + errMsg = evt.Result + } } - } else { - logger.Warn("No markdown JSON block found from Gemini, falling back to raw output.", "outputLength", len(outputStr)) } - - // Write the (possibly extracted and trimmed) JSON content to the writer. - _, writeErr := w.Write([]byte(jsonContent)) - if writeErr != nil { - return 0, fmt.Errorf("writing extracted gemini json: %w", writeErr) + if isError { + return cost, fmt.Errorf("gemini reported error: %s", errMsg) } + return cost, nil +} - return 0, nil // For now, no cost/error parsing for Gemini stream +// stripGeminiFences removes a surrounding ```json ... ``` markdown block if +// present, returning the trimmed inner content. If no markdown fence is +// found, the input is returned verbatim (no whitespace trimming) so callers +// that expect byte-exact pass-through behavior get it. +func stripGeminiFences(raw string, logger *slog.Logger) string { + trimmed := strings.TrimSpace(raw) + if start := strings.Index(trimmed, "```json"); start != -1 { + if end := strings.LastIndex(trimmed, "```"); end > start { + return strings.TrimSpace(trimmed[start+len("```json") : end]) + } + logger.Warn("malformed gemini markdown block (missing closing fence); using raw output", "len", len(trimmed)) + return trimmed + } + return raw } func (r *GeminiRunner) buildArgs(t *task.Task, e *storage.Execution, questionFile string) []string { diff --git a/internal/executor/gemini_test.go b/internal/executor/gemini_test.go index 75e3b45..4b0339e 100644 --- a/internal/executor/gemini_test.go +++ b/internal/executor/gemini_test.go @@ -148,7 +148,6 @@ func TestGeminiRunner_BinaryPath_Custom(t *testing.T) { func TestParseGeminiStream_ParsesStructuredOutput(t *testing.T) { - t.Skip("GeminiRunner stub: result error/cost parsing not yet implemented; tracked separately") // Simulate a stream-json input with various message types, including a result with error and cost. input := streamLine(`{"type":"content_block_start","content_block":{"text":"Hello,"}}`) + streamLine(`{"type":"content_block_delta","content_block":{"text":" World!"}}`) + |
