summaryrefslogtreecommitdiff
path: root/internal/executor/claude.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/executor/claude.go')
-rw-r--r--internal/executor/claude.go152
1 files changed, 152 insertions, 0 deletions
diff --git a/internal/executor/claude.go b/internal/executor/claude.go
new file mode 100644
index 0000000..c845d58
--- /dev/null
+++ b/internal/executor/claude.go
@@ -0,0 +1,152 @@
+package executor
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "os"
+ "os/exec"
+ "path/filepath"
+
+ "github.com/claudomator/claudomator/internal/storage"
+ "github.com/claudomator/claudomator/internal/task"
+)
+
+// ClaudeRunner spawns the `claude` CLI in non-interactive mode.
+type ClaudeRunner struct {
+ BinaryPath string // defaults to "claude"
+ Logger *slog.Logger
+ LogDir string // base directory for execution logs
+}
+
+func (r *ClaudeRunner) binaryPath() string {
+ if r.BinaryPath != "" {
+ return r.BinaryPath
+ }
+ return "claude"
+}
+
+// Run executes a claude -p invocation, streaming output to log files.
+func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
+ args := r.buildArgs(t)
+
+ cmd := exec.CommandContext(ctx, r.binaryPath(), args...)
+ if t.Claude.WorkingDir != "" {
+ cmd.Dir = t.Claude.WorkingDir
+ }
+
+ // Setup log directory for this execution.
+ logDir := filepath.Join(r.LogDir, e.ID)
+ if err := os.MkdirAll(logDir, 0700); err != nil {
+ return fmt.Errorf("creating log dir: %w", err)
+ }
+
+ stdoutPath := filepath.Join(logDir, "stdout.log")
+ stderrPath := filepath.Join(logDir, "stderr.log")
+ e.StdoutPath = stdoutPath
+ e.StderrPath = stderrPath
+ e.ArtifactDir = logDir
+
+ stdoutFile, err := os.Create(stdoutPath)
+ if err != nil {
+ return fmt.Errorf("creating stdout log: %w", err)
+ }
+ defer stdoutFile.Close()
+
+ stderrFile, err := os.Create(stderrPath)
+ if err != nil {
+ return fmt.Errorf("creating stderr log: %w", err)
+ }
+ defer stderrFile.Close()
+
+ stdoutPipe, err := cmd.StdoutPipe()
+ if err != nil {
+ return fmt.Errorf("creating stdout pipe: %w", err)
+ }
+ stderrPipe, err := cmd.StderrPipe()
+ if err != nil {
+ return fmt.Errorf("creating stderr pipe: %w", err)
+ }
+
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("starting claude: %w", err)
+ }
+
+ // Stream output to log files and parse cost info.
+ var costUSD float64
+ go func() {
+ costUSD = streamAndParseCost(stdoutPipe, stdoutFile, r.Logger)
+ }()
+ go io.Copy(stderrFile, stderrPipe)
+
+ if err := cmd.Wait(); err != nil {
+ if exitErr, ok := err.(*exec.ExitError); ok {
+ e.ExitCode = exitErr.ExitCode()
+ }
+ e.CostUSD = costUSD
+ return fmt.Errorf("claude exited with error: %w", err)
+ }
+
+ e.ExitCode = 0
+ e.CostUSD = costUSD
+ return nil
+}
+
+func (r *ClaudeRunner) buildArgs(t *task.Task) []string {
+ args := []string{
+ "-p", t.Claude.Instructions,
+ "--output-format", "stream-json",
+ }
+
+ if t.Claude.Model != "" {
+ args = append(args, "--model", t.Claude.Model)
+ }
+ if t.Claude.MaxBudgetUSD > 0 {
+ args = append(args, "--max-budget-usd", fmt.Sprintf("%.2f", t.Claude.MaxBudgetUSD))
+ }
+ if t.Claude.PermissionMode != "" {
+ args = append(args, "--permission-mode", t.Claude.PermissionMode)
+ }
+ if t.Claude.SystemPromptAppend != "" {
+ args = append(args, "--append-system-prompt", t.Claude.SystemPromptAppend)
+ }
+ for _, tool := range t.Claude.AllowedTools {
+ args = append(args, "--allowedTools", tool)
+ }
+ for _, tool := range t.Claude.DisallowedTools {
+ args = append(args, "--disallowedTools", tool)
+ }
+ for _, f := range t.Claude.ContextFiles {
+ args = append(args, "--add-dir", f)
+ }
+ args = append(args, t.Claude.AdditionalArgs...)
+
+ return args
+}
+
+// streamAndParseCost reads streaming JSON from claude and writes to the log file,
+// extracting cost data from the stream.
+func streamAndParseCost(r io.Reader, w io.Writer, logger *slog.Logger) float64 {
+ tee := io.TeeReader(r, w)
+ scanner := bufio.NewScanner(tee)
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines
+
+ var totalCost float64
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ var msg map[string]interface{}
+ if err := json.Unmarshal(line, &msg); err != nil {
+ continue
+ }
+ // Extract cost from result messages.
+ if costData, ok := msg["cost_usd"]; ok {
+ if cost, ok := costData.(float64); ok {
+ totalCost = cost
+ }
+ }
+ }
+ return totalCost
+}