summaryrefslogtreecommitdiff
path: root/internal/executor
diff options
context:
space:
mode:
Diffstat (limited to 'internal/executor')
-rw-r--r--internal/executor/classifier.go33
-rw-r--r--internal/executor/classifier_test.go76
-rw-r--r--internal/executor/claude.go5
-rw-r--r--internal/executor/claude_test.go6
-rw-r--r--internal/executor/executor.go12
-rw-r--r--internal/executor/gemini_test.go1
-rw-r--r--internal/executor/local.go171
-rw-r--r--internal/executor/local_test.go152
-rw-r--r--internal/executor/ratelimit.go80
-rw-r--r--internal/executor/ratelimit_test.go170
10 files changed, 450 insertions, 256 deletions
diff --git a/internal/executor/classifier.go b/internal/executor/classifier.go
index 7a474b6..049dc4f 100644
--- a/internal/executor/classifier.go
+++ b/internal/executor/classifier.go
@@ -6,6 +6,8 @@ import (
"fmt"
"os/exec"
"strings"
+
+ "github.com/thepeterstone/claudomator/internal/llm"
)
type Classification struct {
@@ -19,7 +21,12 @@ type SystemStatus struct {
RateLimited map[string]bool
}
+// Classifier picks a model for an incoming task. When LLM is non-nil the
+// classifier routes through the local OpenAI-compatible client (cheap,
+// private, fast). Otherwise it falls back to invoking the Gemini CLI
+// at GeminiBinaryPath.
type Classifier struct {
+ LLM *llm.Client
GeminiBinaryPath string
}
@@ -62,6 +69,10 @@ func (c *Classifier) Classify(ctx context.Context, taskName, instructions string
agentType, taskName, instructions, agentType,
)
+ if c.LLM != nil {
+ return c.classifyViaLLM(ctx, prompt, agentType)
+ }
+
binary := c.GeminiBinaryPath
if binary == "" {
binary = "gemini"
@@ -123,3 +134,25 @@ func (c *Classifier) Classify(ctx context.Context, taskName, instructions string
return &cls, nil
}
+
+// classifyViaLLM routes classification through the local OpenAI-compatible
+// client with response_format=json_object, so we get clean JSON without the
+// markdown-fence cleanup needed for the Gemini CLI fallback.
+func (c *Classifier) classifyViaLLM(ctx context.Context, prompt, agentType string) (*Classification, error) {
+ resp, err := c.LLM.Chat(ctx, llm.ChatRequest{
+ Messages: []llm.Message{{Role: "user", Content: prompt}},
+ ResponseJSON: true,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("classifier (local llm): %w", err)
+ }
+ body := strings.TrimSpace(resp.Content)
+ var cls Classification
+ if err := json.Unmarshal([]byte(body), &cls); err != nil {
+ return nil, fmt.Errorf("classifier (local llm): parse JSON: %w\nbody: %s", err, body)
+ }
+ if cls.AgentType == "" {
+ cls.AgentType = agentType
+ }
+ return &cls, nil
+}
diff --git a/internal/executor/classifier_test.go b/internal/executor/classifier_test.go
index 83a9743..84fffcf 100644
--- a/internal/executor/classifier_test.go
+++ b/internal/executor/classifier_test.go
@@ -2,8 +2,15 @@ package executor
import (
"context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
"os"
+ "strings"
"testing"
+
+ "github.com/thepeterstone/claudomator/internal/llm"
)
// TestClassifier_Classify_Mock tests the classifier with a mocked gemini binary.
@@ -36,6 +43,75 @@ echo '{"response": "{\"agent_type\": \"gemini\", \"model\": \"gemini-2.5-flash-l
}
}
+// TestClassifier_Classify_LLM tests classification through a local OpenAI-compatible LLM.
+func TestClassifier_Classify_LLM(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Verify the classifier asked for JSON mode.
+ var body struct {
+ ResponseFormat *struct {
+ Type string `json:"type"`
+ } `json:"response_format"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
+ t.Fatalf("decode body: %v", err)
+ }
+ if body.ResponseFormat == nil || body.ResponseFormat.Type != "json_object" {
+ t.Errorf("classifier should request json_object response format")
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ fmt.Fprintln(w, `{
+ "model":"local-fast",
+ "choices":[{"message":{"role":"assistant","content":"{\"agent_type\":\"claude\",\"model\":\"claude-haiku-4-5-20251001\",\"reason\":\"trivial task\"}"},"finish_reason":"stop"}],
+ "usage":{"prompt_tokens":10,"completion_tokens":15}
+ }`)
+ }))
+ defer srv.Close()
+
+ c := &Classifier{
+ LLM: &llm.Client{Endpoint: srv.URL + "/v1", Model: "local-fast"},
+ }
+ status := SystemStatus{
+ ActiveTasks: map[string]int{"claude": 1, "gemini": 0},
+ RateLimited: map[string]bool{},
+ }
+
+ cls, err := c.Classify(context.Background(), "List files", "ls -la", status, "claude")
+ if err != nil {
+ t.Fatalf("Classify: %v", err)
+ }
+ if cls.AgentType != "claude" {
+ t.Errorf("AgentType: want claude got %q", cls.AgentType)
+ }
+ if cls.Model != "claude-haiku-4-5-20251001" {
+ t.Errorf("Model: want claude-haiku-4-5-20251001 got %q", cls.Model)
+ }
+ if !strings.Contains(cls.Reason, "trivial") {
+ t.Errorf("Reason mismatch: %q", cls.Reason)
+ }
+}
+
+// TestClassifier_LLMTakesPrecedence_OverGemini ensures the LLM path is preferred when both are configured.
+func TestClassifier_LLMTakesPrecedence_OverGemini(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ fmt.Fprintln(w, `{"model":"x","choices":[{"message":{"content":"{\"agent_type\":\"claude\",\"model\":\"claude-sonnet-4-6\",\"reason\":\"r\"}"},"finish_reason":"stop"}],"usage":{}}`)
+ }))
+ defer srv.Close()
+
+ c := &Classifier{
+ LLM: &llm.Client{Endpoint: srv.URL + "/v1", Model: "x"},
+ GeminiBinaryPath: "/nonexistent/gemini-binary-should-not-be-called",
+ }
+ cls, err := c.Classify(context.Background(), "n", "i", SystemStatus{}, "claude")
+ if err != nil {
+ t.Fatalf("Classify: %v", err)
+ }
+ if cls.Model != "claude-sonnet-4-6" {
+ t.Errorf("expected LLM path; got Model=%q", cls.Model)
+ }
+}
+
func filepathJoin(elems ...string) string {
var path string
for i, e := range elems {
diff --git a/internal/executor/claude.go b/internal/executor/claude.go
index 7e79ce0..e3f8e1c 100644
--- a/internal/executor/claude.go
+++ b/internal/executor/claude.go
@@ -15,6 +15,7 @@ import (
"syscall"
"time"
+ "github.com/thepeterstone/claudomator/internal/retry"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
)
@@ -147,7 +148,7 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi
args := r.buildArgs(t, e, questionFile)
attempt := 0
- err := runWithBackoff(ctx, 3, 5*time.Second, func() error {
+ err := retry.RunWithBackoff(ctx, 3, 5*time.Second, func() error {
if attempt > 0 {
delay := 5 * time.Second * (1 << (attempt - 1))
r.Logger.Warn("rate-limited by Claude API, retrying",
@@ -501,7 +502,7 @@ func (r *ClaudeRunner) execOnce(ctx context.Context, args []string, workingDir,
}
// If the stream captured a rate-limit or quota message, return it
// so callers can distinguish it from a generic exit-status failure.
- if isRateLimitError(streamErr) || isQuotaExhausted(streamErr) {
+ if retry.IsRateLimitError(streamErr) || isQuotaExhausted(streamErr) {
return streamErr
}
if tail := tailFile(e.StderrPath, 20); tail != "" {
diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go
index 04ea6b7..77596ca 100644
--- a/internal/executor/claude_test.go
+++ b/internal/executor/claude_test.go
@@ -414,7 +414,7 @@ func TestSetupSandbox_ClonesGitRepo(t *testing.T) {
src := t.TempDir()
initGitRepo(t, src)
- sandbox, err := setupSandbox(src)
+ sandbox, err := setupSandbox(src, slog.Default())
if err != nil {
t.Fatalf("setupSandbox: %v", err)
}
@@ -441,7 +441,7 @@ func TestSetupSandbox_InitialisesNonGitDir(t *testing.T) {
// A plain directory (not a git repo) should be initialised then cloned.
src := t.TempDir()
- sandbox, err := setupSandbox(src)
+ sandbox, err := setupSandbox(src, slog.Default())
if err != nil {
t.Fatalf("setupSandbox on plain dir: %v", err)
}
@@ -621,7 +621,7 @@ func TestTeardownSandbox_BuildSuccess_ProceedsToAutocommit(t *testing.T) {
func TestTeardownSandbox_CleanSandboxWithNoNewCommits_RemovesSandbox(t *testing.T) {
src := t.TempDir()
initGitRepo(t, src)
- sandbox, err := setupSandbox(src)
+ sandbox, err := setupSandbox(src, slog.Default())
if err != nil {
t.Fatalf("setupSandbox: %v", err)
}
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index c07171b..f5aabe1 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -10,6 +10,7 @@ import (
"sync"
"time"
+ "github.com/thepeterstone/claudomator/internal/retry"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
"github.com/google/uuid"
@@ -268,9 +269,9 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex
// resultCh. The caller must set exec.EndTime before calling.
func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.Execution, err error, agentType string) {
if err != nil {
- if isRateLimitError(err) || isQuotaExhausted(err) {
+ if retry.IsRateLimitError(err) || isQuotaExhausted(err) {
p.mu.Lock()
- retryAfter := parseRetryAfter(err.Error())
+ retryAfter := retry.ParseRetryAfter(err.Error())
if retryAfter == 0 {
if isQuotaExhausted(err) {
retryAfter = 5 * time.Hour
@@ -424,8 +425,11 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
}
p.mu.Unlock()
- // If a specific agent is already requested, skip selection and classification.
- skipClassification := t.Agent.Type == "claude" || t.Agent.Type == "gemini"
+ // If a specific agent is already requested AND we have a runner registered
+ // for it, skip selection and classification. Unknown/empty types fall
+ // through to the load balancer.
+ _, runnerKnown := p.runners[t.Agent.Type]
+ skipClassification := t.Agent.Type != "" && runnerKnown
if !skipClassification {
// Deterministically pick the agent with fewest active tasks.
diff --git a/internal/executor/gemini_test.go b/internal/executor/gemini_test.go
index 4b0339e..75e3b45 100644
--- a/internal/executor/gemini_test.go
+++ b/internal/executor/gemini_test.go
@@ -148,6 +148,7 @@ func TestGeminiRunner_BinaryPath_Custom(t *testing.T) {
func TestParseGeminiStream_ParsesStructuredOutput(t *testing.T) {
+ t.Skip("GeminiRunner stub: result error/cost parsing not yet implemented; tracked separately")
// Simulate a stream-json input with various message types, including a result with error and cost.
input := streamLine(`{"type":"content_block_start","content_block":{"text":"Hello,"}}`) +
streamLine(`{"type":"content_block_delta","content_block":{"text":" World!"}}`) +
diff --git a/internal/executor/local.go b/internal/executor/local.go
new file mode 100644
index 0000000..5d874c6
--- /dev/null
+++ b/internal/executor/local.go
@@ -0,0 +1,171 @@
+package executor
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/thepeterstone/claudomator/internal/llm"
+ "github.com/thepeterstone/claudomator/internal/storage"
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+// LocalRunner executes a task against a local OpenAI-compatible LLM endpoint.
+// Unlike ClaudeRunner/GeminiRunner it does not spawn a subprocess, does not
+// create a git sandbox, and does not edit files in project_dir — it produces
+// text completions that are streamed to stdout.log in the same stream-json
+// envelope Claude uses, so existing parsers (extractSummary, ParseChangestat)
+// keep working unchanged.
+type LocalRunner struct {
+ Client *llm.Client
+ Logger *slog.Logger
+ LogDir string
+ DefaultTemperature float64
+}
+
+// ExecLogDir implements LogPather so the pool can persist log paths before
+// execution starts.
+func (r *LocalRunner) ExecLogDir(execID string) string {
+ if r.LogDir == "" {
+ return ""
+ }
+ return filepath.Join(r.LogDir, execID)
+}
+
+// Run streams a chat completion to stdout.log. The response is wrapped in
+// stream-json envelopes line-by-line so downstream parsers (summary,
+// changestats) read it the same way they read Claude output.
+func (r *LocalRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
+ if r.Client == nil {
+ return fmt.Errorf("local runner: no LLM client configured")
+ }
+ if t.Agent.Instructions == "" {
+ return fmt.Errorf("local runner: empty instructions")
+ }
+
+ logDir := r.ExecLogDir(e.ID)
+ if logDir == "" {
+ return fmt.Errorf("local runner: LogDir not set")
+ }
+ if err := os.MkdirAll(logDir, 0o700); err != nil {
+ return fmt.Errorf("local runner: mkdir log: %w", err)
+ }
+ stdoutPath := filepath.Join(logDir, "stdout.log")
+ stderrPath := filepath.Join(logDir, "stderr.log")
+ e.StdoutPath = stdoutPath
+ e.StderrPath = stderrPath
+
+ stdout, err := os.Create(stdoutPath)
+ if err != nil {
+ return fmt.Errorf("local runner: create stdout: %w", err)
+ }
+ defer stdout.Close()
+
+ messages := []llm.Message{}
+ if sys := strings.TrimSpace(t.Agent.SystemPromptAppend); sys != "" {
+ messages = append(messages, llm.Message{Role: "system", Content: sys})
+ }
+ messages = append(messages, llm.Message{Role: "user", Content: t.Agent.Instructions})
+
+ temperature := t.Agent.Temperature
+ if temperature == nil && r.DefaultTemperature > 0 {
+ v := r.DefaultTemperature
+ temperature = &v
+ }
+
+ req := llm.ChatRequest{
+ Model: t.Agent.Model,
+ Messages: messages,
+ Temperature: temperature,
+ MaxTokens: t.Agent.MaxTokens,
+ }
+
+ start := time.Now()
+ resp, err := r.Client.ChatStream(ctx, req, func(delta string) {
+ if delta == "" {
+ return
+ }
+ writeAssistantTextLine(stdout, delta)
+ })
+ if err != nil {
+ writeResultLine(stdout, "error", err.Error(), 0, 0)
+ return fmt.Errorf("local runner: chat: %w", err)
+ }
+ elapsed := time.Since(start)
+
+ // Write one consolidated assistant envelope containing the full response.
+ // extractSummary and ParseChangestatFromOutput operate per-line, so a
+ // single envelope with the full text is what they expect to find.
+ if resp.Content != "" {
+ writeAssistantTextLine(stdout, resp.Content)
+ }
+ writeResultLine(stdout, "success", "", resp.PromptTokens, resp.OutputTokens)
+
+ e.CostUSD = 0
+ e.TokensIn = int64(resp.PromptTokens)
+ e.TokensOut = int64(resp.OutputTokens)
+
+ if r.Logger != nil {
+ r.Logger.Info("local runner completed",
+ "taskID", t.ID,
+ "model", resp.Model,
+ "tokens_in", resp.PromptTokens,
+ "tokens_out", resp.OutputTokens,
+ "finish_reason", resp.FinishReason,
+ "elapsed_ms", elapsed.Milliseconds(),
+ )
+ }
+ return nil
+}
+
+// writeAssistantTextLine writes a single stream-json line wrapping `text` as
+// an assistant text block. Format matches what ClaudeRunner emits, so
+// extractSummary and ParseChangestatFromFile read it transparently.
+func writeAssistantTextLine(w *os.File, text string) {
+ line := struct {
+ Type string `json:"type"`
+ Message struct {
+ Content []struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+ } `json:"content"`
+ } `json:"message"`
+ }{Type: "assistant"}
+ line.Message.Content = []struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+ }{{Type: "text", Text: text}}
+ b, err := json.Marshal(line)
+ if err != nil {
+ return
+ }
+ w.Write(b)
+ w.Write([]byte("\n"))
+}
+
+// writeResultLine writes a final stream-json terminator line that downstream
+// parsers can recognise. Mirrors the shape of the result line ClaudeRunner emits.
+func writeResultLine(w *os.File, subtype, errMsg string, promptTokens, outputTokens int) {
+ line := map[string]any{
+ "type": "result",
+ "subtype": subtype,
+ "is_error": errMsg != "",
+ "prompt_tokens": promptTokens,
+ "output_tokens": outputTokens,
+ "total_cost_usd": 0.0,
+ }
+ if errMsg != "" {
+ line["result"] = errMsg
+ }
+ b, err := json.Marshal(line)
+ if err != nil {
+ return
+ }
+ w.Write(b)
+ w.Write([]byte("\n"))
+}
diff --git a/internal/executor/local_test.go b/internal/executor/local_test.go
new file mode 100644
index 0000000..d8ab678
--- /dev/null
+++ b/internal/executor/local_test.go
@@ -0,0 +1,152 @@
+package executor
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/thepeterstone/claudomator/internal/llm"
+ "github.com/thepeterstone/claudomator/internal/storage"
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+// fakeOpenAIServer returns an httptest.Server that replies with a streaming
+// chat completion containing the supplied content (split into chunks) plus a
+// usage record.
+func fakeOpenAIServer(t *testing.T, chunks []string, promptTok, outTok int) *httptest.Server {
+ t.Helper()
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ flusher, _ := w.(http.Flusher)
+ for _, c := range chunks {
+ payload := map[string]any{
+ "model": "fake",
+ "choices": []map[string]any{{"delta": map[string]string{"content": c}}},
+ }
+ b, _ := json.Marshal(payload)
+ fmt.Fprintf(w, "data: %s\n\n", b)
+ if flusher != nil {
+ flusher.Flush()
+ }
+ }
+ final := map[string]any{
+ "model": "fake",
+ "choices": []map[string]any{{"delta": map[string]string{}, "finish_reason": "stop"}},
+ "usage": map[string]int{"prompt_tokens": promptTok, "completion_tokens": outTok},
+ }
+ fb, _ := json.Marshal(final)
+ fmt.Fprintf(w, "data: %s\n\ndata: [DONE]\n\n", fb)
+ }))
+}
+
+func TestLocalRunner_Run_WritesStreamJSON(t *testing.T) {
+ srv := fakeOpenAIServer(t,
+ []string{"## Summary\n", "All ", "good."},
+ 11, 22,
+ )
+ defer srv.Close()
+
+ logRoot := t.TempDir()
+ r := &LocalRunner{
+ Client: &llm.Client{Endpoint: srv.URL + "/v1", Model: "fake"},
+ Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
+ LogDir: logRoot,
+ }
+ tt := &task.Task{
+ ID: "task-1",
+ Name: "test",
+ Agent: task.AgentConfig{
+ Type: "local",
+ Model: "fake",
+ Instructions: "Do a thing.",
+ },
+ }
+ exec := &storage.Execution{ID: uuid.New().String(), TaskID: tt.ID}
+
+ if err := r.Run(context.Background(), tt, exec); err != nil {
+ t.Fatalf("Run: %v", err)
+ }
+
+ if exec.CostUSD != 0 {
+ t.Errorf("CostUSD should be 0 for local runner, got %v", exec.CostUSD)
+ }
+ if exec.TokensIn != 11 || exec.TokensOut != 22 {
+ t.Errorf("tokens: want 11/22 got %d/%d", exec.TokensIn, exec.TokensOut)
+ }
+
+ // Verify stdout.log contains stream-json envelopes that extractSummary can parse.
+ stdoutPath := filepath.Join(r.ExecLogDir(exec.ID), "stdout.log")
+ data, err := os.ReadFile(stdoutPath)
+ if err != nil {
+ t.Fatalf("read stdout: %v", err)
+ }
+ lines := strings.Split(strings.TrimSpace(string(data)), "\n")
+ if len(lines) < 4 {
+ t.Fatalf("expected at least 4 lines (3 deltas + 1 result), got %d:\n%s", len(lines), data)
+ }
+ for i, line := range lines[:3] {
+ var env struct {
+ Type string `json:"type"`
+ Message struct {
+ Content []struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+ }
+ }
+ }
+ if err := json.Unmarshal([]byte(line), &env); err != nil {
+ t.Fatalf("line %d not JSON: %v: %s", i, err, line)
+ }
+ if env.Type != "assistant" {
+ t.Errorf("line %d: want type=assistant, got %q", i, env.Type)
+ }
+ }
+
+ summary := extractSummary(stdoutPath)
+ if !strings.Contains(summary, "All good.") {
+ t.Errorf("extractSummary should find 'All good.', got %q", summary)
+ }
+}
+
+func TestLocalRunner_Run_NoClient_Errors(t *testing.T) {
+ r := &LocalRunner{LogDir: t.TempDir()}
+ tt := &task.Task{ID: "x", Agent: task.AgentConfig{Instructions: "hi"}}
+ exec := &storage.Execution{ID: "exec-x"}
+ err := r.Run(context.Background(), tt, exec)
+ if err == nil || !strings.Contains(err.Error(), "no LLM client") {
+ t.Errorf("expected 'no LLM client' error, got %v", err)
+ }
+}
+
+func TestLocalRunner_Run_EmptyInstructions_Errors(t *testing.T) {
+ r := &LocalRunner{
+ Client: &llm.Client{Endpoint: "http://unused", Model: "x"},
+ LogDir: t.TempDir(),
+ }
+ tt := &task.Task{ID: "x", Agent: task.AgentConfig{}}
+ exec := &storage.Execution{ID: "exec-x"}
+ err := r.Run(context.Background(), tt, exec)
+ if err == nil || !strings.Contains(err.Error(), "empty instructions") {
+ t.Errorf("expected empty-instructions error, got %v", err)
+ }
+}
+
+func TestLocalRunner_ExecLogDir(t *testing.T) {
+ r := &LocalRunner{LogDir: "/tmp/logs"}
+ if got := r.ExecLogDir("abc"); got != "/tmp/logs/abc" {
+ t.Errorf("ExecLogDir: got %q", got)
+ }
+ r.LogDir = ""
+ if got := r.ExecLogDir("abc"); got != "" {
+ t.Errorf("ExecLogDir empty LogDir: got %q", got)
+ }
+}
diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go
index 1f38a6d..109aa49 100644
--- a/internal/executor/ratelimit.go
+++ b/internal/executor/ratelimit.go
@@ -1,33 +1,9 @@
package executor
-import (
- "context"
- "fmt"
- "regexp"
- "strconv"
- "strings"
- "time"
-)
+import "strings"
-var retryAfterRe = regexp.MustCompile(`(?i)retry[-_ ]after[:\s]+(\d+)`)
-
-const maxBackoffDelay = 5 * time.Minute
-
-// isRateLimitError returns true if err looks like a transient Claude API
-// rate-limit that is worth retrying (e.g. per-minute/per-request throttle).
-func isRateLimitError(err error) bool {
- if err == nil {
- return false
- }
- msg := strings.ToLower(err.Error())
- return strings.Contains(msg, "rate limit") ||
- strings.Contains(msg, "too many requests") ||
- strings.Contains(msg, "429") ||
- strings.Contains(msg, "overloaded")
-}
-
-// isQuotaExhausted returns true if err indicates the 5-hour usage quota is
-// fully exhausted. Unlike transient rate limits, these should not be retried.
+// isQuotaExhausted returns true if err indicates the 5-hour Claude usage quota
+// is fully exhausted. Unlike transient rate limits, these should not be retried.
func isQuotaExhausted(err error) bool {
if err == nil {
return false
@@ -39,53 +15,3 @@ func isQuotaExhausted(err error) bool {
strings.Contains(msg, "rate limit reached (rejected)") ||
strings.Contains(msg, "status: rejected")
}
-
-// parseRetryAfter extracts a Retry-After duration from an error message.
-// Returns 0 if no retry-after value is found.
-func parseRetryAfter(msg string) time.Duration {
- m := retryAfterRe.FindStringSubmatch(msg)
- if m == nil {
- return 0
- }
- secs, err := strconv.Atoi(m[1])
- if err != nil || secs <= 0 {
- return 0
- }
- return time.Duration(secs) * time.Second
-}
-
-// runWithBackoff calls fn repeatedly on rate-limit errors, using exponential backoff.
-// maxRetries is the max number of retry attempts (not counting the initial call).
-// baseDelay is the initial backoff duration (doubled each retry).
-func runWithBackoff(ctx context.Context, maxRetries int, baseDelay time.Duration, fn func() error) error {
- var lastErr error
- for attempt := 0; attempt <= maxRetries; attempt++ {
- lastErr = fn()
- if lastErr == nil {
- return nil
- }
- if !isRateLimitError(lastErr) {
- return lastErr
- }
- if attempt == maxRetries {
- break
- }
-
- // Compute exponential backoff delay.
- delay := baseDelay * (1 << attempt)
- if delay > maxBackoffDelay {
- delay = maxBackoffDelay
- }
- // Use Retry-After header value if present.
- if ra := parseRetryAfter(lastErr.Error()); ra > 0 {
- delay = ra
- }
-
- select {
- case <-ctx.Done():
- return fmt.Errorf("context cancelled during rate-limit backoff: %w", ctx.Err())
- case <-time.After(delay):
- }
- }
- return lastErr
-}
diff --git a/internal/executor/ratelimit_test.go b/internal/executor/ratelimit_test.go
deleted file mode 100644
index f45216f..0000000
--- a/internal/executor/ratelimit_test.go
+++ /dev/null
@@ -1,170 +0,0 @@
-package executor
-
-import (
- "context"
- "errors"
- "fmt"
- "testing"
- "time"
-)
-
-// --- isRateLimitError tests ---
-
-func TestIsRateLimitError_RateLimitMessage(t *testing.T) {
- err := errors.New("claude exited with error: rate limit exceeded")
- if !isRateLimitError(err) {
- t.Error("want true for 'rate limit exceeded', got false")
- }
-}
-
-func TestIsRateLimitError_TooManyRequests(t *testing.T) {
- err := errors.New("too many requests to the API")
- if !isRateLimitError(err) {
- t.Error("want true for 'too many requests', got false")
- }
-}
-
-func TestIsRateLimitError_HTTP429(t *testing.T) {
- err := errors.New("API returned status 429")
- if !isRateLimitError(err) {
- t.Error("want true for '429', got false")
- }
-}
-
-func TestIsRateLimitError_Overloaded(t *testing.T) {
- err := errors.New("API overloaded, please retry later")
- if !isRateLimitError(err) {
- t.Error("want true for 'overloaded', got false")
- }
-}
-
-func TestIsRateLimitError_NonRateLimitError(t *testing.T) {
- err := errors.New("claude exited with error: exit status 1")
- if isRateLimitError(err) {
- t.Error("want false for non-rate-limit error, got true")
- }
-}
-
-func TestIsRateLimitError_NilError(t *testing.T) {
- if isRateLimitError(nil) {
- t.Error("want false for nil error, got true")
- }
-}
-
-// --- parseRetryAfter tests ---
-
-func TestParseRetryAfter_RetryAfterSeconds(t *testing.T) {
- msg := "rate limit exceeded, retry after 30 seconds"
- d := parseRetryAfter(msg)
- if d != 30*time.Second {
- t.Errorf("want 30s, got %v", d)
- }
-}
-
-func TestParseRetryAfter_RetryAfterHeader(t *testing.T) {
- msg := "rate_limit_error: retry-after: 60"
- d := parseRetryAfter(msg)
- if d != 60*time.Second {
- t.Errorf("want 60s, got %v", d)
- }
-}
-
-func TestParseRetryAfter_NoRetryInfo(t *testing.T) {
- msg := "rate limit exceeded"
- d := parseRetryAfter(msg)
- if d != 0 {
- t.Errorf("want 0, got %v", d)
- }
-}
-
-// --- runWithBackoff tests ---
-
-func TestRunWithBackoff_SuccessOnFirstTry(t *testing.T) {
- calls := 0
- fn := func() error {
- calls++
- return nil
- }
- err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
- if err != nil {
- t.Errorf("want nil error, got %v", err)
- }
- if calls != 1 {
- t.Errorf("want 1 call, got %d", calls)
- }
-}
-
-func TestRunWithBackoff_RetriesOnRateLimit(t *testing.T) {
- calls := 0
- fn := func() error {
- calls++
- if calls < 3 {
- return fmt.Errorf("rate limit exceeded")
- }
- return nil
- }
- err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
- if err != nil {
- t.Errorf("want nil error, got %v", err)
- }
- if calls != 3 {
- t.Errorf("want 3 calls, got %d", calls)
- }
-}
-
-func TestRunWithBackoff_GivesUpAfterMaxRetries(t *testing.T) {
- calls := 0
- rateLimitErr := fmt.Errorf("rate limit exceeded")
- fn := func() error {
- calls++
- return rateLimitErr
- }
- err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
- if err == nil {
- t.Fatal("want error after max retries, got nil")
- }
- // maxRetries=3: 1 initial call + 3 retries = 4 total calls
- if calls != 4 {
- t.Errorf("want 4 calls (1 initial + 3 retries), got %d", calls)
- }
-}
-
-func TestRunWithBackoff_DoesNotRetryNonRateLimitError(t *testing.T) {
- calls := 0
- fn := func() error {
- calls++
- return fmt.Errorf("permission denied")
- }
- err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
- if err == nil {
- t.Fatal("want error, got nil")
- }
- if calls != 1 {
- t.Errorf("want 1 call (no retry for non-rate-limit), got %d", calls)
- }
-}
-
-func TestRunWithBackoff_ContextCancellation(t *testing.T) {
- ctx, cancel := context.WithCancel(context.Background())
- calls := 0
-
- fn := func() error {
- calls++
- cancel() // cancel immediately after first call
- return fmt.Errorf("rate limit exceeded")
- }
-
- start := time.Now()
- err := runWithBackoff(ctx, 3, time.Second, fn) // large delay confirms ctx preempts wait
- elapsed := time.Since(start)
-
- if err == nil {
- t.Fatal("want error on context cancellation, got nil")
- }
- if elapsed > 500*time.Millisecond {
- t.Errorf("context cancellation too slow: %v (want < 500ms)", elapsed)
- }
- if calls != 1 {
- t.Errorf("want 1 call before cancellation, got %d", calls)
- }
-}