summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/api/scripts.go50
-rw-r--r--internal/api/server.go52
-rw-r--r--internal/api/server_test.go140
-rw-r--r--internal/executor/claude.go98
-rw-r--r--internal/executor/executor.go8
-rw-r--r--internal/executor/question.go172
-rw-r--r--internal/executor/question_test.go253
-rw-r--r--internal/executor/ratelimit.go76
-rw-r--r--internal/executor/ratelimit_test.go170
9 files changed, 988 insertions, 31 deletions
diff --git a/internal/api/scripts.go b/internal/api/scripts.go
new file mode 100644
index 0000000..492570b
--- /dev/null
+++ b/internal/api/scripts.go
@@ -0,0 +1,50 @@
+package api
+
+import (
+ "bytes"
+ "context"
+ "net/http"
+ "os/exec"
+ "path/filepath"
+ "time"
+)
+
+const scriptTimeout = 30 * time.Second
+
+func (s *Server) startNextTaskScriptPath() string {
+ if s.startNextTaskScript != "" {
+ return s.startNextTaskScript
+ }
+ return filepath.Join(s.workDir, "scripts", "start-next-task")
+}
+
+func (s *Server) handleStartNextTask(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithTimeout(r.Context(), scriptTimeout)
+ defer cancel()
+
+ scriptPath := s.startNextTaskScriptPath()
+ cmd := exec.CommandContext(ctx, scriptPath)
+
+ var stdout, stderr bytes.Buffer
+ cmd.Stdout = &stdout
+ cmd.Stderr = &stderr
+
+ err := cmd.Run()
+ exitCode := 0
+ if err != nil {
+ if exitErr, ok := err.(*exec.ExitError); ok {
+ exitCode = exitErr.ExitCode()
+ } else {
+ s.logger.Error("start-next-task: script execution failed", "error", err, "path", scriptPath)
+ writeJSON(w, http.StatusInternalServerError, map[string]string{
+ "error": "script execution failed: " + err.Error(),
+ })
+ return
+ }
+ }
+
+ writeJSON(w, http.StatusOK, map[string]interface{}{
+ "output": stdout.String(),
+ "exit_code": exitCode,
+ })
+}
diff --git a/internal/api/server.go b/internal/api/server.go
index 608cdd4..bac98b6 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -24,9 +24,10 @@ type Server struct {
hub *Hub
logger *slog.Logger
mux *http.ServeMux
- claudeBinPath string // path to claude binary; defaults to "claude"
- elaborateCmdPath string // overrides claudeBinPath; used in tests
- workDir string // working directory injected into elaborate system prompt
+ claudeBinPath string // path to claude binary; defaults to "claude"
+ elaborateCmdPath string // overrides claudeBinPath; used in tests
+ startNextTaskScript string // path to start-next-task script; overridden in tests
+ workDir string // working directory injected into elaborate system prompt
}
func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath string) *Server {
@@ -71,6 +72,8 @@ func (s *Server) routes() {
s.mux.HandleFunc("GET /api/templates/{id}", s.handleGetTemplate)
s.mux.HandleFunc("PUT /api/templates/{id}", s.handleUpdateTemplate)
s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate)
+ s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion)
+ s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask)
s.mux.HandleFunc("GET /api/ws", s.handleWebSocket)
s.mux.HandleFunc("GET /api/health", s.handleHealth)
s.mux.Handle("GET /", http.FileServerFS(webui.Files))
@@ -93,6 +96,49 @@ func (s *Server) forwardResults() {
}
}
+// BroadcastQuestion sends a task_question event to all WebSocket clients.
+func (s *Server) BroadcastQuestion(taskID, toolUseID string, questionData json.RawMessage) {
+ event := map[string]interface{}{
+ "type": "task_question",
+ "task_id": taskID,
+ "question_id": toolUseID,
+ "data": json.RawMessage(questionData),
+ "timestamp": time.Now().UTC(),
+ }
+ data, _ := json.Marshal(event)
+ s.hub.Broadcast(data)
+}
+
+func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
+ taskID := r.PathValue("id")
+
+ if _, err := s.store.GetTask(taskID); err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
+ return
+ }
+
+ var input struct {
+ QuestionID string `json:"question_id"`
+ Answer string `json:"answer"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if input.QuestionID == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "question_id is required"})
+ return
+ }
+
+ ok := s.pool.Questions.Answer(input.QuestionID, input.Answer)
+ if !ok {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "no pending question with that ID"})
+ return
+ }
+
+ writeJSON(w, http.StatusOK, map[string]string{"status": "delivered"})
+}
+
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index 5094961..af93a77 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -335,3 +335,143 @@ func TestCORS_Headers(t *testing.T) {
t.Errorf("OPTIONS status: want 200, got %d", w.Code)
}
}
+
+func TestAnswerQuestion_NoTask_Returns404(t *testing.T) {
+ srv, _ := testServer(t)
+
+ payload := `{"question_id": "toolu_abc", "answer": "blue"}`
+ req := httptest.NewRequest("POST", "/api/tasks/nonexistent/answer", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusNotFound {
+ t.Errorf("status: want 404, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestAnswerQuestion_NoPendingQuestion_Returns404(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "answer-task-1", task.StatePending)
+
+ payload := `{"question_id": "toolu_nonexistent", "answer": "blue"}`
+ req := httptest.NewRequest("POST", "/api/tasks/answer-task-1/answer", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusNotFound {
+ t.Errorf("status: want 404, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["error"] != "no pending question with that ID" {
+ t.Errorf("error: want 'no pending question with that ID', got %q", body["error"])
+ }
+}
+
+func TestAnswerQuestion_WithPendingQuestion_Returns200(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "answer-task-2", task.StateRunning)
+
+ ch := srv.pool.Questions.Register("answer-task-2", "toolu_Q1", []byte(`{}`))
+
+ go func() {
+ payload := `{"question_id": "toolu_Q1", "answer": "red"}`
+ req := httptest.NewRequest("POST", "/api/tasks/answer-task-2/answer", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ }()
+
+ answer := <-ch
+ if answer != "red" {
+ t.Errorf("answer: want 'red', got %q", answer)
+ }
+}
+
+func TestAnswerQuestion_MissingQuestionID_Returns400(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "answer-task-3", task.StateRunning)
+
+ payload := `{"answer": "blue"}`
+ req := httptest.NewRequest("POST", "/api/tasks/answer-task-3/answer", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status: want 400, got %d", w.Code)
+ }
+}
+
+func TestHandleStartNextTask_Success(t *testing.T) {
+ dir := t.TempDir()
+ script := filepath.Join(dir, "start-next-task")
+ if err := os.WriteFile(script, []byte("#!/bin/sh\necho 'claudomator start abc-123'\n"), 0755); err != nil {
+ t.Fatal(err)
+ }
+
+ srv, _ := testServer(t)
+ srv.startNextTaskScript = script
+
+ req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]interface{}
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["output"] != "claudomator start abc-123\n" {
+ t.Errorf("unexpected output: %v", body["output"])
+ }
+ if body["exit_code"] != float64(0) {
+ t.Errorf("unexpected exit_code: %v", body["exit_code"])
+ }
+}
+
+func TestHandleStartNextTask_NoTask(t *testing.T) {
+ dir := t.TempDir()
+ script := filepath.Join(dir, "start-next-task")
+ if err := os.WriteFile(script, []byte("#!/bin/sh\necho 'No task to start.'\n"), 0755); err != nil {
+ t.Fatal(err)
+ }
+
+ srv, _ := testServer(t)
+ srv.startNextTaskScript = script
+
+ req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]interface{}
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["output"] != "No task to start.\n" {
+ t.Errorf("unexpected output: %v", body["output"])
+ }
+}
+
+func TestHandleStartNextTask_ScriptNotFound(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.startNextTaskScript = "/nonexistent/start-next-task"
+
+ req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("want 500, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
diff --git a/internal/executor/claude.go b/internal/executor/claude.go
index 7b3884c..0029331 100644
--- a/internal/executor/claude.go
+++ b/internal/executor/claude.go
@@ -10,6 +10,9 @@ import (
"os"
"os/exec"
"path/filepath"
+ "sync"
+ "syscall"
+ "time"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
@@ -31,71 +34,116 @@ func (r *ClaudeRunner) binaryPath() string {
}
// Run executes a claude -p invocation, streaming output to log files.
+// It retries up to 3 times on rate-limit errors using exponential backoff.
func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
args := r.buildArgs(t)
- cmd := exec.CommandContext(ctx, r.binaryPath(), args...)
- cmd.Env = append(os.Environ(),
- "CLAUDOMATOR_API_URL="+r.APIURL,
- "CLAUDOMATOR_TASK_ID="+t.ID,
- )
if t.Claude.WorkingDir != "" {
if _, err := os.Stat(t.Claude.WorkingDir); err != nil {
return fmt.Errorf("working_dir %q: %w", t.Claude.WorkingDir, err)
}
- cmd.Dir = t.Claude.WorkingDir
}
- // Setup log directory for this execution.
+ // Setup log directory once; retries overwrite the log files.
logDir := filepath.Join(r.LogDir, e.ID)
if err := os.MkdirAll(logDir, 0700); err != nil {
return fmt.Errorf("creating log dir: %w", err)
}
-
- stdoutPath := filepath.Join(logDir, "stdout.log")
- stderrPath := filepath.Join(logDir, "stderr.log")
- e.StdoutPath = stdoutPath
- e.StderrPath = stderrPath
+ e.StdoutPath = filepath.Join(logDir, "stdout.log")
+ e.StderrPath = filepath.Join(logDir, "stderr.log")
e.ArtifactDir = logDir
- stdoutFile, err := os.Create(stdoutPath)
+ attempt := 0
+ return runWithBackoff(ctx, 3, 5*time.Second, func() error {
+ if attempt > 0 {
+ delay := 5 * time.Second * (1 << (attempt - 1))
+ r.Logger.Warn("rate-limited by Claude API, retrying",
+ "attempt", attempt,
+ "delay", delay,
+ )
+ }
+ attempt++
+ return r.execOnce(ctx, t, args, e)
+ })
+}
+
+// execOnce runs the claude subprocess once, streaming output to e's log paths.
+func (r *ClaudeRunner) execOnce(ctx context.Context, t *task.Task, args []string, e *storage.Execution) error {
+ cmd := exec.CommandContext(ctx, r.binaryPath(), args...)
+ cmd.Env = append(os.Environ(),
+ "CLAUDOMATOR_API_URL="+r.APIURL,
+ "CLAUDOMATOR_TASK_ID="+t.ID,
+ )
+ // Put the subprocess in its own process group so we can SIGKILL the entire
+ // group (MCP servers, bash children, etc.) on cancellation.
+ cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+ if t.Claude.WorkingDir != "" {
+ cmd.Dir = t.Claude.WorkingDir
+ }
+
+ stdoutFile, err := os.Create(e.StdoutPath)
if err != nil {
return fmt.Errorf("creating stdout log: %w", err)
}
defer stdoutFile.Close()
- stderrFile, err := os.Create(stderrPath)
+ stderrFile, err := os.Create(e.StderrPath)
if err != nil {
return fmt.Errorf("creating stderr log: %w", err)
}
defer stderrFile.Close()
- stdoutPipe, err := cmd.StdoutPipe()
+ // Use os.Pipe for stdout so we own the read-end lifetime.
+ // cmd.StdoutPipe() would add the read-end to closeAfterWait, causing
+ // cmd.Wait() to close it before our goroutine finishes reading.
+ stdoutR, stdoutW, err := os.Pipe()
if err != nil {
return fmt.Errorf("creating stdout pipe: %w", err)
}
- stderrPipe, err := cmd.StderrPipe()
- if err != nil {
- return fmt.Errorf("creating stderr pipe: %w", err)
- }
+ cmd.Stdout = stdoutW // *os.File — not added to closeAfterStart/Wait
+ cmd.Stderr = stderrFile
if err := cmd.Start(); err != nil {
+ stdoutW.Close()
+ stdoutR.Close()
return fmt.Errorf("starting claude: %w", err)
}
+ // Close our write-end immediately; the subprocess holds its own copy.
+ // The goroutine below gets EOF when the subprocess exits.
+ stdoutW.Close()
+
+ // killDone is closed when cmd.Wait() returns, stopping the pgid-kill goroutine.
+ killDone := make(chan struct{})
+ go func() {
+ select {
+ case <-ctx.Done():
+ // SIGKILL the entire process group to reap orphan children.
+ syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
+ case <-killDone:
+ }
+ }()
- // Stream output to log files and parse cost info.
+ // Stream stdout to the log file and parse cost.
+ // wg ensures costUSD is fully written before we read it after cmd.Wait().
var costUSD float64
+ var wg sync.WaitGroup
+ wg.Add(1)
go func() {
- costUSD = streamAndParseCost(stdoutPipe, stdoutFile, r.Logger)
+ defer wg.Done()
+ costUSD = streamAndParseCost(stdoutR, stdoutFile, r.Logger)
+ stdoutR.Close()
}()
- go io.Copy(stderrFile, stderrPipe)
- if err := cmd.Wait(); err != nil {
- if exitErr, ok := err.(*exec.ExitError); ok {
+ waitErr := cmd.Wait()
+ close(killDone) // stop the pgid-kill goroutine
+ wg.Wait() // drain remaining stdout before reading costUSD
+
+ if waitErr != nil {
+ if exitErr, ok := waitErr.(*exec.ExitError); ok {
e.ExitCode = exitErr.ExitCode()
}
e.CostUSD = costUSD
- return fmt.Errorf("claude exited with error: %w", err)
+ return fmt.Errorf("claude exited with error: %w", waitErr)
}
e.ExitCode = 0
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index d25d3b4..51f468e 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -24,9 +24,10 @@ type Pool struct {
store *storage.DB
logger *slog.Logger
- mu sync.Mutex
- active int
- resultCh chan *Result
+ mu sync.Mutex
+ active int
+ resultCh chan *Result
+ Questions *QuestionRegistry
}
// Result is emitted when a task execution completes.
@@ -46,6 +47,7 @@ func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.L
store: store,
logger: logger,
resultCh: make(chan *Result, maxConcurrent*2),
+ Questions: NewQuestionRegistry(),
}
}
diff --git a/internal/executor/question.go b/internal/executor/question.go
new file mode 100644
index 0000000..9a2b55d
--- /dev/null
+++ b/internal/executor/question.go
@@ -0,0 +1,172 @@
+package executor
+
+import (
+ "bufio"
+ "encoding/json"
+ "io"
+ "log/slog"
+ "sync"
+)
+
+// QuestionHandler is called when an agent invokes AskUserQuestion.
+// Implementations should broadcast the question and block until an answer arrives.
+type QuestionHandler interface {
+ HandleQuestion(taskID, toolUseID string, input json.RawMessage) (string, error)
+}
+
+// PendingQuestion holds state for a question awaiting a user answer.
+type PendingQuestion struct {
+ TaskID string `json:"task_id"`
+ ToolUseID string `json:"tool_use_id"`
+ Input json.RawMessage `json:"input"`
+ AnswerCh chan string `json:"-"`
+}
+
+// QuestionRegistry tracks pending questions across running tasks.
+type QuestionRegistry struct {
+ mu sync.Mutex
+ questions map[string]*PendingQuestion // keyed by toolUseID
+}
+
+// NewQuestionRegistry creates a new registry.
+func NewQuestionRegistry() *QuestionRegistry {
+ return &QuestionRegistry{
+ questions: make(map[string]*PendingQuestion),
+ }
+}
+
+// Register adds a pending question and returns its answer channel.
+func (qr *QuestionRegistry) Register(taskID, toolUseID string, input json.RawMessage) chan string {
+ ch := make(chan string, 1)
+ qr.mu.Lock()
+ qr.questions[toolUseID] = &PendingQuestion{
+ TaskID: taskID,
+ ToolUseID: toolUseID,
+ Input: input,
+ AnswerCh: ch,
+ }
+ qr.mu.Unlock()
+ return ch
+}
+
+// Answer delivers an answer for a pending question. Returns false if no such question exists.
+func (qr *QuestionRegistry) Answer(toolUseID, answer string) bool {
+ qr.mu.Lock()
+ pq, ok := qr.questions[toolUseID]
+ if ok {
+ delete(qr.questions, toolUseID)
+ }
+ qr.mu.Unlock()
+ if !ok {
+ return false
+ }
+ pq.AnswerCh <- answer
+ return true
+}
+
+// Get returns a pending question by tool_use_id, or nil.
+func (qr *QuestionRegistry) Get(toolUseID string) *PendingQuestion {
+ qr.mu.Lock()
+ defer qr.mu.Unlock()
+ return qr.questions[toolUseID]
+}
+
+// PendingForTask returns all pending questions for a given task.
+func (qr *QuestionRegistry) PendingForTask(taskID string) []*PendingQuestion {
+ qr.mu.Lock()
+ defer qr.mu.Unlock()
+ var result []*PendingQuestion
+ for _, pq := range qr.questions {
+ if pq.TaskID == taskID {
+ result = append(result, pq)
+ }
+ }
+ return result
+}
+
+// Remove removes a question without answering it (e.g., on task cancellation).
+func (qr *QuestionRegistry) Remove(toolUseID string) {
+ qr.mu.Lock()
+ delete(qr.questions, toolUseID)
+ qr.mu.Unlock()
+}
+
+// extractAskUserQuestion parses a stream-json line and returns the tool_use_id and input
+// if the line is an assistant event containing an AskUserQuestion tool_use.
+func extractAskUserQuestion(line []byte) (string, json.RawMessage) {
+ var event struct {
+ Type string `json:"type"`
+ Message struct {
+ Content []struct {
+ Type string `json:"type"`
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Input json.RawMessage `json:"input"`
+ } `json:"content"`
+ } `json:"message"`
+ }
+ if err := json.Unmarshal(line, &event); err != nil {
+ return "", nil
+ }
+ if event.Type != "assistant" {
+ return "", nil
+ }
+ for _, block := range event.Message.Content {
+ if block.Type == "tool_use" && block.Name == "AskUserQuestion" {
+ return block.ID, block.Input
+ }
+ }
+ return "", nil
+}
+
+// streamAndParseWithQuestions reads streaming JSON, writes to w, parses cost,
+// and calls onQuestion for each detected AskUserQuestion tool_use.
+func streamAndParseWithQuestions(r io.Reader, w io.Writer, _ *slog.Logger, onQuestion func(string, json.RawMessage)) float64 {
+ tee := io.TeeReader(r, w)
+ scanner := bufio.NewScanner(tee)
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
+
+ var totalCost float64
+ for scanner.Scan() {
+ line := scanner.Bytes()
+
+ if toolUseID, input := extractAskUserQuestion(line); toolUseID != "" {
+ if onQuestion != nil {
+ onQuestion(toolUseID, input)
+ }
+ }
+
+ var msg map[string]interface{}
+ if err := json.Unmarshal(line, &msg); err != nil {
+ continue
+ }
+ if costData, ok := msg["cost_usd"]; ok {
+ if cost, ok := costData.(float64); ok {
+ totalCost = cost
+ }
+ }
+ }
+ return totalCost
+}
+
+// buildToolResultMessage builds a tool_result message to feed back to Claude
+// as the answer to an AskUserQuestion tool_use.
+func buildToolResultMessage(toolUseID, answer string) []byte {
+ answerJSON, _ := json.Marshal(map[string]interface{}{
+ "answers": map[string]string{"answer": answer},
+ })
+ msg := map[string]interface{}{
+ "message": map[string]interface{}{
+ "role": "user",
+ "content": []map[string]interface{}{
+ {
+ "type": "tool_result",
+ "tool_use_id": toolUseID,
+ "content": string(answerJSON),
+ },
+ },
+ },
+ }
+ result, _ := json.Marshal(msg)
+ return result
+}
diff --git a/internal/executor/question_test.go b/internal/executor/question_test.go
new file mode 100644
index 0000000..d0fbed9
--- /dev/null
+++ b/internal/executor/question_test.go
@@ -0,0 +1,253 @@
+package executor
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "log/slog"
+ "strings"
+ "testing"
+)
+
+func TestQuestionRegistry_RegisterAndAnswer(t *testing.T) {
+ qr := NewQuestionRegistry()
+
+ ch := qr.Register("task-1", "toolu_abc", json.RawMessage(`{"question":"color?"}`))
+
+ // Answer should unblock the channel.
+ go func() {
+ ok := qr.Answer("toolu_abc", "blue")
+ if !ok {
+ t.Error("Answer returned false, expected true")
+ }
+ }()
+
+ answer := <-ch
+ if answer != "blue" {
+ t.Errorf("want 'blue', got %q", answer)
+ }
+
+ // Question should be removed after answering.
+ if qr.Get("toolu_abc") != nil {
+ t.Error("question should be removed after answering")
+ }
+}
+
+func TestQuestionRegistry_AnswerUnknown(t *testing.T) {
+ qr := NewQuestionRegistry()
+ ok := qr.Answer("nonexistent", "anything")
+ if ok {
+ t.Error("expected false for unknown question")
+ }
+}
+
+func TestQuestionRegistry_PendingForTask(t *testing.T) {
+ qr := NewQuestionRegistry()
+ qr.Register("task-1", "toolu_1", json.RawMessage(`{}`))
+ qr.Register("task-1", "toolu_2", json.RawMessage(`{}`))
+ qr.Register("task-2", "toolu_3", json.RawMessage(`{}`))
+
+ pending := qr.PendingForTask("task-1")
+ if len(pending) != 2 {
+ t.Errorf("want 2 pending for task-1, got %d", len(pending))
+ }
+
+ pending2 := qr.PendingForTask("task-2")
+ if len(pending2) != 1 {
+ t.Errorf("want 1 pending for task-2, got %d", len(pending2))
+ }
+}
+
+func TestQuestionRegistry_Remove(t *testing.T) {
+ qr := NewQuestionRegistry()
+ qr.Register("task-1", "toolu_x", json.RawMessage(`{}`))
+ qr.Remove("toolu_x")
+ if qr.Get("toolu_x") != nil {
+ t.Error("question should be removed")
+ }
+}
+
+func TestExtractAskUserQuestion_DetectsQuestion(t *testing.T) {
+ // Simulate a stream-json assistant event containing an AskUserQuestion tool_use.
+ event := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_01ABC",
+ "name": "AskUserQuestion",
+ "input": map[string]interface{}{
+ "questions": []interface{}{
+ map[string]interface{}{
+ "question": "Which color?",
+ "header": "Color",
+ "options": []interface{}{
+ map[string]interface{}{"label": "red", "description": "Red color"},
+ map[string]interface{}{"label": "blue", "description": "Blue color"},
+ },
+ "multiSelect": false,
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "toolu_01ABC" {
+ t.Errorf("toolUseID: want 'toolu_01ABC', got %q", toolUseID)
+ }
+ if input == nil {
+ t.Fatal("input should not be nil")
+ }
+}
+
+func TestExtractAskUserQuestion_IgnoresOtherTools(t *testing.T) {
+ event := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_01XYZ",
+ "name": "Read",
+ "input": map[string]interface{}{"file_path": "/foo"},
+ },
+ },
+ },
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "" || input != nil {
+ t.Error("should not detect non-AskUserQuestion tool_use")
+ }
+}
+
+func TestExtractAskUserQuestion_IgnoresNonAssistant(t *testing.T) {
+ event := map[string]interface{}{
+ "type": "system",
+ "subtype": "init",
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "" || input != nil {
+ t.Error("should not detect from non-assistant events")
+ }
+}
+
+func TestStreamAndParseQuestions_DetectsQuestionAndCost(t *testing.T) {
+ // Build a stream with an assistant event containing AskUserQuestion and a result with cost.
+ assistantEvent := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_Q1",
+ "name": "AskUserQuestion",
+ "input": map[string]interface{}{
+ "questions": []interface{}{
+ map[string]interface{}{
+ "question": "Pick a number",
+ "header": "Num",
+ "options": []interface{}{
+ map[string]interface{}{"label": "1", "description": "One"},
+ map[string]interface{}{"label": "2", "description": "Two"},
+ },
+ "multiSelect": false,
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ resultEvent := map[string]interface{}{
+ "type": "result",
+ "cost_usd": 0.05,
+ }
+
+ var buf bytes.Buffer
+ json.NewEncoder(&buf).Encode(assistantEvent)
+ json.NewEncoder(&buf).Encode(resultEvent)
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ var questions []questionDetected
+ onQuestion := func(toolUseID string, input json.RawMessage) {
+ questions = append(questions, questionDetected{toolUseID, input})
+ }
+
+ cost := streamAndParseWithQuestions(strings.NewReader(buf.String()), io.Discard, logger, onQuestion)
+
+ if cost != 0.05 {
+ t.Errorf("cost: want 0.05, got %f", cost)
+ }
+ if len(questions) != 1 {
+ t.Fatalf("want 1 question detected, got %d", len(questions))
+ }
+ if questions[0].toolUseID != "toolu_Q1" {
+ t.Errorf("toolUseID: want 'toolu_Q1', got %q", questions[0].toolUseID)
+ }
+}
+
+type questionDetected struct {
+ toolUseID string
+ input json.RawMessage
+}
+
+func TestBuildToolResultMessage_Format(t *testing.T) {
+ msg := buildToolResultMessage("toolu_123", "blue")
+
+ var parsed map[string]interface{}
+ if err := json.Unmarshal(msg, &parsed); err != nil {
+ t.Fatalf("invalid JSON: %v", err)
+ }
+
+ // Should have type "user" with message containing tool_result
+ msgObj, ok := parsed["message"].(map[string]interface{})
+ if !ok {
+ t.Fatal("missing 'message' field")
+ }
+ content, ok := msgObj["content"].([]interface{})
+ if !ok || len(content) == 0 {
+ t.Fatal("missing content array")
+ }
+
+ block := content[0].(map[string]interface{})
+ if block["type"] != "tool_result" {
+ t.Errorf("type: want 'tool_result', got %v", block["type"])
+ }
+ if block["tool_use_id"] != "toolu_123" {
+ t.Errorf("tool_use_id: want 'toolu_123', got %v", block["tool_use_id"])
+ }
+
+ // The content should contain the answer JSON
+ resultContent, ok := block["content"].(string)
+ if !ok {
+ t.Fatal("content should be a string")
+ }
+ var answerData map[string]interface{}
+ if err := json.Unmarshal([]byte(resultContent), &answerData); err != nil {
+ t.Fatalf("answer content is not valid JSON: %v", err)
+ }
+ answers, ok := answerData["answers"].(map[string]interface{})
+ if !ok {
+ t.Fatal("missing answers in result content")
+ }
+ // At least one answer key should have the value "blue"
+ found := false
+ for _, v := range answers {
+ if v == "blue" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("expected 'blue' in answers, got %v", answers)
+ }
+}
diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go
new file mode 100644
index 0000000..884da43
--- /dev/null
+++ b/internal/executor/ratelimit.go
@@ -0,0 +1,76 @@
+package executor
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+)
+
+var retryAfterRe = regexp.MustCompile(`(?i)retry[-_ ]after[:\s]+(\d+)`)
+
+const maxBackoffDelay = 5 * time.Minute
+
+// isRateLimitError returns true if err looks like a Claude API rate-limit response.
+func isRateLimitError(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := strings.ToLower(err.Error())
+ return strings.Contains(msg, "rate limit") ||
+ strings.Contains(msg, "too many requests") ||
+ strings.Contains(msg, "429") ||
+ strings.Contains(msg, "overloaded")
+}
+
+// parseRetryAfter extracts a Retry-After duration from an error message.
+// Returns 0 if no retry-after value is found.
+func parseRetryAfter(msg string) time.Duration {
+ m := retryAfterRe.FindStringSubmatch(msg)
+ if m == nil {
+ return 0
+ }
+ secs, err := strconv.Atoi(m[1])
+ if err != nil || secs <= 0 {
+ return 0
+ }
+ return time.Duration(secs) * time.Second
+}
+
+// runWithBackoff calls fn repeatedly on rate-limit errors, using exponential backoff.
+// maxRetries is the max number of retry attempts (not counting the initial call).
+// baseDelay is the initial backoff duration (doubled each retry).
+func runWithBackoff(ctx context.Context, maxRetries int, baseDelay time.Duration, fn func() error) error {
+ var lastErr error
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ lastErr = fn()
+ if lastErr == nil {
+ return nil
+ }
+ if !isRateLimitError(lastErr) {
+ return lastErr
+ }
+ if attempt == maxRetries {
+ break
+ }
+
+ // Compute exponential backoff delay.
+ delay := baseDelay * (1 << attempt)
+ if delay > maxBackoffDelay {
+ delay = maxBackoffDelay
+ }
+ // Use Retry-After header value if present.
+ if ra := parseRetryAfter(lastErr.Error()); ra > 0 {
+ delay = ra
+ }
+
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("context cancelled during rate-limit backoff: %w", ctx.Err())
+ case <-time.After(delay):
+ }
+ }
+ return lastErr
+}
diff --git a/internal/executor/ratelimit_test.go b/internal/executor/ratelimit_test.go
new file mode 100644
index 0000000..f45216f
--- /dev/null
+++ b/internal/executor/ratelimit_test.go
@@ -0,0 +1,170 @@
+package executor
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+)
+
+// --- isRateLimitError tests ---
+
+func TestIsRateLimitError_RateLimitMessage(t *testing.T) {
+ err := errors.New("claude exited with error: rate limit exceeded")
+ if !isRateLimitError(err) {
+ t.Error("want true for 'rate limit exceeded', got false")
+ }
+}
+
+func TestIsRateLimitError_TooManyRequests(t *testing.T) {
+ err := errors.New("too many requests to the API")
+ if !isRateLimitError(err) {
+ t.Error("want true for 'too many requests', got false")
+ }
+}
+
+func TestIsRateLimitError_HTTP429(t *testing.T) {
+ err := errors.New("API returned status 429")
+ if !isRateLimitError(err) {
+ t.Error("want true for '429', got false")
+ }
+}
+
+func TestIsRateLimitError_Overloaded(t *testing.T) {
+ err := errors.New("API overloaded, please retry later")
+ if !isRateLimitError(err) {
+ t.Error("want true for 'overloaded', got false")
+ }
+}
+
+func TestIsRateLimitError_NonRateLimitError(t *testing.T) {
+ err := errors.New("claude exited with error: exit status 1")
+ if isRateLimitError(err) {
+ t.Error("want false for non-rate-limit error, got true")
+ }
+}
+
+func TestIsRateLimitError_NilError(t *testing.T) {
+ if isRateLimitError(nil) {
+ t.Error("want false for nil error, got true")
+ }
+}
+
+// --- parseRetryAfter tests ---
+
+func TestParseRetryAfter_RetryAfterSeconds(t *testing.T) {
+ msg := "rate limit exceeded, retry after 30 seconds"
+ d := parseRetryAfter(msg)
+ if d != 30*time.Second {
+ t.Errorf("want 30s, got %v", d)
+ }
+}
+
+func TestParseRetryAfter_RetryAfterHeader(t *testing.T) {
+ msg := "rate_limit_error: retry-after: 60"
+ d := parseRetryAfter(msg)
+ if d != 60*time.Second {
+ t.Errorf("want 60s, got %v", d)
+ }
+}
+
+func TestParseRetryAfter_NoRetryInfo(t *testing.T) {
+ msg := "rate limit exceeded"
+ d := parseRetryAfter(msg)
+ if d != 0 {
+ t.Errorf("want 0, got %v", d)
+ }
+}
+
+// --- runWithBackoff tests ---
+
+func TestRunWithBackoff_SuccessOnFirstTry(t *testing.T) {
+ calls := 0
+ fn := func() error {
+ calls++
+ return nil
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("want nil error, got %v", err)
+ }
+ if calls != 1 {
+ t.Errorf("want 1 call, got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_RetriesOnRateLimit(t *testing.T) {
+ calls := 0
+ fn := func() error {
+ calls++
+ if calls < 3 {
+ return fmt.Errorf("rate limit exceeded")
+ }
+ return nil
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("want nil error, got %v", err)
+ }
+ if calls != 3 {
+ t.Errorf("want 3 calls, got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_GivesUpAfterMaxRetries(t *testing.T) {
+ calls := 0
+ rateLimitErr := fmt.Errorf("rate limit exceeded")
+ fn := func() error {
+ calls++
+ return rateLimitErr
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err == nil {
+ t.Fatal("want error after max retries, got nil")
+ }
+ // maxRetries=3: 1 initial call + 3 retries = 4 total calls
+ if calls != 4 {
+ t.Errorf("want 4 calls (1 initial + 3 retries), got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_DoesNotRetryNonRateLimitError(t *testing.T) {
+ calls := 0
+ fn := func() error {
+ calls++
+ return fmt.Errorf("permission denied")
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err == nil {
+ t.Fatal("want error, got nil")
+ }
+ if calls != 1 {
+ t.Errorf("want 1 call (no retry for non-rate-limit), got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_ContextCancellation(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ calls := 0
+
+ fn := func() error {
+ calls++
+ cancel() // cancel immediately after first call
+ return fmt.Errorf("rate limit exceeded")
+ }
+
+ start := time.Now()
+ err := runWithBackoff(ctx, 3, time.Second, fn) // large delay confirms ctx preempts wait
+ elapsed := time.Since(start)
+
+ if err == nil {
+ t.Fatal("want error on context cancellation, got nil")
+ }
+ if elapsed > 500*time.Millisecond {
+ t.Errorf("context cancellation too slow: %v (want < 500ms)", elapsed)
+ }
+ if calls != 1 {
+ t.Errorf("want 1 call before cancellation, got %d", calls)
+ }
+}