diff options
Diffstat (limited to 'internal/executor')
| -rw-r--r-- | internal/executor/claude.go | 152 | ||||
| -rw-r--r-- | internal/executor/claude_test.go | 84 | ||||
| -rw-r--r-- | internal/executor/executor.go | 138 | ||||
| -rw-r--r-- | internal/executor/executor_test.go | 206 |
4 files changed, 580 insertions, 0 deletions
diff --git a/internal/executor/claude.go b/internal/executor/claude.go new file mode 100644 index 0000000..c845d58 --- /dev/null +++ b/internal/executor/claude.go @@ -0,0 +1,152 @@ +package executor + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" +) + +// ClaudeRunner spawns the `claude` CLI in non-interactive mode. +type ClaudeRunner struct { + BinaryPath string // defaults to "claude" + Logger *slog.Logger + LogDir string // base directory for execution logs +} + +func (r *ClaudeRunner) binaryPath() string { + if r.BinaryPath != "" { + return r.BinaryPath + } + return "claude" +} + +// Run executes a claude -p invocation, streaming output to log files. +func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + args := r.buildArgs(t) + + cmd := exec.CommandContext(ctx, r.binaryPath(), args...) + if t.Claude.WorkingDir != "" { + cmd.Dir = t.Claude.WorkingDir + } + + // Setup log directory for this execution. + logDir := filepath.Join(r.LogDir, e.ID) + if err := os.MkdirAll(logDir, 0700); err != nil { + return fmt.Errorf("creating log dir: %w", err) + } + + stdoutPath := filepath.Join(logDir, "stdout.log") + stderrPath := filepath.Join(logDir, "stderr.log") + e.StdoutPath = stdoutPath + e.StderrPath = stderrPath + e.ArtifactDir = logDir + + stdoutFile, err := os.Create(stdoutPath) + if err != nil { + return fmt.Errorf("creating stdout log: %w", err) + } + defer stdoutFile.Close() + + stderrFile, err := os.Create(stderrPath) + if err != nil { + return fmt.Errorf("creating stderr log: %w", err) + } + defer stderrFile.Close() + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("creating stdout pipe: %w", err) + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("creating stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("starting claude: %w", err) + } + + // Stream output to log files and parse cost info. + var costUSD float64 + go func() { + costUSD = streamAndParseCost(stdoutPipe, stdoutFile, r.Logger) + }() + go io.Copy(stderrFile, stderrPipe) + + if err := cmd.Wait(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + e.ExitCode = exitErr.ExitCode() + } + e.CostUSD = costUSD + return fmt.Errorf("claude exited with error: %w", err) + } + + e.ExitCode = 0 + e.CostUSD = costUSD + return nil +} + +func (r *ClaudeRunner) buildArgs(t *task.Task) []string { + args := []string{ + "-p", t.Claude.Instructions, + "--output-format", "stream-json", + } + + if t.Claude.Model != "" { + args = append(args, "--model", t.Claude.Model) + } + if t.Claude.MaxBudgetUSD > 0 { + args = append(args, "--max-budget-usd", fmt.Sprintf("%.2f", t.Claude.MaxBudgetUSD)) + } + if t.Claude.PermissionMode != "" { + args = append(args, "--permission-mode", t.Claude.PermissionMode) + } + if t.Claude.SystemPromptAppend != "" { + args = append(args, "--append-system-prompt", t.Claude.SystemPromptAppend) + } + for _, tool := range t.Claude.AllowedTools { + args = append(args, "--allowedTools", tool) + } + for _, tool := range t.Claude.DisallowedTools { + args = append(args, "--disallowedTools", tool) + } + for _, f := range t.Claude.ContextFiles { + args = append(args, "--add-dir", f) + } + args = append(args, t.Claude.AdditionalArgs...) + + return args +} + +// streamAndParseCost reads streaming JSON from claude and writes to the log file, +// extracting cost data from the stream. +func streamAndParseCost(r io.Reader, w io.Writer, logger *slog.Logger) float64 { + tee := io.TeeReader(r, w) + scanner := bufio.NewScanner(tee) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines + + var totalCost float64 + for scanner.Scan() { + line := scanner.Bytes() + var msg map[string]interface{} + if err := json.Unmarshal(line, &msg); err != nil { + continue + } + // Extract cost from result messages. + if costData, ok := msg["cost_usd"]; ok { + if cost, ok := costData.(float64); ok { + totalCost = cost + } + } + } + return totalCost +} diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go new file mode 100644 index 0000000..448ab40 --- /dev/null +++ b/internal/executor/claude_test.go @@ -0,0 +1,84 @@ +package executor + +import ( + "testing" + + "github.com/claudomator/claudomator/internal/task" +) + +func TestClaudeRunner_BuildArgs_BasicTask(t *testing.T) { + r := &ClaudeRunner{} + tk := &task.Task{ + Claude: task.ClaudeConfig{ + Instructions: "fix the bug", + Model: "sonnet", + }, + } + + args := r.buildArgs(tk) + + expected := []string{"-p", "fix the bug", "--output-format", "stream-json", "--model", "sonnet"} + if len(args) != len(expected) { + t.Fatalf("args length: want %d, got %d: %v", len(expected), len(args), args) + } + for i, want := range expected { + if args[i] != want { + t.Errorf("arg[%d]: want %q, got %q", i, want, args[i]) + } + } +} + +func TestClaudeRunner_BuildArgs_FullConfig(t *testing.T) { + r := &ClaudeRunner{} + tk := &task.Task{ + Claude: task.ClaudeConfig{ + Instructions: "implement feature", + Model: "opus", + MaxBudgetUSD: 5.0, + PermissionMode: "bypassPermissions", + SystemPromptAppend: "Follow TDD", + AllowedTools: []string{"Bash", "Edit"}, + DisallowedTools: []string{"Write"}, + ContextFiles: []string{"/src"}, + AdditionalArgs: []string{"--verbose"}, + }, + } + + args := r.buildArgs(tk) + + // Check key args are present. + argMap := make(map[string]bool) + for _, a := range args { + argMap[a] = true + } + + requiredArgs := []string{ + "-p", "implement feature", "--output-format", "stream-json", + "--model", "opus", "--max-budget-usd", "5.00", + "--permission-mode", "bypassPermissions", + "--append-system-prompt", "Follow TDD", + "--allowedTools", "Bash", "Edit", + "--disallowedTools", "Write", + "--add-dir", "/src", + "--verbose", + } + for _, req := range requiredArgs { + if !argMap[req] { + t.Errorf("missing arg %q in %v", req, args) + } + } +} + +func TestClaudeRunner_BinaryPath_Default(t *testing.T) { + r := &ClaudeRunner{} + if r.binaryPath() != "claude" { + t.Errorf("want 'claude', got %q", r.binaryPath()) + } +} + +func TestClaudeRunner_BinaryPath_Custom(t *testing.T) { + r := &ClaudeRunner{BinaryPath: "/usr/local/bin/claude"} + if r.binaryPath() != "/usr/local/bin/claude" { + t.Errorf("want custom path, got %q", r.binaryPath()) + } +} diff --git a/internal/executor/executor.go b/internal/executor/executor.go new file mode 100644 index 0000000..c6c5124 --- /dev/null +++ b/internal/executor/executor.go @@ -0,0 +1,138 @@ +package executor + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" + "github.com/google/uuid" +) + +// Runner executes a single task and returns the result. +type Runner interface { + Run(ctx context.Context, t *task.Task, exec *storage.Execution) error +} + +// Pool manages a bounded set of concurrent task workers. +type Pool struct { + maxConcurrent int + runner Runner + store *storage.DB + logger *slog.Logger + + mu sync.Mutex + active int + resultCh chan *Result +} + +// Result is emitted when a task execution completes. +type Result struct { + TaskID string + Execution *storage.Execution + Err error +} + +func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.Logger) *Pool { + if maxConcurrent < 1 { + maxConcurrent = 1 + } + return &Pool{ + maxConcurrent: maxConcurrent, + runner: runner, + store: store, + logger: logger, + resultCh: make(chan *Result, maxConcurrent*2), + } +} + +// Submit dispatches a task for execution. Blocks if pool is at capacity. +func (p *Pool) Submit(ctx context.Context, t *task.Task) error { + p.mu.Lock() + if p.active >= p.maxConcurrent { + active := p.active + max := p.maxConcurrent + p.mu.Unlock() + return fmt.Errorf("executor pool at capacity (%d/%d)", active, max) + } + p.active++ + p.mu.Unlock() + + go p.execute(ctx, t) + return nil +} + +// Results returns the channel for reading execution results. +func (p *Pool) Results() <-chan *Result { + return p.resultCh +} + +// ActiveCount returns the number of currently running tasks. +func (p *Pool) ActiveCount() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.active +} + +func (p *Pool) execute(ctx context.Context, t *task.Task) { + execID := uuid.New().String() + exec := &storage.Execution{ + ID: execID, + TaskID: t.ID, + StartTime: time.Now().UTC(), + Status: "RUNNING", + } + + // Record execution start. + if err := p.store.CreateExecution(exec); err != nil { + p.logger.Error("failed to create execution record", "error", err) + } + if err := p.store.UpdateTaskState(t.ID, task.StateRunning); err != nil { + p.logger.Error("failed to update task state", "error", err) + } + + // Apply task timeout. + var cancel context.CancelFunc + if t.Timeout.Duration > 0 { + ctx, cancel = context.WithTimeout(ctx, t.Timeout.Duration) + } else { + ctx, cancel = context.WithCancel(ctx) + } + defer cancel() + + // Run the task. + err := p.runner.Run(ctx, t, exec) + exec.EndTime = time.Now().UTC() + + if err != nil { + if ctx.Err() == context.DeadlineExceeded { + exec.Status = "TIMED_OUT" + exec.ErrorMsg = "execution timed out" + p.store.UpdateTaskState(t.ID, task.StateTimedOut) + } else if ctx.Err() == context.Canceled { + exec.Status = "CANCELLED" + exec.ErrorMsg = "execution cancelled" + p.store.UpdateTaskState(t.ID, task.StateCancelled) + } else { + exec.Status = "FAILED" + exec.ErrorMsg = err.Error() + p.store.UpdateTaskState(t.ID, task.StateFailed) + } + } else { + exec.Status = "COMPLETED" + p.store.UpdateTaskState(t.ID, task.StateCompleted) + } + + if updateErr := p.store.UpdateExecution(exec); updateErr != nil { + p.logger.Error("failed to update execution", "error", updateErr) + } + + p.mu.Lock() + p.active-- + p.mu.Unlock() + + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} +} diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go new file mode 100644 index 0000000..acce95b --- /dev/null +++ b/internal/executor/executor_test.go @@ -0,0 +1,206 @@ +package executor + +import ( + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" +) + +// mockRunner implements Runner for testing. +type mockRunner struct { + mu sync.Mutex + calls int + delay time.Duration + err error + exitCode int +} + +func (m *mockRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + m.mu.Lock() + m.calls++ + m.mu.Unlock() + + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return ctx.Err() + } + } + if m.err != nil { + e.ExitCode = m.exitCode + return m.err + } + return nil +} + +func (m *mockRunner) callCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.calls +} + +func testStore(t *testing.T) *storage.DB { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "test.db") + db, err := storage.Open(dbPath) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { db.Close() }) + return db +} + +func makeTask(id string) *task.Task { + now := time.Now().UTC() + return &task.Task{ + ID: id, Name: "Test " + id, + Claude: task.ClaudeConfig{Instructions: "test"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateQueued, + CreatedAt: now, UpdatedAt: now, + } +} + +func TestPool_Submit_Success(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + tk := makeTask("ps-1") + store.CreateTask(tk) + + if err := pool.Submit(context.Background(), tk); err != nil { + t.Fatalf("submit: %v", err) + } + + result := <-pool.Results() + if result.Err != nil { + t.Errorf("expected no error, got: %v", result.Err) + } + if result.Execution.Status != "COMPLETED" { + t.Errorf("status: want COMPLETED, got %q", result.Execution.Status) + } + + // Verify task state in DB. + got, _ := store.GetTask("ps-1") + if got.State != task.StateCompleted { + t.Errorf("task state: want COMPLETED, got %v", got.State) + } +} + +func TestPool_Submit_Failure(t *testing.T) { + store := testStore(t) + runner := &mockRunner{err: fmt.Errorf("boom"), exitCode: 1} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + tk := makeTask("pf-1") + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + result := <-pool.Results() + if result.Err == nil { + t.Fatal("expected error") + } + if result.Execution.Status != "FAILED" { + t.Errorf("status: want FAILED, got %q", result.Execution.Status) + } +} + +func TestPool_Submit_Timeout(t *testing.T) { + store := testStore(t) + runner := &mockRunner{delay: 5 * time.Second} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + tk := makeTask("pt-1") + tk.Timeout.Duration = 50 * time.Millisecond + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + result := <-pool.Results() + if result.Execution.Status != "TIMED_OUT" { + t.Errorf("status: want TIMED_OUT, got %q", result.Execution.Status) + } +} + +func TestPool_Submit_Cancellation(t *testing.T) { + store := testStore(t) + runner := &mockRunner{delay: 5 * time.Second} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + ctx, cancel := context.WithCancel(context.Background()) + tk := makeTask("pc-1") + store.CreateTask(tk) + pool.Submit(ctx, tk) + + time.Sleep(20 * time.Millisecond) + cancel() + + result := <-pool.Results() + if result.Execution.Status != "CANCELLED" { + t.Errorf("status: want CANCELLED, got %q", result.Execution.Status) + } +} + +func TestPool_AtCapacity(t *testing.T) { + store := testStore(t) + runner := &mockRunner{delay: time.Second} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(1, runner, store, logger) + + tk1 := makeTask("cap-1") + store.CreateTask(tk1) + pool.Submit(context.Background(), tk1) + + // Pool is at capacity, second submit should fail. + time.Sleep(10 * time.Millisecond) // let goroutine start + tk2 := makeTask("cap-2") + store.CreateTask(tk2) + err := pool.Submit(context.Background(), tk2) + if err == nil { + t.Fatal("expected capacity error") + } + + <-pool.Results() // drain +} + +func TestPool_ConcurrentExecution(t *testing.T) { + store := testStore(t) + runner := &mockRunner{delay: 50 * time.Millisecond} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(3, runner, store, logger) + + for i := 0; i < 3; i++ { + tk := makeTask(fmt.Sprintf("cc-%d", i)) + store.CreateTask(tk) + if err := pool.Submit(context.Background(), tk); err != nil { + t.Fatalf("submit %d: %v", i, err) + } + } + + for i := 0; i < 3; i++ { + result := <-pool.Results() + if result.Execution.Status != "COMPLETED" { + t.Errorf("task %s: want COMPLETED, got %q", result.TaskID, result.Execution.Status) + } + } + + if runner.callCount() != 3 { + t.Errorf("calls: want 3, got %d", runner.callCount()) + } +} |
