summaryrefslogtreecommitdiff
path: root/internal/executor
diff options
context:
space:
mode:
Diffstat (limited to 'internal/executor')
-rw-r--r--internal/executor/claude.go152
-rw-r--r--internal/executor/claude_test.go84
-rw-r--r--internal/executor/executor.go138
-rw-r--r--internal/executor/executor_test.go206
4 files changed, 580 insertions, 0 deletions
diff --git a/internal/executor/claude.go b/internal/executor/claude.go
new file mode 100644
index 0000000..c845d58
--- /dev/null
+++ b/internal/executor/claude.go
@@ -0,0 +1,152 @@
+package executor
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "os"
+ "os/exec"
+ "path/filepath"
+
+ "github.com/claudomator/claudomator/internal/storage"
+ "github.com/claudomator/claudomator/internal/task"
+)
+
+// ClaudeRunner spawns the `claude` CLI in non-interactive mode.
+type ClaudeRunner struct {
+ BinaryPath string // defaults to "claude"
+ Logger *slog.Logger
+ LogDir string // base directory for execution logs
+}
+
+func (r *ClaudeRunner) binaryPath() string {
+ if r.BinaryPath != "" {
+ return r.BinaryPath
+ }
+ return "claude"
+}
+
+// Run executes a claude -p invocation, streaming output to log files.
+func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
+ args := r.buildArgs(t)
+
+ cmd := exec.CommandContext(ctx, r.binaryPath(), args...)
+ if t.Claude.WorkingDir != "" {
+ cmd.Dir = t.Claude.WorkingDir
+ }
+
+ // Setup log directory for this execution.
+ logDir := filepath.Join(r.LogDir, e.ID)
+ if err := os.MkdirAll(logDir, 0700); err != nil {
+ return fmt.Errorf("creating log dir: %w", err)
+ }
+
+ stdoutPath := filepath.Join(logDir, "stdout.log")
+ stderrPath := filepath.Join(logDir, "stderr.log")
+ e.StdoutPath = stdoutPath
+ e.StderrPath = stderrPath
+ e.ArtifactDir = logDir
+
+ stdoutFile, err := os.Create(stdoutPath)
+ if err != nil {
+ return fmt.Errorf("creating stdout log: %w", err)
+ }
+ defer stdoutFile.Close()
+
+ stderrFile, err := os.Create(stderrPath)
+ if err != nil {
+ return fmt.Errorf("creating stderr log: %w", err)
+ }
+ defer stderrFile.Close()
+
+ stdoutPipe, err := cmd.StdoutPipe()
+ if err != nil {
+ return fmt.Errorf("creating stdout pipe: %w", err)
+ }
+ stderrPipe, err := cmd.StderrPipe()
+ if err != nil {
+ return fmt.Errorf("creating stderr pipe: %w", err)
+ }
+
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("starting claude: %w", err)
+ }
+
+ // Stream output to log files and parse cost info.
+ var costUSD float64
+ go func() {
+ costUSD = streamAndParseCost(stdoutPipe, stdoutFile, r.Logger)
+ }()
+ go io.Copy(stderrFile, stderrPipe)
+
+ if err := cmd.Wait(); err != nil {
+ if exitErr, ok := err.(*exec.ExitError); ok {
+ e.ExitCode = exitErr.ExitCode()
+ }
+ e.CostUSD = costUSD
+ return fmt.Errorf("claude exited with error: %w", err)
+ }
+
+ e.ExitCode = 0
+ e.CostUSD = costUSD
+ return nil
+}
+
+func (r *ClaudeRunner) buildArgs(t *task.Task) []string {
+ args := []string{
+ "-p", t.Claude.Instructions,
+ "--output-format", "stream-json",
+ }
+
+ if t.Claude.Model != "" {
+ args = append(args, "--model", t.Claude.Model)
+ }
+ if t.Claude.MaxBudgetUSD > 0 {
+ args = append(args, "--max-budget-usd", fmt.Sprintf("%.2f", t.Claude.MaxBudgetUSD))
+ }
+ if t.Claude.PermissionMode != "" {
+ args = append(args, "--permission-mode", t.Claude.PermissionMode)
+ }
+ if t.Claude.SystemPromptAppend != "" {
+ args = append(args, "--append-system-prompt", t.Claude.SystemPromptAppend)
+ }
+ for _, tool := range t.Claude.AllowedTools {
+ args = append(args, "--allowedTools", tool)
+ }
+ for _, tool := range t.Claude.DisallowedTools {
+ args = append(args, "--disallowedTools", tool)
+ }
+ for _, f := range t.Claude.ContextFiles {
+ args = append(args, "--add-dir", f)
+ }
+ args = append(args, t.Claude.AdditionalArgs...)
+
+ return args
+}
+
+// streamAndParseCost reads streaming JSON from claude and writes to the log file,
+// extracting cost data from the stream.
+func streamAndParseCost(r io.Reader, w io.Writer, logger *slog.Logger) float64 {
+ tee := io.TeeReader(r, w)
+ scanner := bufio.NewScanner(tee)
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines
+
+ var totalCost float64
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ var msg map[string]interface{}
+ if err := json.Unmarshal(line, &msg); err != nil {
+ continue
+ }
+ // Extract cost from result messages.
+ if costData, ok := msg["cost_usd"]; ok {
+ if cost, ok := costData.(float64); ok {
+ totalCost = cost
+ }
+ }
+ }
+ return totalCost
+}
diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go
new file mode 100644
index 0000000..448ab40
--- /dev/null
+++ b/internal/executor/claude_test.go
@@ -0,0 +1,84 @@
+package executor
+
+import (
+ "testing"
+
+ "github.com/claudomator/claudomator/internal/task"
+)
+
+func TestClaudeRunner_BuildArgs_BasicTask(t *testing.T) {
+ r := &ClaudeRunner{}
+ tk := &task.Task{
+ Claude: task.ClaudeConfig{
+ Instructions: "fix the bug",
+ Model: "sonnet",
+ },
+ }
+
+ args := r.buildArgs(tk)
+
+ expected := []string{"-p", "fix the bug", "--output-format", "stream-json", "--model", "sonnet"}
+ if len(args) != len(expected) {
+ t.Fatalf("args length: want %d, got %d: %v", len(expected), len(args), args)
+ }
+ for i, want := range expected {
+ if args[i] != want {
+ t.Errorf("arg[%d]: want %q, got %q", i, want, args[i])
+ }
+ }
+}
+
+func TestClaudeRunner_BuildArgs_FullConfig(t *testing.T) {
+ r := &ClaudeRunner{}
+ tk := &task.Task{
+ Claude: task.ClaudeConfig{
+ Instructions: "implement feature",
+ Model: "opus",
+ MaxBudgetUSD: 5.0,
+ PermissionMode: "bypassPermissions",
+ SystemPromptAppend: "Follow TDD",
+ AllowedTools: []string{"Bash", "Edit"},
+ DisallowedTools: []string{"Write"},
+ ContextFiles: []string{"/src"},
+ AdditionalArgs: []string{"--verbose"},
+ },
+ }
+
+ args := r.buildArgs(tk)
+
+ // Check key args are present.
+ argMap := make(map[string]bool)
+ for _, a := range args {
+ argMap[a] = true
+ }
+
+ requiredArgs := []string{
+ "-p", "implement feature", "--output-format", "stream-json",
+ "--model", "opus", "--max-budget-usd", "5.00",
+ "--permission-mode", "bypassPermissions",
+ "--append-system-prompt", "Follow TDD",
+ "--allowedTools", "Bash", "Edit",
+ "--disallowedTools", "Write",
+ "--add-dir", "/src",
+ "--verbose",
+ }
+ for _, req := range requiredArgs {
+ if !argMap[req] {
+ t.Errorf("missing arg %q in %v", req, args)
+ }
+ }
+}
+
+func TestClaudeRunner_BinaryPath_Default(t *testing.T) {
+ r := &ClaudeRunner{}
+ if r.binaryPath() != "claude" {
+ t.Errorf("want 'claude', got %q", r.binaryPath())
+ }
+}
+
+func TestClaudeRunner_BinaryPath_Custom(t *testing.T) {
+ r := &ClaudeRunner{BinaryPath: "/usr/local/bin/claude"}
+ if r.binaryPath() != "/usr/local/bin/claude" {
+ t.Errorf("want custom path, got %q", r.binaryPath())
+ }
+}
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
new file mode 100644
index 0000000..c6c5124
--- /dev/null
+++ b/internal/executor/executor.go
@@ -0,0 +1,138 @@
+package executor
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "sync"
+ "time"
+
+ "github.com/claudomator/claudomator/internal/storage"
+ "github.com/claudomator/claudomator/internal/task"
+ "github.com/google/uuid"
+)
+
+// Runner executes a single task and returns the result.
+type Runner interface {
+ Run(ctx context.Context, t *task.Task, exec *storage.Execution) error
+}
+
+// Pool manages a bounded set of concurrent task workers.
+type Pool struct {
+ maxConcurrent int
+ runner Runner
+ store *storage.DB
+ logger *slog.Logger
+
+ mu sync.Mutex
+ active int
+ resultCh chan *Result
+}
+
+// Result is emitted when a task execution completes.
+type Result struct {
+ TaskID string
+ Execution *storage.Execution
+ Err error
+}
+
+func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.Logger) *Pool {
+ if maxConcurrent < 1 {
+ maxConcurrent = 1
+ }
+ return &Pool{
+ maxConcurrent: maxConcurrent,
+ runner: runner,
+ store: store,
+ logger: logger,
+ resultCh: make(chan *Result, maxConcurrent*2),
+ }
+}
+
+// Submit dispatches a task for execution. Blocks if pool is at capacity.
+func (p *Pool) Submit(ctx context.Context, t *task.Task) error {
+ p.mu.Lock()
+ if p.active >= p.maxConcurrent {
+ active := p.active
+ max := p.maxConcurrent
+ p.mu.Unlock()
+ return fmt.Errorf("executor pool at capacity (%d/%d)", active, max)
+ }
+ p.active++
+ p.mu.Unlock()
+
+ go p.execute(ctx, t)
+ return nil
+}
+
+// Results returns the channel for reading execution results.
+func (p *Pool) Results() <-chan *Result {
+ return p.resultCh
+}
+
+// ActiveCount returns the number of currently running tasks.
+func (p *Pool) ActiveCount() int {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.active
+}
+
+func (p *Pool) execute(ctx context.Context, t *task.Task) {
+ execID := uuid.New().String()
+ exec := &storage.Execution{
+ ID: execID,
+ TaskID: t.ID,
+ StartTime: time.Now().UTC(),
+ Status: "RUNNING",
+ }
+
+ // Record execution start.
+ if err := p.store.CreateExecution(exec); err != nil {
+ p.logger.Error("failed to create execution record", "error", err)
+ }
+ if err := p.store.UpdateTaskState(t.ID, task.StateRunning); err != nil {
+ p.logger.Error("failed to update task state", "error", err)
+ }
+
+ // Apply task timeout.
+ var cancel context.CancelFunc
+ if t.Timeout.Duration > 0 {
+ ctx, cancel = context.WithTimeout(ctx, t.Timeout.Duration)
+ } else {
+ ctx, cancel = context.WithCancel(ctx)
+ }
+ defer cancel()
+
+ // Run the task.
+ err := p.runner.Run(ctx, t, exec)
+ exec.EndTime = time.Now().UTC()
+
+ if err != nil {
+ if ctx.Err() == context.DeadlineExceeded {
+ exec.Status = "TIMED_OUT"
+ exec.ErrorMsg = "execution timed out"
+ p.store.UpdateTaskState(t.ID, task.StateTimedOut)
+ } else if ctx.Err() == context.Canceled {
+ exec.Status = "CANCELLED"
+ exec.ErrorMsg = "execution cancelled"
+ p.store.UpdateTaskState(t.ID, task.StateCancelled)
+ } else {
+ exec.Status = "FAILED"
+ exec.ErrorMsg = err.Error()
+ p.store.UpdateTaskState(t.ID, task.StateFailed)
+ }
+ } else {
+ exec.Status = "COMPLETED"
+ p.store.UpdateTaskState(t.ID, task.StateCompleted)
+ }
+
+ if updateErr := p.store.UpdateExecution(exec); updateErr != nil {
+ p.logger.Error("failed to update execution", "error", updateErr)
+ }
+
+ p.mu.Lock()
+ p.active--
+ p.mu.Unlock()
+
+ p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err}
+}
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
new file mode 100644
index 0000000..acce95b
--- /dev/null
+++ b/internal/executor/executor_test.go
@@ -0,0 +1,206 @@
+package executor
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/claudomator/claudomator/internal/storage"
+ "github.com/claudomator/claudomator/internal/task"
+)
+
+// mockRunner implements Runner for testing.
+type mockRunner struct {
+ mu sync.Mutex
+ calls int
+ delay time.Duration
+ err error
+ exitCode int
+}
+
+func (m *mockRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
+ m.mu.Lock()
+ m.calls++
+ m.mu.Unlock()
+
+ if m.delay > 0 {
+ select {
+ case <-time.After(m.delay):
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+ if m.err != nil {
+ e.ExitCode = m.exitCode
+ return m.err
+ }
+ return nil
+}
+
+func (m *mockRunner) callCount() int {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.calls
+}
+
+func testStore(t *testing.T) *storage.DB {
+ t.Helper()
+ dbPath := filepath.Join(t.TempDir(), "test.db")
+ db, err := storage.Open(dbPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() { db.Close() })
+ return db
+}
+
+func makeTask(id string) *task.Task {
+ now := time.Now().UTC()
+ return &task.Task{
+ ID: id, Name: "Test " + id,
+ Claude: task.ClaudeConfig{Instructions: "test"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StateQueued,
+ CreatedAt: now, UpdatedAt: now,
+ }
+}
+
+func TestPool_Submit_Success(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runner, store, logger)
+
+ tk := makeTask("ps-1")
+ store.CreateTask(tk)
+
+ if err := pool.Submit(context.Background(), tk); err != nil {
+ t.Fatalf("submit: %v", err)
+ }
+
+ result := <-pool.Results()
+ if result.Err != nil {
+ t.Errorf("expected no error, got: %v", result.Err)
+ }
+ if result.Execution.Status != "COMPLETED" {
+ t.Errorf("status: want COMPLETED, got %q", result.Execution.Status)
+ }
+
+ // Verify task state in DB.
+ got, _ := store.GetTask("ps-1")
+ if got.State != task.StateCompleted {
+ t.Errorf("task state: want COMPLETED, got %v", got.State)
+ }
+}
+
+func TestPool_Submit_Failure(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{err: fmt.Errorf("boom"), exitCode: 1}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runner, store, logger)
+
+ tk := makeTask("pf-1")
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+
+ result := <-pool.Results()
+ if result.Err == nil {
+ t.Fatal("expected error")
+ }
+ if result.Execution.Status != "FAILED" {
+ t.Errorf("status: want FAILED, got %q", result.Execution.Status)
+ }
+}
+
+func TestPool_Submit_Timeout(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{delay: 5 * time.Second}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runner, store, logger)
+
+ tk := makeTask("pt-1")
+ tk.Timeout.Duration = 50 * time.Millisecond
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+
+ result := <-pool.Results()
+ if result.Execution.Status != "TIMED_OUT" {
+ t.Errorf("status: want TIMED_OUT, got %q", result.Execution.Status)
+ }
+}
+
+func TestPool_Submit_Cancellation(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{delay: 5 * time.Second}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runner, store, logger)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ tk := makeTask("pc-1")
+ store.CreateTask(tk)
+ pool.Submit(ctx, tk)
+
+ time.Sleep(20 * time.Millisecond)
+ cancel()
+
+ result := <-pool.Results()
+ if result.Execution.Status != "CANCELLED" {
+ t.Errorf("status: want CANCELLED, got %q", result.Execution.Status)
+ }
+}
+
+func TestPool_AtCapacity(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{delay: time.Second}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(1, runner, store, logger)
+
+ tk1 := makeTask("cap-1")
+ store.CreateTask(tk1)
+ pool.Submit(context.Background(), tk1)
+
+ // Pool is at capacity, second submit should fail.
+ time.Sleep(10 * time.Millisecond) // let goroutine start
+ tk2 := makeTask("cap-2")
+ store.CreateTask(tk2)
+ err := pool.Submit(context.Background(), tk2)
+ if err == nil {
+ t.Fatal("expected capacity error")
+ }
+
+ <-pool.Results() // drain
+}
+
+func TestPool_ConcurrentExecution(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{delay: 50 * time.Millisecond}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(3, runner, store, logger)
+
+ for i := 0; i < 3; i++ {
+ tk := makeTask(fmt.Sprintf("cc-%d", i))
+ store.CreateTask(tk)
+ if err := pool.Submit(context.Background(), tk); err != nil {
+ t.Fatalf("submit %d: %v", i, err)
+ }
+ }
+
+ for i := 0; i < 3; i++ {
+ result := <-pool.Results()
+ if result.Execution.Status != "COMPLETED" {
+ t.Errorf("task %s: want COMPLETED, got %q", result.TaskID, result.Execution.Status)
+ }
+ }
+
+ if runner.callCount() != 3 {
+ t.Errorf("calls: want 3, got %d", runner.callCount())
+ }
+}