summaryrefslogtreecommitdiff
path: root/internal/executor
diff options
context:
space:
mode:
Diffstat (limited to 'internal/executor')
-rw-r--r--internal/executor/container.go98
-rw-r--r--internal/executor/container_test.go52
-rw-r--r--internal/executor/helpers.go4
3 files changed, 122 insertions, 32 deletions
diff --git a/internal/executor/container.go b/internal/executor/container.go
index d21aea3..45758d2 100644
--- a/internal/executor/container.go
+++ b/internal/executor/container.go
@@ -17,12 +17,23 @@ import (
// ContainerRunner executes an agent inside a container.
type ContainerRunner struct {
- Image string // default image if not specified in task
- Logger *slog.Logger
- LogDir string
- APIURL string
- DropsDir string
- SSHAuthSock string // optional path to host SSH agent
+ Image string // default image if not specified in task
+ Logger *slog.Logger
+ LogDir string
+ APIURL string
+ DropsDir string
+ SSHAuthSock string // optional path to host SSH agent
+ ClaudeBinary string // optional path to claude binary in container
+ GeminiBinary string // optional path to gemini binary in container
+ // Command allows mocking exec.CommandContext for tests.
+ Command func(ctx context.Context, name string, arg ...string) *exec.Cmd
+}
+
+func (r *ContainerRunner) command(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if r.Command != nil {
+ return r.Command(ctx, name, arg...)
+ }
+ return exec.CommandContext(ctx, name, arg...)
}
func (r *ContainerRunner) ExecLogDir(execID string) string {
@@ -88,7 +99,11 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
// 2. Clone repo into workspace if not resuming
if !isResume {
r.Logger.Info("cloning repository", "url", repoURL, "workspace", workspace)
- if out, err := exec.CommandContext(ctx, "git", "clone", repoURL, workspace).CombinedOutput(); err != nil {
+ if out, err := r.command(ctx, "git", "clone", repoURL, workspace).CombinedOutput(); err != nil {
+ // If it looks like a remote URL, fail fast.
+ if strings.HasPrefix(repoURL, "http") || strings.HasPrefix(repoURL, "git@") || strings.HasPrefix(repoURL, "ssh://") {
+ return fmt.Errorf("git clone failed for remote repository: %w\n%s", err, string(out))
+ }
r.Logger.Warn("git clone failed, attempting fallback init", "url", repoURL, "error", err)
if initErr := r.fallbackGitInit(repoURL, workspace); initErr != nil {
return fmt.Errorf("git clone and fallback init failed: %w\n%s", err, string(out))
@@ -143,7 +158,7 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
fullArgs = append(fullArgs, innerCmd...)
r.Logger.Info("starting container", "image", image, "taskID", t.ID)
- cmd := exec.CommandContext(ctx, "docker", fullArgs...)
+ cmd := r.command(ctx, "docker", fullArgs...)
cmd.Stderr = stderrFile
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
@@ -162,6 +177,18 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
}
stdoutW.Close()
+ // Watch for context cancellation to kill the process group (Issue 1)
+ done := make(chan struct{})
+ defer close(done)
+ go func() {
+ select {
+ case <-ctx.Done():
+ r.Logger.Info("killing container process group due to context cancellation", "taskID", t.ID)
+ syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
+ case <-done:
+ }
+ }()
+
// Stream stdout to the log file and parse cost/errors.
var costUSD float64
var sessionID string
@@ -193,6 +220,9 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
} else {
isBlocked = true
success = true // We consider BLOCKED as a "success" for workspace preservation
+ if e.SessionID == "" {
+ r.Logger.Warn("missing session ID; resume will start fresh", "taskID", e.TaskID)
+ }
return &BlockedError{
QuestionJSON: questionJSON,
SessionID: e.SessionID,
@@ -210,14 +240,24 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
// 5. Post-execution: push changes if successful
if waitErr == nil && streamErr == nil {
- r.Logger.Info("pushing changes back to remote", "url", repoURL)
- // We assume the sandbox has committed changes (the agent image should enforce this)
- if out, err := exec.CommandContext(ctx, "git", "-C", workspace, "push", "origin", "HEAD").CombinedOutput(); err != nil {
- r.Logger.Warn("git push failed or no changes", "error", err, "output", string(out))
- // Only set success = true if we consider this "good enough".
- // Review says: "If the remote is missing or the push fails, the task is marked FAILED and the host-side workspace is preserved"
- // So we MUST return error here.
- return fmt.Errorf("git push failed: %w\n%s", err, string(out))
+ // Check if there are any commits to push (Issue 10)
+ // We use rev-list to see if HEAD is ahead of origin/HEAD.
+ // If origin/HEAD doesn't exist (e.g. fresh init), we just attempt to push.
+ hasCommits := true
+ if out, err := r.command(ctx, "git", "-C", workspace, "rev-list", "origin/HEAD..HEAD").CombinedOutput(); err == nil {
+ if len(strings.TrimSpace(string(out))) == 0 {
+ hasCommits = false
+ }
+ }
+
+ if hasCommits {
+ r.Logger.Info("pushing changes back to remote", "url", repoURL)
+ if out, err := r.command(ctx, "git", "-C", workspace, "push", "origin", "HEAD").CombinedOutput(); err != nil {
+ r.Logger.Warn("git push failed", "error", err, "output", string(out))
+ return fmt.Errorf("git push failed: %w\n%s", err, string(out))
+ }
+ } else {
+ r.Logger.Info("no new commits to push", "taskID", t.ID)
}
success = true
}
@@ -235,7 +275,7 @@ func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Exec
func (r *ContainerRunner) buildDockerArgs(workspace, taskID string) []string {
// --env-file takes a HOST path.
hostEnvFile := filepath.Join(workspace, ".claudomator-env")
- return []string{
+ args := []string{
"run", "--rm",
"-v", workspace + ":/workspace",
"-w", "/workspace",
@@ -244,28 +284,42 @@ func (r *ContainerRunner) buildDockerArgs(workspace, taskID string) []string {
"-e", "CLAUDOMATOR_TASK_ID=" + taskID,
"-e", "CLAUDOMATOR_DROP_DIR=" + r.DropsDir,
}
+ if r.SSHAuthSock != "" {
+ args = append(args, "-v", r.SSHAuthSock+":/tmp/ssh-auth.sock", "-e", "SSH_AUTH_SOCK=/tmp/ssh-auth.sock")
+ }
+ return args
}
func (r *ContainerRunner) buildInnerCmd(t *task.Task, e *storage.Execution, isResume bool) []string {
// Claude CLI uses -p for prompt text. To pass a file, we use a shell to cat it.
// We use a shell variable to capture the expansion to avoid quoting issues with instructions contents.
// The outer single quotes around the sh -c argument prevent host-side expansion.
-
+
+ claudeBin := r.ClaudeBinary
+ if claudeBin == "" {
+ claudeBin = "claude"
+ }
+ geminiBin := r.GeminiBinary
+ if geminiBin == "" {
+ geminiBin = "gemini"
+ }
+
if t.Agent.Type == "gemini" {
- return []string{"sh", "-c", "INST=$(cat /workspace/.claudomator-instructions.txt); gemini -p \"$INST\""}
+ return []string{"sh", "-c", fmt.Sprintf("INST=$(cat /workspace/.claudomator-instructions.txt); %s -p \"$INST\"", geminiBin)}
}
// Claude
var claudeCmd strings.Builder
- claudeCmd.WriteString("INST=$(cat /workspace/.claudomator-instructions.txt); claude -p \"$INST\"")
+ claudeCmd.WriteString(fmt.Sprintf("INST=$(cat /workspace/.claudomator-instructions.txt); %s -p \"$INST\"", claudeBin))
if isResume && e.ResumeSessionID != "" {
claudeCmd.WriteString(fmt.Sprintf(" --resume %s", e.ResumeSessionID))
}
claudeCmd.WriteString(" --output-format stream-json --verbose --permission-mode bypassPermissions")
-
+
return []string{"sh", "-c", claudeCmd.String()}
}
+
func (r *ContainerRunner) fallbackGitInit(repoURL, workspace string) error {
// Ensure directory exists
if err := os.MkdirAll(workspace, 0755); err != nil {
@@ -281,7 +335,7 @@ func (r *ContainerRunner) fallbackGitInit(repoURL, workspace string) error {
// git clone handle local paths fine if they are repos.
// This fallback is only if it's NOT a repo.
for _, args := range cmds {
- if out, err := exec.Command("git", args...).CombinedOutput(); err != nil {
+ if out, err := r.command(context.Background(), "git", args...).CombinedOutput(); err != nil {
return fmt.Errorf("git init failed: %w\n%s", err, out)
}
}
diff --git a/internal/executor/container_test.go b/internal/executor/container_test.go
index 0e36def..d4d591e 100644
--- a/internal/executor/container_test.go
+++ b/internal/executor/container_test.go
@@ -6,6 +6,7 @@ import (
"io"
"log/slog"
"os"
+ "os/exec"
"strings"
"testing"
@@ -15,14 +16,15 @@ import (
func TestContainerRunner_BuildDockerArgs(t *testing.T) {
runner := &ContainerRunner{
- APIURL: "http://localhost:8484",
- DropsDir: "/data/drops",
+ APIURL: "http://localhost:8484",
+ DropsDir: "/data/drops",
+ SSHAuthSock: "/tmp/ssh.sock",
}
workspace := "/tmp/ws"
taskID := "task-123"
args := runner.buildDockerArgs(workspace, taskID)
-
+
expected := []string{
"run", "--rm",
"-v", "/tmp/ws:/workspace",
@@ -31,11 +33,12 @@ func TestContainerRunner_BuildDockerArgs(t *testing.T) {
"-e", "CLAUDOMATOR_API_URL=http://localhost:8484",
"-e", "CLAUDOMATOR_TASK_ID=task-123",
"-e", "CLAUDOMATOR_DROP_DIR=/data/drops",
+ "-v", "/tmp/ssh.sock:/tmp/ssh-auth.sock",
+ "-e", "SSH_AUTH_SOCK=/tmp/ssh-auth.sock",
}
-
if len(args) != len(expected) {
- t.Fatalf("expected %d args, got %d", len(expected), len(args))
+ t.Fatalf("expected %d args, got %d. Got: %v", len(expected), len(args), args)
}
for i, v := range args {
if v != expected[i] {
@@ -76,12 +79,31 @@ func TestContainerRunner_BuildInnerCmd(t *testing.T) {
tk := &task.Task{Agent: task.AgentConfig{Type: "gemini"}}
exec := &storage.Execution{}
cmd := runner.buildInnerCmd(tk, exec, false)
-
+
cmdStr := strings.Join(cmd, " ")
if !strings.Contains(cmdStr, "gemini -p \"$INST\"") {
t.Errorf("expected gemini command with safer quoting, got %q", cmdStr)
}
})
+
+ t.Run("custom-binaries", func(t *testing.T) {
+ runnerCustom := &ContainerRunner{
+ ClaudeBinary: "/usr/bin/claude-v2",
+ GeminiBinary: "/usr/local/bin/gemini-pro",
+ }
+
+ tkClaude := &task.Task{Agent: task.AgentConfig{Type: "claude"}}
+ cmdClaude := runnerCustom.buildInnerCmd(tkClaude, &storage.Execution{}, false)
+ if !strings.Contains(strings.Join(cmdClaude, " "), "/usr/bin/claude-v2 -p") {
+ t.Errorf("expected custom claude binary, got %q", cmdClaude)
+ }
+
+ tkGemini := &task.Task{Agent: task.AgentConfig{Type: "gemini"}}
+ cmdGemini := runnerCustom.buildInnerCmd(tkGemini, &storage.Execution{}, false)
+ if !strings.Contains(strings.Join(cmdGemini, " "), "/usr/local/bin/gemini-pro -p") {
+ t.Errorf("expected custom gemini binary, got %q", cmdGemini)
+ }
+ })
}
func TestContainerRunner_Run_PreservesWorkspaceOnFailure(t *testing.T) {
@@ -89,19 +111,31 @@ func TestContainerRunner_Run_PreservesWorkspaceOnFailure(t *testing.T) {
runner := &ContainerRunner{
Logger: logger,
Image: "busybox",
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ // Mock docker run to exit 1
+ if name == "docker" {
+ return exec.Command("sh", "-c", "exit 1")
+ }
+ // Mock git clone to succeed and create the directory
+ if name == "git" && len(arg) > 0 && arg[0] == "clone" {
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ return exec.Command("true")
+ }
+ return exec.Command("true")
+ },
}
- // Use an invalid repo URL to trigger failure.
tk := &task.Task{
ID: "test-task",
- RepositoryURL: "/nonexistent/repo",
+ RepositoryURL: "https://github.com/example/repo.git",
Agent: task.AgentConfig{Type: "claude"},
}
exec := &storage.Execution{ID: "test-exec", TaskID: "test-task"}
err := runner.Run(context.Background(), tk, exec)
if err == nil {
- t.Fatal("expected error due to invalid repo")
+ t.Fatal("expected error due to mocked docker failure")
}
// Verify SandboxDir was set and directory exists.
diff --git a/internal/executor/helpers.go b/internal/executor/helpers.go
index 36cd050..9e4530b 100644
--- a/internal/executor/helpers.go
+++ b/internal/executor/helpers.go
@@ -33,6 +33,7 @@ func parseStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, string
var sessionID string
var streamErr error
+Loop:
for scanner.Scan() {
line := scanner.Bytes()
var msg map[string]interface{}
@@ -54,7 +55,7 @@ func parseStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, string
if status == "rejected" {
streamErr = fmt.Errorf("claude rate limit reached (rejected): %v", msg)
// Immediately break since we can't continue anyway
- break
+ break Loop
}
}
case "assistant":
@@ -91,6 +92,7 @@ func parseStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, string
return totalCost, sessionID, streamErr
}
+
// permissionDenialError inspects a "user" stream message for tool_result entries
// that were denied due to missing permissions. Returns an error if found.
func permissionDenialError(msg map[string]interface{}) error {