diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/server.go | 2 | ||||
| -rw-r--r-- | internal/api/webhook.go | 3 | ||||
| -rw-r--r-- | internal/api/webhook_test.go | 8 | ||||
| -rw-r--r-- | internal/cli/serve.go | 31 | ||||
| -rw-r--r-- | internal/executor/container.go | 172 | ||||
| -rw-r--r-- | internal/storage/db.go | 19 | ||||
| -rw-r--r-- | internal/task/task.go | 7 |
7 files changed, 215 insertions, 27 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index 48440e1..64d2c3a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -424,6 +424,7 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { Description string `json:"description"` ElaborationInput string `json:"elaboration_input"` Project string `json:"project"` + RepositoryURL string `json:"repository_url"` Agent task.AgentConfig `json:"agent"` Claude task.AgentConfig `json:"claude"` // legacy alias Timeout string `json:"timeout"` @@ -448,6 +449,7 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { Description: input.Description, ElaborationInput: input.ElaborationInput, Project: input.Project, + RepositoryURL: input.RepositoryURL, Agent: input.Agent, Priority: task.Priority(input.Priority), Tags: input.Tags, diff --git a/internal/api/webhook.go b/internal/api/webhook.go index 0530f3e..a28b43f 100644 --- a/internal/api/webhook.go +++ b/internal/api/webhook.go @@ -219,7 +219,8 @@ func (s *Server) createCIFailureTask(w http.ResponseWriter, repoName, fullName, UpdatedAt: now, } if project != nil { - t.Agent.ProjectDir = project.Dir + t.RepositoryURL = fmt.Sprintf("https://github.com/%s.git", fullName) + t.Project = project.Name } if err := s.store.CreateTask(t); err != nil { diff --git a/internal/api/webhook_test.go b/internal/api/webhook_test.go index 1bc4aaa..0fc9664 100644 --- a/internal/api/webhook_test.go +++ b/internal/api/webhook_test.go @@ -124,8 +124,8 @@ func TestGitHubWebhook_CheckRunFailure_CreatesTask(t *testing.T) { if !strings.Contains(tk.Name, "main") { t.Errorf("task name %q does not contain branch", tk.Name) } - if tk.Agent.ProjectDir != "/workspace/myrepo" { - t.Errorf("task project dir = %q, want /workspace/myrepo", tk.Agent.ProjectDir) + if tk.RepositoryURL != "https://github.com/owner/myrepo.git" { + t.Errorf("task repository url = %q, want https://github.com/owner/myrepo.git", tk.RepositoryURL) } if !contains(tk.Tags, "ci") || !contains(tk.Tags, "auto") { t.Errorf("task tags %v missing expected ci/auto tags", tk.Tags) @@ -375,8 +375,8 @@ func TestGitHubWebhook_FallbackToSingleProject(t *testing.T) { if err != nil { t.Fatalf("task not found: %v", err) } - if tk.Agent.ProjectDir != "/workspace/someapp" { - t.Errorf("expected fallback to /workspace/someapp, got %q", tk.Agent.ProjectDir) + if tk.RepositoryURL != "https://github.com/owner/myrepo.git" { + t.Errorf("expected fallback repository url, got %q", tk.RepositoryURL) } } diff --git a/internal/cli/serve.go b/internal/cli/serve.go index 5677562..56947bf 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -74,19 +74,26 @@ func serve(addr string) error { } runners := map[string]executor.Runner{ - "claude": &executor.ClaudeRunner{ - BinaryPath: cfg.ClaudeBinaryPath, - Logger: logger, - LogDir: cfg.LogDir, - APIURL: apiURL, - DropsDir: cfg.DropsDir, + "claude": &executor.ContainerRunner{ + Image: "claudomator-agent:latest", + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + DropsDir: cfg.DropsDir, }, - "gemini": &executor.GeminiRunner{ - BinaryPath: cfg.GeminiBinaryPath, - Logger: logger, - LogDir: cfg.LogDir, - APIURL: apiURL, - DropsDir: cfg.DropsDir, + "gemini": &executor.ContainerRunner{ + Image: "claudomator-agent:latest", + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + DropsDir: cfg.DropsDir, + }, + "container": &executor.ContainerRunner{ + Image: "claudomator-agent:latest", + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + DropsDir: cfg.DropsDir, }, } diff --git a/internal/executor/container.go b/internal/executor/container.go new file mode 100644 index 0000000..e148620 --- /dev/null +++ b/internal/executor/container.go @@ -0,0 +1,172 @@ +package executor + +import ( + "context" + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" + "sync" + "syscall" + + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +// ContainerRunner executes an agent inside a container. +type ContainerRunner struct { + Image string // default image if not specified in task + Logger *slog.Logger + LogDir string + APIURL string + DropsDir string + SSHAuthSock string // optional path to host SSH agent +} + +func (r *ContainerRunner) ExecLogDir(execID string) string { + if r.LogDir == "" { + return "" + } + return filepath.Join(r.LogDir, execID) +} + +func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + repoURL := t.RepositoryURL + if repoURL == "" { + // Fallback to project_dir if repository_url is not set (legacy support) + if t.Agent.ProjectDir != "" { + repoURL = t.Agent.ProjectDir + } else { + return fmt.Errorf("task %s has no repository_url or project_dir", t.ID) + } + } + + image := t.Agent.ContainerImage + if image == "" { + image = r.Image + } + if image == "" { + image = "claudomator-agent:latest" + } + + // 1. Setup workspace on host + workspace, err := os.MkdirTemp("", "claudomator-workspace-*") + if err != nil { + return fmt.Errorf("creating workspace: %w", err) + } + defer os.RemoveAll(workspace) + + // 2. Clone repo into workspace + r.Logger.Info("cloning repository", "url", repoURL, "workspace", workspace) + if out, err := exec.CommandContext(ctx, "git", "clone", repoURL, workspace).CombinedOutput(); err != nil { + return fmt.Errorf("git clone failed: %w\n%s", err, string(out)) + } + + // 3. Prepare logs + logDir := r.ExecLogDir(e.ID) + if logDir == "" { + logDir = filepath.Join(workspace, ".claudomator-logs") + } + if err := os.MkdirAll(logDir, 0700); err != nil { + return fmt.Errorf("creating log dir: %w", err) + } + e.StdoutPath = filepath.Join(logDir, "stdout.log") + e.StderrPath = filepath.Join(logDir, "stderr.log") + e.ArtifactDir = logDir + + stdoutFile, err := os.Create(e.StdoutPath) + if err != nil { + return fmt.Errorf("creating stdout log: %w", err) + } + defer stdoutFile.Close() + + stderrFile, err := os.Create(e.StderrPath) + if err != nil { + return fmt.Errorf("creating stderr log: %w", err) + } + defer stderrFile.Close() + + // 4. Run container + // Build docker command + args := []string{ + "run", "--rm", + "-v", workspace + ":/workspace", + "-w", "/workspace", + "-e", "CLAUDOMATOR_API_URL=" + r.APIURL, + "-e", "CLAUDOMATOR_TASK_ID=" + e.TaskID, + "-e", "CLAUDOMATOR_DROP_DIR=" + r.DropsDir, + "-e", "ANTHROPIC_API_KEY=" + os.Getenv("ANTHROPIC_API_KEY"), + "-e", "GOOGLE_API_KEY=" + os.Getenv("GOOGLE_API_KEY"), + } + + // Inject custom instructions as environment variable or via file + instructionsFile := filepath.Join(workspace, ".claudomator-instructions.txt") + if err := os.WriteFile(instructionsFile, []byte(t.Agent.Instructions), 0600); err != nil { + return fmt.Errorf("writing instructions: %w", err) + } + + // Command to run inside container: we assume the image has 'claude' or 'gemini' + // and a wrapper script that reads CLAUDOMATOR_TASK_ID etc. + innerCmd := []string{"claude", "-p", t.Agent.Instructions, "--session-id", e.ID, "--output-format", "stream-json", "--verbose", "--permission-mode", "bypassPermissions"} + if t.Agent.Type == "gemini" { + innerCmd = []string{"gemini", "-p", t.Agent.Instructions} // simplified for now + } + + args = append(args, image) + args = append(args, innerCmd...) + + r.Logger.Info("starting container", "image", image, "taskID", t.ID) + cmd := exec.CommandContext(ctx, "docker", args...) + cmd.Stderr = stderrFile + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + // Use os.Pipe for stdout so we can parse it in real-time + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + return fmt.Errorf("creating stdout pipe: %w", err) + } + cmd.Stdout = stdoutW + + if err := cmd.Start(); err != nil { + stdoutW.Close() + stdoutR.Close() + return fmt.Errorf("starting container: %w", err) + } + stdoutW.Close() + + // Stream stdout to the log file and parse cost/errors. + var costUSD float64 + var streamErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + costUSD, streamErr = parseStream(stdoutR, stdoutFile, r.Logger) + stdoutR.Close() + }() + + waitErr := cmd.Wait() + wg.Wait() + + e.CostUSD = costUSD + + // 5. Post-execution: push changes if successful + if waitErr == nil && streamErr == nil { + r.Logger.Info("pushing changes back to remote", "url", repoURL) + // We assume the sandbox has committed changes (the agent image should enforce this) + if out, err := exec.CommandContext(ctx, "git", "-C", workspace, "push", "origin", "HEAD").CombinedOutput(); err != nil { + r.Logger.Warn("git push failed", "error", err, "output", string(out)) + // Don't fail the task just because push failed, but record it? + // Actually, user said: "they should only ever commit to their sandbox, and only ever push to an actual remote" + // So push failure is a task failure in this new model. + return fmt.Errorf("git push failed: %w\n%s", err, string(out)) + } + } + + if waitErr != nil { + return fmt.Errorf("container execution failed: %w", waitErr) + } + + return nil +} diff --git a/internal/storage/db.go b/internal/storage/db.go index 25801b2..8bc9864 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -87,6 +87,7 @@ func (s *DB) migrate() error { `ALTER TABLE executions ADD COLUMN commits_json TEXT NOT NULL DEFAULT '[]'`, `ALTER TABLE tasks ADD COLUMN elaboration_input TEXT`, `ALTER TABLE tasks ADD COLUMN project TEXT`, + `ALTER TABLE tasks ADD COLUMN repository_url TEXT`, `CREATE TABLE IF NOT EXISTS push_subscriptions ( id TEXT PRIMARY KEY, endpoint TEXT NOT NULL UNIQUE, @@ -135,9 +136,9 @@ func (s *DB) CreateTask(t *task.Task) error { } _, err = s.db.Exec(` - INSERT INTO tasks (id, name, description, elaboration_input, project, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - t.ID, t.Name, t.Description, t.ElaborationInput, t.Project, string(configJSON), string(t.Priority), + INSERT INTO tasks (id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + t.ID, t.Name, t.Description, t.ElaborationInput, t.Project, t.RepositoryURL, string(configJSON), string(t.Priority), t.Timeout.Duration.Nanoseconds(), string(retryJSON), string(tagsJSON), string(depsJSON), t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), ) @@ -146,13 +147,13 @@ func (s *DB) CreateTask(t *task.Task) error { // GetTask retrieves a task by ID. func (s *DB) GetTask(id string) (*task.Task, error) { - row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, project, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id) + row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id) return scanTask(row) } // ListTasks returns tasks matching the given filter. func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { - query := `SELECT id, name, description, elaboration_input, project, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE 1=1` + query := `SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE 1=1` var args []interface{} if filter.State != "" { @@ -188,7 +189,7 @@ func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { // ListSubtasks returns all tasks whose parent_task_id matches the given ID. func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) { - rows, err := s.db.Query(`SELECT id, name, description, elaboration_input, project, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) + rows, err := s.db.Query(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) if err != nil { return nil, err } @@ -241,7 +242,7 @@ func (s *DB) ResetTaskForRetry(id string) (*task.Task, error) { } defer tx.Rollback() //nolint:errcheck - t, err := scanTask(tx.QueryRow(`SELECT id, name, description, elaboration_input, project, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id)) + t, err := scanTask(tx.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id)) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("task %q not found", id) @@ -688,15 +689,17 @@ func scanTask(row scanner) (*task.Task, error) { parentTaskID sql.NullString elaborationInput sql.NullString project sql.NullString + repositoryURL sql.NullString rejectionComment sql.NullString questionJSON sql.NullString summary sql.NullString interactionsJSON sql.NullString ) - err := row.Scan(&t.ID, &t.Name, &t.Description, &elaborationInput, &project, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt, &rejectionComment, &questionJSON, &summary, &interactionsJSON) + err := row.Scan(&t.ID, &t.Name, &t.Description, &elaborationInput, &project, &repositoryURL, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt, &rejectionComment, &questionJSON, &summary, &interactionsJSON) t.ParentTaskID = parentTaskID.String t.ElaborationInput = elaborationInput.String t.Project = project.String + t.RepositoryURL = repositoryURL.String t.RejectionComment = rejectionComment.String t.QuestionJSON = questionJSON.String t.Summary = summary.String diff --git a/internal/task/task.go b/internal/task/task.go index 3a04716..5b2fff3 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -32,7 +32,9 @@ type AgentConfig struct { Model string `yaml:"model" json:"model"` ContextFiles []string `yaml:"context_files" json:"context_files"` Instructions string `yaml:"instructions" json:"instructions"` - ProjectDir string `yaml:"project_dir" json:"project_dir"` + RepositoryURL string `yaml:"repository_url" json:"repository_url"` + ContainerImage string `yaml:"container_image" json:"container_image"` + ProjectDir string `yaml:"project_dir" json:"project_dir"` // Deprecated: use RepositoryURL MaxBudgetUSD float64 `yaml:"max_budget_usd" json:"max_budget_usd"` PermissionMode string `yaml:"permission_mode" json:"permission_mode"` AllowedTools []string `yaml:"allowed_tools" json:"allowed_tools"` @@ -74,7 +76,8 @@ type Task struct { ParentTaskID string `yaml:"parent_task_id" json:"parent_task_id"` Name string `yaml:"name" json:"name"` Description string `yaml:"description" json:"description"` - Project string `yaml:"project" json:"project"` + Project string `yaml:"project" json:"project"` // Human-readable project name + RepositoryURL string `yaml:"repository_url" json:"repository_url"` Agent AgentConfig `yaml:"agent" json:"agent"` Timeout Duration `yaml:"timeout" json:"timeout"` Retry RetryConfig `yaml:"retry" json:"retry"` |
