summaryrefslogtreecommitdiff
path: root/internal/executor/claude.go
blob: c845d58a1922a414da27953ff3911c4789d5c955 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package executor

import (
	"bufio"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"log/slog"
	"os"
	"os/exec"
	"path/filepath"

	"github.com/claudomator/claudomator/internal/storage"
	"github.com/claudomator/claudomator/internal/task"
)

// ClaudeRunner spawns the `claude` CLI in non-interactive mode.
type ClaudeRunner struct {
	BinaryPath string // defaults to "claude"
	Logger     *slog.Logger
	LogDir     string // base directory for execution logs
}

func (r *ClaudeRunner) binaryPath() string {
	if r.BinaryPath != "" {
		return r.BinaryPath
	}
	return "claude"
}

// Run executes a claude -p invocation, streaming output to log files.
func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
	args := r.buildArgs(t)

	cmd := exec.CommandContext(ctx, r.binaryPath(), args...)
	if t.Claude.WorkingDir != "" {
		cmd.Dir = t.Claude.WorkingDir
	}

	// Setup log directory for this execution.
	logDir := filepath.Join(r.LogDir, e.ID)
	if err := os.MkdirAll(logDir, 0700); err != nil {
		return fmt.Errorf("creating log dir: %w", err)
	}

	stdoutPath := filepath.Join(logDir, "stdout.log")
	stderrPath := filepath.Join(logDir, "stderr.log")
	e.StdoutPath = stdoutPath
	e.StderrPath = stderrPath
	e.ArtifactDir = logDir

	stdoutFile, err := os.Create(stdoutPath)
	if err != nil {
		return fmt.Errorf("creating stdout log: %w", err)
	}
	defer stdoutFile.Close()

	stderrFile, err := os.Create(stderrPath)
	if err != nil {
		return fmt.Errorf("creating stderr log: %w", err)
	}
	defer stderrFile.Close()

	stdoutPipe, err := cmd.StdoutPipe()
	if err != nil {
		return fmt.Errorf("creating stdout pipe: %w", err)
	}
	stderrPipe, err := cmd.StderrPipe()
	if err != nil {
		return fmt.Errorf("creating stderr pipe: %w", err)
	}

	if err := cmd.Start(); err != nil {
		return fmt.Errorf("starting claude: %w", err)
	}

	// Stream output to log files and parse cost info.
	var costUSD float64
	go func() {
		costUSD = streamAndParseCost(stdoutPipe, stdoutFile, r.Logger)
	}()
	go io.Copy(stderrFile, stderrPipe)

	if err := cmd.Wait(); err != nil {
		if exitErr, ok := err.(*exec.ExitError); ok {
			e.ExitCode = exitErr.ExitCode()
		}
		e.CostUSD = costUSD
		return fmt.Errorf("claude exited with error: %w", err)
	}

	e.ExitCode = 0
	e.CostUSD = costUSD
	return nil
}

func (r *ClaudeRunner) buildArgs(t *task.Task) []string {
	args := []string{
		"-p", t.Claude.Instructions,
		"--output-format", "stream-json",
	}

	if t.Claude.Model != "" {
		args = append(args, "--model", t.Claude.Model)
	}
	if t.Claude.MaxBudgetUSD > 0 {
		args = append(args, "--max-budget-usd", fmt.Sprintf("%.2f", t.Claude.MaxBudgetUSD))
	}
	if t.Claude.PermissionMode != "" {
		args = append(args, "--permission-mode", t.Claude.PermissionMode)
	}
	if t.Claude.SystemPromptAppend != "" {
		args = append(args, "--append-system-prompt", t.Claude.SystemPromptAppend)
	}
	for _, tool := range t.Claude.AllowedTools {
		args = append(args, "--allowedTools", tool)
	}
	for _, tool := range t.Claude.DisallowedTools {
		args = append(args, "--disallowedTools", tool)
	}
	for _, f := range t.Claude.ContextFiles {
		args = append(args, "--add-dir", f)
	}
	args = append(args, t.Claude.AdditionalArgs...)

	return args
}

// streamAndParseCost reads streaming JSON from claude and writes to the log file,
// extracting cost data from the stream.
func streamAndParseCost(r io.Reader, w io.Writer, logger *slog.Logger) float64 {
	tee := io.TeeReader(r, w)
	scanner := bufio.NewScanner(tee)
	scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines

	var totalCost float64
	for scanner.Scan() {
		line := scanner.Bytes()
		var msg map[string]interface{}
		if err := json.Unmarshal(line, &msg); err != nil {
			continue
		}
		// Extract cost from result messages.
		if costData, ok := msg["cost_usd"]; ok {
			if cost, ok := costData.(float64); ok {
				totalCost = cost
			}
		}
	}
	return totalCost
}