summaryrefslogtreecommitdiff
path: root/internal/executor
diff options
context:
space:
mode:
Diffstat (limited to 'internal/executor')
-rw-r--r--internal/executor/claude.go98
-rw-r--r--internal/executor/executor.go8
-rw-r--r--internal/executor/question.go172
-rw-r--r--internal/executor/question_test.go253
-rw-r--r--internal/executor/ratelimit.go76
-rw-r--r--internal/executor/ratelimit_test.go170
6 files changed, 749 insertions, 28 deletions
diff --git a/internal/executor/claude.go b/internal/executor/claude.go
index 7b3884c..0029331 100644
--- a/internal/executor/claude.go
+++ b/internal/executor/claude.go
@@ -10,6 +10,9 @@ import (
"os"
"os/exec"
"path/filepath"
+ "sync"
+ "syscall"
+ "time"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
@@ -31,71 +34,116 @@ func (r *ClaudeRunner) binaryPath() string {
}
// Run executes a claude -p invocation, streaming output to log files.
+// It retries up to 3 times on rate-limit errors using exponential backoff.
func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
args := r.buildArgs(t)
- cmd := exec.CommandContext(ctx, r.binaryPath(), args...)
- cmd.Env = append(os.Environ(),
- "CLAUDOMATOR_API_URL="+r.APIURL,
- "CLAUDOMATOR_TASK_ID="+t.ID,
- )
if t.Claude.WorkingDir != "" {
if _, err := os.Stat(t.Claude.WorkingDir); err != nil {
return fmt.Errorf("working_dir %q: %w", t.Claude.WorkingDir, err)
}
- cmd.Dir = t.Claude.WorkingDir
}
- // Setup log directory for this execution.
+ // Setup log directory once; retries overwrite the log files.
logDir := filepath.Join(r.LogDir, e.ID)
if err := os.MkdirAll(logDir, 0700); err != nil {
return fmt.Errorf("creating log dir: %w", err)
}
-
- stdoutPath := filepath.Join(logDir, "stdout.log")
- stderrPath := filepath.Join(logDir, "stderr.log")
- e.StdoutPath = stdoutPath
- e.StderrPath = stderrPath
+ e.StdoutPath = filepath.Join(logDir, "stdout.log")
+ e.StderrPath = filepath.Join(logDir, "stderr.log")
e.ArtifactDir = logDir
- stdoutFile, err := os.Create(stdoutPath)
+ attempt := 0
+ return runWithBackoff(ctx, 3, 5*time.Second, func() error {
+ if attempt > 0 {
+ delay := 5 * time.Second * (1 << (attempt - 1))
+ r.Logger.Warn("rate-limited by Claude API, retrying",
+ "attempt", attempt,
+ "delay", delay,
+ )
+ }
+ attempt++
+ return r.execOnce(ctx, t, args, e)
+ })
+}
+
+// execOnce runs the claude subprocess once, streaming output to e's log paths.
+func (r *ClaudeRunner) execOnce(ctx context.Context, t *task.Task, args []string, e *storage.Execution) error {
+ cmd := exec.CommandContext(ctx, r.binaryPath(), args...)
+ cmd.Env = append(os.Environ(),
+ "CLAUDOMATOR_API_URL="+r.APIURL,
+ "CLAUDOMATOR_TASK_ID="+t.ID,
+ )
+ // Put the subprocess in its own process group so we can SIGKILL the entire
+ // group (MCP servers, bash children, etc.) on cancellation.
+ cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+ if t.Claude.WorkingDir != "" {
+ cmd.Dir = t.Claude.WorkingDir
+ }
+
+ stdoutFile, err := os.Create(e.StdoutPath)
if err != nil {
return fmt.Errorf("creating stdout log: %w", err)
}
defer stdoutFile.Close()
- stderrFile, err := os.Create(stderrPath)
+ stderrFile, err := os.Create(e.StderrPath)
if err != nil {
return fmt.Errorf("creating stderr log: %w", err)
}
defer stderrFile.Close()
- stdoutPipe, err := cmd.StdoutPipe()
+ // Use os.Pipe for stdout so we own the read-end lifetime.
+ // cmd.StdoutPipe() would add the read-end to closeAfterWait, causing
+ // cmd.Wait() to close it before our goroutine finishes reading.
+ stdoutR, stdoutW, err := os.Pipe()
if err != nil {
return fmt.Errorf("creating stdout pipe: %w", err)
}
- stderrPipe, err := cmd.StderrPipe()
- if err != nil {
- return fmt.Errorf("creating stderr pipe: %w", err)
- }
+ cmd.Stdout = stdoutW // *os.File — not added to closeAfterStart/Wait
+ cmd.Stderr = stderrFile
if err := cmd.Start(); err != nil {
+ stdoutW.Close()
+ stdoutR.Close()
return fmt.Errorf("starting claude: %w", err)
}
+ // Close our write-end immediately; the subprocess holds its own copy.
+ // The goroutine below gets EOF when the subprocess exits.
+ stdoutW.Close()
+
+ // killDone is closed when cmd.Wait() returns, stopping the pgid-kill goroutine.
+ killDone := make(chan struct{})
+ go func() {
+ select {
+ case <-ctx.Done():
+ // SIGKILL the entire process group to reap orphan children.
+ syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
+ case <-killDone:
+ }
+ }()
- // Stream output to log files and parse cost info.
+ // Stream stdout to the log file and parse cost.
+ // wg ensures costUSD is fully written before we read it after cmd.Wait().
var costUSD float64
+ var wg sync.WaitGroup
+ wg.Add(1)
go func() {
- costUSD = streamAndParseCost(stdoutPipe, stdoutFile, r.Logger)
+ defer wg.Done()
+ costUSD = streamAndParseCost(stdoutR, stdoutFile, r.Logger)
+ stdoutR.Close()
}()
- go io.Copy(stderrFile, stderrPipe)
- if err := cmd.Wait(); err != nil {
- if exitErr, ok := err.(*exec.ExitError); ok {
+ waitErr := cmd.Wait()
+ close(killDone) // stop the pgid-kill goroutine
+ wg.Wait() // drain remaining stdout before reading costUSD
+
+ if waitErr != nil {
+ if exitErr, ok := waitErr.(*exec.ExitError); ok {
e.ExitCode = exitErr.ExitCode()
}
e.CostUSD = costUSD
- return fmt.Errorf("claude exited with error: %w", err)
+ return fmt.Errorf("claude exited with error: %w", waitErr)
}
e.ExitCode = 0
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index d25d3b4..51f468e 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -24,9 +24,10 @@ type Pool struct {
store *storage.DB
logger *slog.Logger
- mu sync.Mutex
- active int
- resultCh chan *Result
+ mu sync.Mutex
+ active int
+ resultCh chan *Result
+ Questions *QuestionRegistry
}
// Result is emitted when a task execution completes.
@@ -46,6 +47,7 @@ func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.L
store: store,
logger: logger,
resultCh: make(chan *Result, maxConcurrent*2),
+ Questions: NewQuestionRegistry(),
}
}
diff --git a/internal/executor/question.go b/internal/executor/question.go
new file mode 100644
index 0000000..9a2b55d
--- /dev/null
+++ b/internal/executor/question.go
@@ -0,0 +1,172 @@
+package executor
+
+import (
+ "bufio"
+ "encoding/json"
+ "io"
+ "log/slog"
+ "sync"
+)
+
+// QuestionHandler is called when an agent invokes AskUserQuestion.
+// Implementations should broadcast the question and block until an answer arrives.
+type QuestionHandler interface {
+ HandleQuestion(taskID, toolUseID string, input json.RawMessage) (string, error)
+}
+
+// PendingQuestion holds state for a question awaiting a user answer.
+type PendingQuestion struct {
+ TaskID string `json:"task_id"`
+ ToolUseID string `json:"tool_use_id"`
+ Input json.RawMessage `json:"input"`
+ AnswerCh chan string `json:"-"`
+}
+
+// QuestionRegistry tracks pending questions across running tasks.
+type QuestionRegistry struct {
+ mu sync.Mutex
+ questions map[string]*PendingQuestion // keyed by toolUseID
+}
+
+// NewQuestionRegistry creates a new registry.
+func NewQuestionRegistry() *QuestionRegistry {
+ return &QuestionRegistry{
+ questions: make(map[string]*PendingQuestion),
+ }
+}
+
+// Register adds a pending question and returns its answer channel.
+func (qr *QuestionRegistry) Register(taskID, toolUseID string, input json.RawMessage) chan string {
+ ch := make(chan string, 1)
+ qr.mu.Lock()
+ qr.questions[toolUseID] = &PendingQuestion{
+ TaskID: taskID,
+ ToolUseID: toolUseID,
+ Input: input,
+ AnswerCh: ch,
+ }
+ qr.mu.Unlock()
+ return ch
+}
+
+// Answer delivers an answer for a pending question. Returns false if no such question exists.
+func (qr *QuestionRegistry) Answer(toolUseID, answer string) bool {
+ qr.mu.Lock()
+ pq, ok := qr.questions[toolUseID]
+ if ok {
+ delete(qr.questions, toolUseID)
+ }
+ qr.mu.Unlock()
+ if !ok {
+ return false
+ }
+ pq.AnswerCh <- answer
+ return true
+}
+
+// Get returns a pending question by tool_use_id, or nil.
+func (qr *QuestionRegistry) Get(toolUseID string) *PendingQuestion {
+ qr.mu.Lock()
+ defer qr.mu.Unlock()
+ return qr.questions[toolUseID]
+}
+
+// PendingForTask returns all pending questions for a given task.
+func (qr *QuestionRegistry) PendingForTask(taskID string) []*PendingQuestion {
+ qr.mu.Lock()
+ defer qr.mu.Unlock()
+ var result []*PendingQuestion
+ for _, pq := range qr.questions {
+ if pq.TaskID == taskID {
+ result = append(result, pq)
+ }
+ }
+ return result
+}
+
+// Remove removes a question without answering it (e.g., on task cancellation).
+func (qr *QuestionRegistry) Remove(toolUseID string) {
+ qr.mu.Lock()
+ delete(qr.questions, toolUseID)
+ qr.mu.Unlock()
+}
+
+// extractAskUserQuestion parses a stream-json line and returns the tool_use_id and input
+// if the line is an assistant event containing an AskUserQuestion tool_use.
+func extractAskUserQuestion(line []byte) (string, json.RawMessage) {
+ var event struct {
+ Type string `json:"type"`
+ Message struct {
+ Content []struct {
+ Type string `json:"type"`
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Input json.RawMessage `json:"input"`
+ } `json:"content"`
+ } `json:"message"`
+ }
+ if err := json.Unmarshal(line, &event); err != nil {
+ return "", nil
+ }
+ if event.Type != "assistant" {
+ return "", nil
+ }
+ for _, block := range event.Message.Content {
+ if block.Type == "tool_use" && block.Name == "AskUserQuestion" {
+ return block.ID, block.Input
+ }
+ }
+ return "", nil
+}
+
+// streamAndParseWithQuestions reads streaming JSON, writes to w, parses cost,
+// and calls onQuestion for each detected AskUserQuestion tool_use.
+func streamAndParseWithQuestions(r io.Reader, w io.Writer, _ *slog.Logger, onQuestion func(string, json.RawMessage)) float64 {
+ tee := io.TeeReader(r, w)
+ scanner := bufio.NewScanner(tee)
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
+
+ var totalCost float64
+ for scanner.Scan() {
+ line := scanner.Bytes()
+
+ if toolUseID, input := extractAskUserQuestion(line); toolUseID != "" {
+ if onQuestion != nil {
+ onQuestion(toolUseID, input)
+ }
+ }
+
+ var msg map[string]interface{}
+ if err := json.Unmarshal(line, &msg); err != nil {
+ continue
+ }
+ if costData, ok := msg["cost_usd"]; ok {
+ if cost, ok := costData.(float64); ok {
+ totalCost = cost
+ }
+ }
+ }
+ return totalCost
+}
+
+// buildToolResultMessage builds a tool_result message to feed back to Claude
+// as the answer to an AskUserQuestion tool_use.
+func buildToolResultMessage(toolUseID, answer string) []byte {
+ answerJSON, _ := json.Marshal(map[string]interface{}{
+ "answers": map[string]string{"answer": answer},
+ })
+ msg := map[string]interface{}{
+ "message": map[string]interface{}{
+ "role": "user",
+ "content": []map[string]interface{}{
+ {
+ "type": "tool_result",
+ "tool_use_id": toolUseID,
+ "content": string(answerJSON),
+ },
+ },
+ },
+ }
+ result, _ := json.Marshal(msg)
+ return result
+}
diff --git a/internal/executor/question_test.go b/internal/executor/question_test.go
new file mode 100644
index 0000000..d0fbed9
--- /dev/null
+++ b/internal/executor/question_test.go
@@ -0,0 +1,253 @@
+package executor
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "log/slog"
+ "strings"
+ "testing"
+)
+
+func TestQuestionRegistry_RegisterAndAnswer(t *testing.T) {
+ qr := NewQuestionRegistry()
+
+ ch := qr.Register("task-1", "toolu_abc", json.RawMessage(`{"question":"color?"}`))
+
+ // Answer should unblock the channel.
+ go func() {
+ ok := qr.Answer("toolu_abc", "blue")
+ if !ok {
+ t.Error("Answer returned false, expected true")
+ }
+ }()
+
+ answer := <-ch
+ if answer != "blue" {
+ t.Errorf("want 'blue', got %q", answer)
+ }
+
+ // Question should be removed after answering.
+ if qr.Get("toolu_abc") != nil {
+ t.Error("question should be removed after answering")
+ }
+}
+
+func TestQuestionRegistry_AnswerUnknown(t *testing.T) {
+ qr := NewQuestionRegistry()
+ ok := qr.Answer("nonexistent", "anything")
+ if ok {
+ t.Error("expected false for unknown question")
+ }
+}
+
+func TestQuestionRegistry_PendingForTask(t *testing.T) {
+ qr := NewQuestionRegistry()
+ qr.Register("task-1", "toolu_1", json.RawMessage(`{}`))
+ qr.Register("task-1", "toolu_2", json.RawMessage(`{}`))
+ qr.Register("task-2", "toolu_3", json.RawMessage(`{}`))
+
+ pending := qr.PendingForTask("task-1")
+ if len(pending) != 2 {
+ t.Errorf("want 2 pending for task-1, got %d", len(pending))
+ }
+
+ pending2 := qr.PendingForTask("task-2")
+ if len(pending2) != 1 {
+ t.Errorf("want 1 pending for task-2, got %d", len(pending2))
+ }
+}
+
+func TestQuestionRegistry_Remove(t *testing.T) {
+ qr := NewQuestionRegistry()
+ qr.Register("task-1", "toolu_x", json.RawMessage(`{}`))
+ qr.Remove("toolu_x")
+ if qr.Get("toolu_x") != nil {
+ t.Error("question should be removed")
+ }
+}
+
+func TestExtractAskUserQuestion_DetectsQuestion(t *testing.T) {
+ // Simulate a stream-json assistant event containing an AskUserQuestion tool_use.
+ event := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_01ABC",
+ "name": "AskUserQuestion",
+ "input": map[string]interface{}{
+ "questions": []interface{}{
+ map[string]interface{}{
+ "question": "Which color?",
+ "header": "Color",
+ "options": []interface{}{
+ map[string]interface{}{"label": "red", "description": "Red color"},
+ map[string]interface{}{"label": "blue", "description": "Blue color"},
+ },
+ "multiSelect": false,
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "toolu_01ABC" {
+ t.Errorf("toolUseID: want 'toolu_01ABC', got %q", toolUseID)
+ }
+ if input == nil {
+ t.Fatal("input should not be nil")
+ }
+}
+
+func TestExtractAskUserQuestion_IgnoresOtherTools(t *testing.T) {
+ event := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_01XYZ",
+ "name": "Read",
+ "input": map[string]interface{}{"file_path": "/foo"},
+ },
+ },
+ },
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "" || input != nil {
+ t.Error("should not detect non-AskUserQuestion tool_use")
+ }
+}
+
+func TestExtractAskUserQuestion_IgnoresNonAssistant(t *testing.T) {
+ event := map[string]interface{}{
+ "type": "system",
+ "subtype": "init",
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "" || input != nil {
+ t.Error("should not detect from non-assistant events")
+ }
+}
+
+func TestStreamAndParseQuestions_DetectsQuestionAndCost(t *testing.T) {
+ // Build a stream with an assistant event containing AskUserQuestion and a result with cost.
+ assistantEvent := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_Q1",
+ "name": "AskUserQuestion",
+ "input": map[string]interface{}{
+ "questions": []interface{}{
+ map[string]interface{}{
+ "question": "Pick a number",
+ "header": "Num",
+ "options": []interface{}{
+ map[string]interface{}{"label": "1", "description": "One"},
+ map[string]interface{}{"label": "2", "description": "Two"},
+ },
+ "multiSelect": false,
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ resultEvent := map[string]interface{}{
+ "type": "result",
+ "cost_usd": 0.05,
+ }
+
+ var buf bytes.Buffer
+ json.NewEncoder(&buf).Encode(assistantEvent)
+ json.NewEncoder(&buf).Encode(resultEvent)
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ var questions []questionDetected
+ onQuestion := func(toolUseID string, input json.RawMessage) {
+ questions = append(questions, questionDetected{toolUseID, input})
+ }
+
+ cost := streamAndParseWithQuestions(strings.NewReader(buf.String()), io.Discard, logger, onQuestion)
+
+ if cost != 0.05 {
+ t.Errorf("cost: want 0.05, got %f", cost)
+ }
+ if len(questions) != 1 {
+ t.Fatalf("want 1 question detected, got %d", len(questions))
+ }
+ if questions[0].toolUseID != "toolu_Q1" {
+ t.Errorf("toolUseID: want 'toolu_Q1', got %q", questions[0].toolUseID)
+ }
+}
+
+type questionDetected struct {
+ toolUseID string
+ input json.RawMessage
+}
+
+func TestBuildToolResultMessage_Format(t *testing.T) {
+ msg := buildToolResultMessage("toolu_123", "blue")
+
+ var parsed map[string]interface{}
+ if err := json.Unmarshal(msg, &parsed); err != nil {
+ t.Fatalf("invalid JSON: %v", err)
+ }
+
+ // Should have type "user" with message containing tool_result
+ msgObj, ok := parsed["message"].(map[string]interface{})
+ if !ok {
+ t.Fatal("missing 'message' field")
+ }
+ content, ok := msgObj["content"].([]interface{})
+ if !ok || len(content) == 0 {
+ t.Fatal("missing content array")
+ }
+
+ block := content[0].(map[string]interface{})
+ if block["type"] != "tool_result" {
+ t.Errorf("type: want 'tool_result', got %v", block["type"])
+ }
+ if block["tool_use_id"] != "toolu_123" {
+ t.Errorf("tool_use_id: want 'toolu_123', got %v", block["tool_use_id"])
+ }
+
+ // The content should contain the answer JSON
+ resultContent, ok := block["content"].(string)
+ if !ok {
+ t.Fatal("content should be a string")
+ }
+ var answerData map[string]interface{}
+ if err := json.Unmarshal([]byte(resultContent), &answerData); err != nil {
+ t.Fatalf("answer content is not valid JSON: %v", err)
+ }
+ answers, ok := answerData["answers"].(map[string]interface{})
+ if !ok {
+ t.Fatal("missing answers in result content")
+ }
+ // At least one answer key should have the value "blue"
+ found := false
+ for _, v := range answers {
+ if v == "blue" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("expected 'blue' in answers, got %v", answers)
+ }
+}
diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go
new file mode 100644
index 0000000..884da43
--- /dev/null
+++ b/internal/executor/ratelimit.go
@@ -0,0 +1,76 @@
+package executor
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+)
+
+var retryAfterRe = regexp.MustCompile(`(?i)retry[-_ ]after[:\s]+(\d+)`)
+
+const maxBackoffDelay = 5 * time.Minute
+
+// isRateLimitError returns true if err looks like a Claude API rate-limit response.
+func isRateLimitError(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := strings.ToLower(err.Error())
+ return strings.Contains(msg, "rate limit") ||
+ strings.Contains(msg, "too many requests") ||
+ strings.Contains(msg, "429") ||
+ strings.Contains(msg, "overloaded")
+}
+
+// parseRetryAfter extracts a Retry-After duration from an error message.
+// Returns 0 if no retry-after value is found.
+func parseRetryAfter(msg string) time.Duration {
+ m := retryAfterRe.FindStringSubmatch(msg)
+ if m == nil {
+ return 0
+ }
+ secs, err := strconv.Atoi(m[1])
+ if err != nil || secs <= 0 {
+ return 0
+ }
+ return time.Duration(secs) * time.Second
+}
+
+// runWithBackoff calls fn repeatedly on rate-limit errors, using exponential backoff.
+// maxRetries is the max number of retry attempts (not counting the initial call).
+// baseDelay is the initial backoff duration (doubled each retry).
+func runWithBackoff(ctx context.Context, maxRetries int, baseDelay time.Duration, fn func() error) error {
+ var lastErr error
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ lastErr = fn()
+ if lastErr == nil {
+ return nil
+ }
+ if !isRateLimitError(lastErr) {
+ return lastErr
+ }
+ if attempt == maxRetries {
+ break
+ }
+
+ // Compute exponential backoff delay.
+ delay := baseDelay * (1 << attempt)
+ if delay > maxBackoffDelay {
+ delay = maxBackoffDelay
+ }
+ // Use Retry-After header value if present.
+ if ra := parseRetryAfter(lastErr.Error()); ra > 0 {
+ delay = ra
+ }
+
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("context cancelled during rate-limit backoff: %w", ctx.Err())
+ case <-time.After(delay):
+ }
+ }
+ return lastErr
+}
diff --git a/internal/executor/ratelimit_test.go b/internal/executor/ratelimit_test.go
new file mode 100644
index 0000000..f45216f
--- /dev/null
+++ b/internal/executor/ratelimit_test.go
@@ -0,0 +1,170 @@
+package executor
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+)
+
+// --- isRateLimitError tests ---
+
+func TestIsRateLimitError_RateLimitMessage(t *testing.T) {
+ err := errors.New("claude exited with error: rate limit exceeded")
+ if !isRateLimitError(err) {
+ t.Error("want true for 'rate limit exceeded', got false")
+ }
+}
+
+func TestIsRateLimitError_TooManyRequests(t *testing.T) {
+ err := errors.New("too many requests to the API")
+ if !isRateLimitError(err) {
+ t.Error("want true for 'too many requests', got false")
+ }
+}
+
+func TestIsRateLimitError_HTTP429(t *testing.T) {
+ err := errors.New("API returned status 429")
+ if !isRateLimitError(err) {
+ t.Error("want true for '429', got false")
+ }
+}
+
+func TestIsRateLimitError_Overloaded(t *testing.T) {
+ err := errors.New("API overloaded, please retry later")
+ if !isRateLimitError(err) {
+ t.Error("want true for 'overloaded', got false")
+ }
+}
+
+func TestIsRateLimitError_NonRateLimitError(t *testing.T) {
+ err := errors.New("claude exited with error: exit status 1")
+ if isRateLimitError(err) {
+ t.Error("want false for non-rate-limit error, got true")
+ }
+}
+
+func TestIsRateLimitError_NilError(t *testing.T) {
+ if isRateLimitError(nil) {
+ t.Error("want false for nil error, got true")
+ }
+}
+
+// --- parseRetryAfter tests ---
+
+func TestParseRetryAfter_RetryAfterSeconds(t *testing.T) {
+ msg := "rate limit exceeded, retry after 30 seconds"
+ d := parseRetryAfter(msg)
+ if d != 30*time.Second {
+ t.Errorf("want 30s, got %v", d)
+ }
+}
+
+func TestParseRetryAfter_RetryAfterHeader(t *testing.T) {
+ msg := "rate_limit_error: retry-after: 60"
+ d := parseRetryAfter(msg)
+ if d != 60*time.Second {
+ t.Errorf("want 60s, got %v", d)
+ }
+}
+
+func TestParseRetryAfter_NoRetryInfo(t *testing.T) {
+ msg := "rate limit exceeded"
+ d := parseRetryAfter(msg)
+ if d != 0 {
+ t.Errorf("want 0, got %v", d)
+ }
+}
+
+// --- runWithBackoff tests ---
+
+func TestRunWithBackoff_SuccessOnFirstTry(t *testing.T) {
+ calls := 0
+ fn := func() error {
+ calls++
+ return nil
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("want nil error, got %v", err)
+ }
+ if calls != 1 {
+ t.Errorf("want 1 call, got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_RetriesOnRateLimit(t *testing.T) {
+ calls := 0
+ fn := func() error {
+ calls++
+ if calls < 3 {
+ return fmt.Errorf("rate limit exceeded")
+ }
+ return nil
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("want nil error, got %v", err)
+ }
+ if calls != 3 {
+ t.Errorf("want 3 calls, got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_GivesUpAfterMaxRetries(t *testing.T) {
+ calls := 0
+ rateLimitErr := fmt.Errorf("rate limit exceeded")
+ fn := func() error {
+ calls++
+ return rateLimitErr
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err == nil {
+ t.Fatal("want error after max retries, got nil")
+ }
+ // maxRetries=3: 1 initial call + 3 retries = 4 total calls
+ if calls != 4 {
+ t.Errorf("want 4 calls (1 initial + 3 retries), got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_DoesNotRetryNonRateLimitError(t *testing.T) {
+ calls := 0
+ fn := func() error {
+ calls++
+ return fmt.Errorf("permission denied")
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err == nil {
+ t.Fatal("want error, got nil")
+ }
+ if calls != 1 {
+ t.Errorf("want 1 call (no retry for non-rate-limit), got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_ContextCancellation(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ calls := 0
+
+ fn := func() error {
+ calls++
+ cancel() // cancel immediately after first call
+ return fmt.Errorf("rate limit exceeded")
+ }
+
+ start := time.Now()
+ err := runWithBackoff(ctx, 3, time.Second, fn) // large delay confirms ctx preempts wait
+ elapsed := time.Since(start)
+
+ if err == nil {
+ t.Fatal("want error on context cancellation, got nil")
+ }
+ if elapsed > 500*time.Millisecond {
+ t.Errorf("context cancellation too slow: %v (want < 500ms)", elapsed)
+ }
+ if calls != 1 {
+ t.Errorf("want 1 call before cancellation, got %d", calls)
+ }
+}