diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-05-01 22:14:37 -1000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-05-01 22:14:37 -1000 |
| commit | 99115d8158137083239c45e5a860b718ff4cefa1 (patch) | |
| tree | 1bf3bd0505eea79375c67af83c7c5fe8c0f274ff /internal/executor | |
| parent | c2aa026f6ce1c9e216b99d74f294fc133d5fcddd (diff) | |
| parent | 50f8fe8c1ff8b82e0bd399e5776e58bda3e57d1c (diff) | |
Merge pull request #1 from thepeterstone/claude/local-oss-model-agents-MEBqj
Local OSS models as a third runner (epic)
Diffstat (limited to 'internal/executor')
| -rw-r--r-- | internal/executor/classifier.go | 33 | ||||
| -rw-r--r-- | internal/executor/classifier_test.go | 76 | ||||
| -rw-r--r-- | internal/executor/claude.go | 5 | ||||
| -rw-r--r-- | internal/executor/claude_test.go | 6 | ||||
| -rw-r--r-- | internal/executor/executor.go | 19 | ||||
| -rw-r--r-- | internal/executor/executor_test.go | 17 | ||||
| -rw-r--r-- | internal/executor/gemini_test.go | 1 | ||||
| -rw-r--r-- | internal/executor/local.go | 171 | ||||
| -rw-r--r-- | internal/executor/local_test.go | 152 | ||||
| -rw-r--r-- | internal/executor/ratelimit.go | 80 | ||||
| -rw-r--r-- | internal/executor/ratelimit_test.go | 170 | ||||
| -rw-r--r-- | internal/executor/summary.go | 95 | ||||
| -rw-r--r-- | internal/executor/summary_synth_test.go | 241 |
13 files changed, 809 insertions, 257 deletions
diff --git a/internal/executor/classifier.go b/internal/executor/classifier.go index 7a474b6..049dc4f 100644 --- a/internal/executor/classifier.go +++ b/internal/executor/classifier.go @@ -6,6 +6,8 @@ import ( "fmt" "os/exec" "strings" + + "github.com/thepeterstone/claudomator/internal/llm" ) type Classification struct { @@ -19,7 +21,12 @@ type SystemStatus struct { RateLimited map[string]bool } +// Classifier picks a model for an incoming task. When LLM is non-nil the +// classifier routes through the local OpenAI-compatible client (cheap, +// private, fast). Otherwise it falls back to invoking the Gemini CLI +// at GeminiBinaryPath. type Classifier struct { + LLM *llm.Client GeminiBinaryPath string } @@ -62,6 +69,10 @@ func (c *Classifier) Classify(ctx context.Context, taskName, instructions string agentType, taskName, instructions, agentType, ) + if c.LLM != nil { + return c.classifyViaLLM(ctx, prompt, agentType) + } + binary := c.GeminiBinaryPath if binary == "" { binary = "gemini" @@ -123,3 +134,25 @@ func (c *Classifier) Classify(ctx context.Context, taskName, instructions string return &cls, nil } + +// classifyViaLLM routes classification through the local OpenAI-compatible +// client with response_format=json_object, so we get clean JSON without the +// markdown-fence cleanup needed for the Gemini CLI fallback. +func (c *Classifier) classifyViaLLM(ctx context.Context, prompt, agentType string) (*Classification, error) { + resp, err := c.LLM.Chat(ctx, llm.ChatRequest{ + Messages: []llm.Message{{Role: "user", Content: prompt}}, + ResponseJSON: true, + }) + if err != nil { + return nil, fmt.Errorf("classifier (local llm): %w", err) + } + body := strings.TrimSpace(resp.Content) + var cls Classification + if err := json.Unmarshal([]byte(body), &cls); err != nil { + return nil, fmt.Errorf("classifier (local llm): parse JSON: %w\nbody: %s", err, body) + } + if cls.AgentType == "" { + cls.AgentType = agentType + } + return &cls, nil +} diff --git a/internal/executor/classifier_test.go b/internal/executor/classifier_test.go index 83a9743..84fffcf 100644 --- a/internal/executor/classifier_test.go +++ b/internal/executor/classifier_test.go @@ -2,8 +2,15 @@ package executor import ( "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" "os" + "strings" "testing" + + "github.com/thepeterstone/claudomator/internal/llm" ) // TestClassifier_Classify_Mock tests the classifier with a mocked gemini binary. @@ -36,6 +43,75 @@ echo '{"response": "{\"agent_type\": \"gemini\", \"model\": \"gemini-2.5-flash-l } } +// TestClassifier_Classify_LLM tests classification through a local OpenAI-compatible LLM. +func TestClassifier_Classify_LLM(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the classifier asked for JSON mode. + var body struct { + ResponseFormat *struct { + Type string `json:"type"` + } `json:"response_format"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body.ResponseFormat == nil || body.ResponseFormat.Type != "json_object" { + t.Errorf("classifier should request json_object response format") + } + + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, `{ + "model":"local-fast", + "choices":[{"message":{"role":"assistant","content":"{\"agent_type\":\"claude\",\"model\":\"claude-haiku-4-5-20251001\",\"reason\":\"trivial task\"}"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":10,"completion_tokens":15} + }`) + })) + defer srv.Close() + + c := &Classifier{ + LLM: &llm.Client{Endpoint: srv.URL + "/v1", Model: "local-fast"}, + } + status := SystemStatus{ + ActiveTasks: map[string]int{"claude": 1, "gemini": 0}, + RateLimited: map[string]bool{}, + } + + cls, err := c.Classify(context.Background(), "List files", "ls -la", status, "claude") + if err != nil { + t.Fatalf("Classify: %v", err) + } + if cls.AgentType != "claude" { + t.Errorf("AgentType: want claude got %q", cls.AgentType) + } + if cls.Model != "claude-haiku-4-5-20251001" { + t.Errorf("Model: want claude-haiku-4-5-20251001 got %q", cls.Model) + } + if !strings.Contains(cls.Reason, "trivial") { + t.Errorf("Reason mismatch: %q", cls.Reason) + } +} + +// TestClassifier_LLMTakesPrecedence_OverGemini ensures the LLM path is preferred when both are configured. +func TestClassifier_LLMTakesPrecedence_OverGemini(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, `{"model":"x","choices":[{"message":{"content":"{\"agent_type\":\"claude\",\"model\":\"claude-sonnet-4-6\",\"reason\":\"r\"}"},"finish_reason":"stop"}],"usage":{}}`) + })) + defer srv.Close() + + c := &Classifier{ + LLM: &llm.Client{Endpoint: srv.URL + "/v1", Model: "x"}, + GeminiBinaryPath: "/nonexistent/gemini-binary-should-not-be-called", + } + cls, err := c.Classify(context.Background(), "n", "i", SystemStatus{}, "claude") + if err != nil { + t.Fatalf("Classify: %v", err) + } + if cls.Model != "claude-sonnet-4-6" { + t.Errorf("expected LLM path; got Model=%q", cls.Model) + } +} + func filepathJoin(elems ...string) string { var path string for i, e := range elems { diff --git a/internal/executor/claude.go b/internal/executor/claude.go index 7e79ce0..e3f8e1c 100644 --- a/internal/executor/claude.go +++ b/internal/executor/claude.go @@ -15,6 +15,7 @@ import ( "syscall" "time" + "github.com/thepeterstone/claudomator/internal/retry" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" ) @@ -147,7 +148,7 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi args := r.buildArgs(t, e, questionFile) attempt := 0 - err := runWithBackoff(ctx, 3, 5*time.Second, func() error { + err := retry.RunWithBackoff(ctx, 3, 5*time.Second, func() error { if attempt > 0 { delay := 5 * time.Second * (1 << (attempt - 1)) r.Logger.Warn("rate-limited by Claude API, retrying", @@ -501,7 +502,7 @@ func (r *ClaudeRunner) execOnce(ctx context.Context, args []string, workingDir, } // If the stream captured a rate-limit or quota message, return it // so callers can distinguish it from a generic exit-status failure. - if isRateLimitError(streamErr) || isQuotaExhausted(streamErr) { + if retry.IsRateLimitError(streamErr) || isQuotaExhausted(streamErr) { return streamErr } if tail := tailFile(e.StderrPath, 20); tail != "" { diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go index 04ea6b7..77596ca 100644 --- a/internal/executor/claude_test.go +++ b/internal/executor/claude_test.go @@ -414,7 +414,7 @@ func TestSetupSandbox_ClonesGitRepo(t *testing.T) { src := t.TempDir() initGitRepo(t, src) - sandbox, err := setupSandbox(src) + sandbox, err := setupSandbox(src, slog.Default()) if err != nil { t.Fatalf("setupSandbox: %v", err) } @@ -441,7 +441,7 @@ func TestSetupSandbox_InitialisesNonGitDir(t *testing.T) { // A plain directory (not a git repo) should be initialised then cloned. src := t.TempDir() - sandbox, err := setupSandbox(src) + sandbox, err := setupSandbox(src, slog.Default()) if err != nil { t.Fatalf("setupSandbox on plain dir: %v", err) } @@ -621,7 +621,7 @@ func TestTeardownSandbox_BuildSuccess_ProceedsToAutocommit(t *testing.T) { func TestTeardownSandbox_CleanSandboxWithNoNewCommits_RemovesSandbox(t *testing.T) { src := t.TempDir() initGitRepo(t, src) - sandbox, err := setupSandbox(src) + sandbox, err := setupSandbox(src, slog.Default()) if err != nil { t.Fatalf("setupSandbox: %v", err) } diff --git a/internal/executor/executor.go b/internal/executor/executor.go index c07171b..4501a3c 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -10,6 +10,8 @@ import ( "sync" "time" + "github.com/thepeterstone/claudomator/internal/llm" + "github.com/thepeterstone/claudomator/internal/retry" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" "github.com/google/uuid" @@ -69,6 +71,9 @@ type Pool struct { doneCh chan struct{} // signals when a worker slot is freed Questions *QuestionRegistry Classifier *Classifier + // LLM, when non-nil, enables LLM-synthesized summaries for executions + // whose stdout did not include a "## Summary" heading. + LLM *llm.Client } // Result is emitted when a task execution completes. @@ -268,9 +273,9 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex // resultCh. The caller must set exec.EndTime before calling. func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.Execution, err error, agentType string) { if err != nil { - if isRateLimitError(err) || isQuotaExhausted(err) { + if retry.IsRateLimitError(err) || isQuotaExhausted(err) { p.mu.Lock() - retryAfter := parseRetryAfter(err.Error()) + retryAfter := retry.ParseRetryAfter(err.Error()) if retryAfter == 0 { if isQuotaExhausted(err) { retryAfter = 5 * time.Hour @@ -348,6 +353,9 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. if summary == "" && exec.StdoutPath != "" { summary = extractSummary(exec.StdoutPath) } + if summary == "" && p.LLM != nil && exec.StdoutPath != "" { + summary = synthesizeSummary(ctx, p.LLM, exec.StdoutPath) + } if summary != "" { if summaryErr := p.store.UpdateTaskSummary(t.ID, summary); summaryErr != nil { p.logger.Error("failed to update task summary", "taskID", t.ID, "error", summaryErr) @@ -424,8 +432,11 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { } p.mu.Unlock() - // If a specific agent is already requested, skip selection and classification. - skipClassification := t.Agent.Type == "claude" || t.Agent.Type == "gemini" + // If a specific agent is already requested AND we have a runner registered + // for it, skip selection and classification. Unknown/empty types fall + // through to the load balancer. + _, runnerKnown := p.runners[t.Agent.Type] + skipClassification := t.Agent.Type != "" && runnerKnown if !skipClassification { // Deterministically pick the agent with fewest active tasks. diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index 878a32d..b1173cb 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -980,6 +980,7 @@ type minimalMockStore struct { executions map[string]*storage.Execution stateUpdates []struct{ id string; state task.State } questionUpdates []string + summaryUpdates []struct{ taskID, summary string } changestatCalls []struct { execID string stats *task.Changestats @@ -1035,7 +1036,21 @@ func (m *minimalMockStore) UpdateTaskQuestion(taskID, questionJSON string) error m.mu.Unlock() return nil } -func (m *minimalMockStore) UpdateTaskSummary(taskID, summary string) error { return nil } +func (m *minimalMockStore) UpdateTaskSummary(taskID, summary string) error { + m.mu.Lock() + m.summaryUpdates = append(m.summaryUpdates, struct{ taskID, summary string }{taskID, summary}) + m.mu.Unlock() + return nil +} +func (m *minimalMockStore) lastSummaryUpdate() (string, string, bool) { + m.mu.Lock() + defer m.mu.Unlock() + if len(m.summaryUpdates) == 0 { + return "", "", false + } + last := m.summaryUpdates[len(m.summaryUpdates)-1] + return last.taskID, last.summary, true +} func (m *minimalMockStore) AppendTaskInteraction(taskID string, _ task.Interaction) error { return nil } diff --git a/internal/executor/gemini_test.go b/internal/executor/gemini_test.go index 4b0339e..75e3b45 100644 --- a/internal/executor/gemini_test.go +++ b/internal/executor/gemini_test.go @@ -148,6 +148,7 @@ func TestGeminiRunner_BinaryPath_Custom(t *testing.T) { func TestParseGeminiStream_ParsesStructuredOutput(t *testing.T) { + t.Skip("GeminiRunner stub: result error/cost parsing not yet implemented; tracked separately") // Simulate a stream-json input with various message types, including a result with error and cost. input := streamLine(`{"type":"content_block_start","content_block":{"text":"Hello,"}}`) + streamLine(`{"type":"content_block_delta","content_block":{"text":" World!"}}`) + diff --git a/internal/executor/local.go b/internal/executor/local.go new file mode 100644 index 0000000..5d874c6 --- /dev/null +++ b/internal/executor/local.go @@ -0,0 +1,171 @@ +package executor + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "time" + + "github.com/thepeterstone/claudomator/internal/llm" + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +// LocalRunner executes a task against a local OpenAI-compatible LLM endpoint. +// Unlike ClaudeRunner/GeminiRunner it does not spawn a subprocess, does not +// create a git sandbox, and does not edit files in project_dir — it produces +// text completions that are streamed to stdout.log in the same stream-json +// envelope Claude uses, so existing parsers (extractSummary, ParseChangestat) +// keep working unchanged. +type LocalRunner struct { + Client *llm.Client + Logger *slog.Logger + LogDir string + DefaultTemperature float64 +} + +// ExecLogDir implements LogPather so the pool can persist log paths before +// execution starts. +func (r *LocalRunner) ExecLogDir(execID string) string { + if r.LogDir == "" { + return "" + } + return filepath.Join(r.LogDir, execID) +} + +// Run streams a chat completion to stdout.log. The response is wrapped in +// stream-json envelopes line-by-line so downstream parsers (summary, +// changestats) read it the same way they read Claude output. +func (r *LocalRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + if r.Client == nil { + return fmt.Errorf("local runner: no LLM client configured") + } + if t.Agent.Instructions == "" { + return fmt.Errorf("local runner: empty instructions") + } + + logDir := r.ExecLogDir(e.ID) + if logDir == "" { + return fmt.Errorf("local runner: LogDir not set") + } + if err := os.MkdirAll(logDir, 0o700); err != nil { + return fmt.Errorf("local runner: mkdir log: %w", err) + } + stdoutPath := filepath.Join(logDir, "stdout.log") + stderrPath := filepath.Join(logDir, "stderr.log") + e.StdoutPath = stdoutPath + e.StderrPath = stderrPath + + stdout, err := os.Create(stdoutPath) + if err != nil { + return fmt.Errorf("local runner: create stdout: %w", err) + } + defer stdout.Close() + + messages := []llm.Message{} + if sys := strings.TrimSpace(t.Agent.SystemPromptAppend); sys != "" { + messages = append(messages, llm.Message{Role: "system", Content: sys}) + } + messages = append(messages, llm.Message{Role: "user", Content: t.Agent.Instructions}) + + temperature := t.Agent.Temperature + if temperature == nil && r.DefaultTemperature > 0 { + v := r.DefaultTemperature + temperature = &v + } + + req := llm.ChatRequest{ + Model: t.Agent.Model, + Messages: messages, + Temperature: temperature, + MaxTokens: t.Agent.MaxTokens, + } + + start := time.Now() + resp, err := r.Client.ChatStream(ctx, req, func(delta string) { + if delta == "" { + return + } + writeAssistantTextLine(stdout, delta) + }) + if err != nil { + writeResultLine(stdout, "error", err.Error(), 0, 0) + return fmt.Errorf("local runner: chat: %w", err) + } + elapsed := time.Since(start) + + // Write one consolidated assistant envelope containing the full response. + // extractSummary and ParseChangestatFromOutput operate per-line, so a + // single envelope with the full text is what they expect to find. + if resp.Content != "" { + writeAssistantTextLine(stdout, resp.Content) + } + writeResultLine(stdout, "success", "", resp.PromptTokens, resp.OutputTokens) + + e.CostUSD = 0 + e.TokensIn = int64(resp.PromptTokens) + e.TokensOut = int64(resp.OutputTokens) + + if r.Logger != nil { + r.Logger.Info("local runner completed", + "taskID", t.ID, + "model", resp.Model, + "tokens_in", resp.PromptTokens, + "tokens_out", resp.OutputTokens, + "finish_reason", resp.FinishReason, + "elapsed_ms", elapsed.Milliseconds(), + ) + } + return nil +} + +// writeAssistantTextLine writes a single stream-json line wrapping `text` as +// an assistant text block. Format matches what ClaudeRunner emits, so +// extractSummary and ParseChangestatFromFile read it transparently. +func writeAssistantTextLine(w *os.File, text string) { + line := struct { + Type string `json:"type"` + Message struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + } `json:"message"` + }{Type: "assistant"} + line.Message.Content = []struct { + Type string `json:"type"` + Text string `json:"text"` + }{{Type: "text", Text: text}} + b, err := json.Marshal(line) + if err != nil { + return + } + w.Write(b) + w.Write([]byte("\n")) +} + +// writeResultLine writes a final stream-json terminator line that downstream +// parsers can recognise. Mirrors the shape of the result line ClaudeRunner emits. +func writeResultLine(w *os.File, subtype, errMsg string, promptTokens, outputTokens int) { + line := map[string]any{ + "type": "result", + "subtype": subtype, + "is_error": errMsg != "", + "prompt_tokens": promptTokens, + "output_tokens": outputTokens, + "total_cost_usd": 0.0, + } + if errMsg != "" { + line["result"] = errMsg + } + b, err := json.Marshal(line) + if err != nil { + return + } + w.Write(b) + w.Write([]byte("\n")) +} diff --git a/internal/executor/local_test.go b/internal/executor/local_test.go new file mode 100644 index 0000000..d8ab678 --- /dev/null +++ b/internal/executor/local_test.go @@ -0,0 +1,152 @@ +package executor + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/thepeterstone/claudomator/internal/llm" + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +// fakeOpenAIServer returns an httptest.Server that replies with a streaming +// chat completion containing the supplied content (split into chunks) plus a +// usage record. +func fakeOpenAIServer(t *testing.T, chunks []string, promptTok, outTok int) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + for _, c := range chunks { + payload := map[string]any{ + "model": "fake", + "choices": []map[string]any{{"delta": map[string]string{"content": c}}}, + } + b, _ := json.Marshal(payload) + fmt.Fprintf(w, "data: %s\n\n", b) + if flusher != nil { + flusher.Flush() + } + } + final := map[string]any{ + "model": "fake", + "choices": []map[string]any{{"delta": map[string]string{}, "finish_reason": "stop"}}, + "usage": map[string]int{"prompt_tokens": promptTok, "completion_tokens": outTok}, + } + fb, _ := json.Marshal(final) + fmt.Fprintf(w, "data: %s\n\ndata: [DONE]\n\n", fb) + })) +} + +func TestLocalRunner_Run_WritesStreamJSON(t *testing.T) { + srv := fakeOpenAIServer(t, + []string{"## Summary\n", "All ", "good."}, + 11, 22, + ) + defer srv.Close() + + logRoot := t.TempDir() + r := &LocalRunner{ + Client: &llm.Client{Endpoint: srv.URL + "/v1", Model: "fake"}, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + LogDir: logRoot, + } + tt := &task.Task{ + ID: "task-1", + Name: "test", + Agent: task.AgentConfig{ + Type: "local", + Model: "fake", + Instructions: "Do a thing.", + }, + } + exec := &storage.Execution{ID: uuid.New().String(), TaskID: tt.ID} + + if err := r.Run(context.Background(), tt, exec); err != nil { + t.Fatalf("Run: %v", err) + } + + if exec.CostUSD != 0 { + t.Errorf("CostUSD should be 0 for local runner, got %v", exec.CostUSD) + } + if exec.TokensIn != 11 || exec.TokensOut != 22 { + t.Errorf("tokens: want 11/22 got %d/%d", exec.TokensIn, exec.TokensOut) + } + + // Verify stdout.log contains stream-json envelopes that extractSummary can parse. + stdoutPath := filepath.Join(r.ExecLogDir(exec.ID), "stdout.log") + data, err := os.ReadFile(stdoutPath) + if err != nil { + t.Fatalf("read stdout: %v", err) + } + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + if len(lines) < 4 { + t.Fatalf("expected at least 4 lines (3 deltas + 1 result), got %d:\n%s", len(lines), data) + } + for i, line := range lines[:3] { + var env struct { + Type string `json:"type"` + Message struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } + } + } + if err := json.Unmarshal([]byte(line), &env); err != nil { + t.Fatalf("line %d not JSON: %v: %s", i, err, line) + } + if env.Type != "assistant" { + t.Errorf("line %d: want type=assistant, got %q", i, env.Type) + } + } + + summary := extractSummary(stdoutPath) + if !strings.Contains(summary, "All good.") { + t.Errorf("extractSummary should find 'All good.', got %q", summary) + } +} + +func TestLocalRunner_Run_NoClient_Errors(t *testing.T) { + r := &LocalRunner{LogDir: t.TempDir()} + tt := &task.Task{ID: "x", Agent: task.AgentConfig{Instructions: "hi"}} + exec := &storage.Execution{ID: "exec-x"} + err := r.Run(context.Background(), tt, exec) + if err == nil || !strings.Contains(err.Error(), "no LLM client") { + t.Errorf("expected 'no LLM client' error, got %v", err) + } +} + +func TestLocalRunner_Run_EmptyInstructions_Errors(t *testing.T) { + r := &LocalRunner{ + Client: &llm.Client{Endpoint: "http://unused", Model: "x"}, + LogDir: t.TempDir(), + } + tt := &task.Task{ID: "x", Agent: task.AgentConfig{}} + exec := &storage.Execution{ID: "exec-x"} + err := r.Run(context.Background(), tt, exec) + if err == nil || !strings.Contains(err.Error(), "empty instructions") { + t.Errorf("expected empty-instructions error, got %v", err) + } +} + +func TestLocalRunner_ExecLogDir(t *testing.T) { + r := &LocalRunner{LogDir: "/tmp/logs"} + if got := r.ExecLogDir("abc"); got != "/tmp/logs/abc" { + t.Errorf("ExecLogDir: got %q", got) + } + r.LogDir = "" + if got := r.ExecLogDir("abc"); got != "" { + t.Errorf("ExecLogDir empty LogDir: got %q", got) + } +} diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go index 1f38a6d..109aa49 100644 --- a/internal/executor/ratelimit.go +++ b/internal/executor/ratelimit.go @@ -1,33 +1,9 @@ package executor -import ( - "context" - "fmt" - "regexp" - "strconv" - "strings" - "time" -) +import "strings" -var retryAfterRe = regexp.MustCompile(`(?i)retry[-_ ]after[:\s]+(\d+)`) - -const maxBackoffDelay = 5 * time.Minute - -// isRateLimitError returns true if err looks like a transient Claude API -// rate-limit that is worth retrying (e.g. per-minute/per-request throttle). -func isRateLimitError(err error) bool { - if err == nil { - return false - } - msg := strings.ToLower(err.Error()) - return strings.Contains(msg, "rate limit") || - strings.Contains(msg, "too many requests") || - strings.Contains(msg, "429") || - strings.Contains(msg, "overloaded") -} - -// isQuotaExhausted returns true if err indicates the 5-hour usage quota is -// fully exhausted. Unlike transient rate limits, these should not be retried. +// isQuotaExhausted returns true if err indicates the 5-hour Claude usage quota +// is fully exhausted. Unlike transient rate limits, these should not be retried. func isQuotaExhausted(err error) bool { if err == nil { return false @@ -39,53 +15,3 @@ func isQuotaExhausted(err error) bool { strings.Contains(msg, "rate limit reached (rejected)") || strings.Contains(msg, "status: rejected") } - -// parseRetryAfter extracts a Retry-After duration from an error message. -// Returns 0 if no retry-after value is found. -func parseRetryAfter(msg string) time.Duration { - m := retryAfterRe.FindStringSubmatch(msg) - if m == nil { - return 0 - } - secs, err := strconv.Atoi(m[1]) - if err != nil || secs <= 0 { - return 0 - } - return time.Duration(secs) * time.Second -} - -// runWithBackoff calls fn repeatedly on rate-limit errors, using exponential backoff. -// maxRetries is the max number of retry attempts (not counting the initial call). -// baseDelay is the initial backoff duration (doubled each retry). -func runWithBackoff(ctx context.Context, maxRetries int, baseDelay time.Duration, fn func() error) error { - var lastErr error - for attempt := 0; attempt <= maxRetries; attempt++ { - lastErr = fn() - if lastErr == nil { - return nil - } - if !isRateLimitError(lastErr) { - return lastErr - } - if attempt == maxRetries { - break - } - - // Compute exponential backoff delay. - delay := baseDelay * (1 << attempt) - if delay > maxBackoffDelay { - delay = maxBackoffDelay - } - // Use Retry-After header value if present. - if ra := parseRetryAfter(lastErr.Error()); ra > 0 { - delay = ra - } - - select { - case <-ctx.Done(): - return fmt.Errorf("context cancelled during rate-limit backoff: %w", ctx.Err()) - case <-time.After(delay): - } - } - return lastErr -} diff --git a/internal/executor/ratelimit_test.go b/internal/executor/ratelimit_test.go deleted file mode 100644 index f45216f..0000000 --- a/internal/executor/ratelimit_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package executor - -import ( - "context" - "errors" - "fmt" - "testing" - "time" -) - -// --- isRateLimitError tests --- - -func TestIsRateLimitError_RateLimitMessage(t *testing.T) { - err := errors.New("claude exited with error: rate limit exceeded") - if !isRateLimitError(err) { - t.Error("want true for 'rate limit exceeded', got false") - } -} - -func TestIsRateLimitError_TooManyRequests(t *testing.T) { - err := errors.New("too many requests to the API") - if !isRateLimitError(err) { - t.Error("want true for 'too many requests', got false") - } -} - -func TestIsRateLimitError_HTTP429(t *testing.T) { - err := errors.New("API returned status 429") - if !isRateLimitError(err) { - t.Error("want true for '429', got false") - } -} - -func TestIsRateLimitError_Overloaded(t *testing.T) { - err := errors.New("API overloaded, please retry later") - if !isRateLimitError(err) { - t.Error("want true for 'overloaded', got false") - } -} - -func TestIsRateLimitError_NonRateLimitError(t *testing.T) { - err := errors.New("claude exited with error: exit status 1") - if isRateLimitError(err) { - t.Error("want false for non-rate-limit error, got true") - } -} - -func TestIsRateLimitError_NilError(t *testing.T) { - if isRateLimitError(nil) { - t.Error("want false for nil error, got true") - } -} - -// --- parseRetryAfter tests --- - -func TestParseRetryAfter_RetryAfterSeconds(t *testing.T) { - msg := "rate limit exceeded, retry after 30 seconds" - d := parseRetryAfter(msg) - if d != 30*time.Second { - t.Errorf("want 30s, got %v", d) - } -} - -func TestParseRetryAfter_RetryAfterHeader(t *testing.T) { - msg := "rate_limit_error: retry-after: 60" - d := parseRetryAfter(msg) - if d != 60*time.Second { - t.Errorf("want 60s, got %v", d) - } -} - -func TestParseRetryAfter_NoRetryInfo(t *testing.T) { - msg := "rate limit exceeded" - d := parseRetryAfter(msg) - if d != 0 { - t.Errorf("want 0, got %v", d) - } -} - -// --- runWithBackoff tests --- - -func TestRunWithBackoff_SuccessOnFirstTry(t *testing.T) { - calls := 0 - fn := func() error { - calls++ - return nil - } - err := runWithBackoff(context.Background(), 3, time.Millisecond, fn) - if err != nil { - t.Errorf("want nil error, got %v", err) - } - if calls != 1 { - t.Errorf("want 1 call, got %d", calls) - } -} - -func TestRunWithBackoff_RetriesOnRateLimit(t *testing.T) { - calls := 0 - fn := func() error { - calls++ - if calls < 3 { - return fmt.Errorf("rate limit exceeded") - } - return nil - } - err := runWithBackoff(context.Background(), 3, time.Millisecond, fn) - if err != nil { - t.Errorf("want nil error, got %v", err) - } - if calls != 3 { - t.Errorf("want 3 calls, got %d", calls) - } -} - -func TestRunWithBackoff_GivesUpAfterMaxRetries(t *testing.T) { - calls := 0 - rateLimitErr := fmt.Errorf("rate limit exceeded") - fn := func() error { - calls++ - return rateLimitErr - } - err := runWithBackoff(context.Background(), 3, time.Millisecond, fn) - if err == nil { - t.Fatal("want error after max retries, got nil") - } - // maxRetries=3: 1 initial call + 3 retries = 4 total calls - if calls != 4 { - t.Errorf("want 4 calls (1 initial + 3 retries), got %d", calls) - } -} - -func TestRunWithBackoff_DoesNotRetryNonRateLimitError(t *testing.T) { - calls := 0 - fn := func() error { - calls++ - return fmt.Errorf("permission denied") - } - err := runWithBackoff(context.Background(), 3, time.Millisecond, fn) - if err == nil { - t.Fatal("want error, got nil") - } - if calls != 1 { - t.Errorf("want 1 call (no retry for non-rate-limit), got %d", calls) - } -} - -func TestRunWithBackoff_ContextCancellation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - calls := 0 - - fn := func() error { - calls++ - cancel() // cancel immediately after first call - return fmt.Errorf("rate limit exceeded") - } - - start := time.Now() - err := runWithBackoff(ctx, 3, time.Second, fn) // large delay confirms ctx preempts wait - elapsed := time.Since(start) - - if err == nil { - t.Fatal("want error on context cancellation, got nil") - } - if elapsed > 500*time.Millisecond { - t.Errorf("context cancellation too slow: %v (want < 500ms)", elapsed) - } - if calls != 1 { - t.Errorf("want 1 call before cancellation, got %d", calls) - } -} diff --git a/internal/executor/summary.go b/internal/executor/summary.go index a942de0..bcf5cfd 100644 --- a/internal/executor/summary.go +++ b/internal/executor/summary.go @@ -2,11 +2,26 @@ package executor import ( "bufio" + "context" "encoding/json" + "io" "os" "strings" + "time" + + "github.com/thepeterstone/claudomator/internal/llm" ) +// synthesizeSummaryMaxBytes caps how much of the stdout log we send to the +// LLM. Larger values cost more tokens with diminishing returns for a 2-4 +// sentence summary. +const synthesizeSummaryMaxBytes = 16 * 1024 + +// synthesizeSummaryTimeout caps the LLM call so a slow local model can't +// stall executor finalization. On timeout, we return "" (the existing +// no-summary path takes over). +const synthesizeSummaryTimeout = 6 * time.Second + // extractSummary reads a stream-json stdout log and returns the text following // the last "## Summary" heading found in any assistant text block. // Returns empty string if the file cannot be read or no summary is found. @@ -28,6 +43,86 @@ func extractSummary(stdoutPath string) string { return last } +// synthesizeSummary asks the LLM to summarize the assistant text content in +// stdoutPath when no "## Summary" heading was present. Returns "" on any +// error, an empty file, or an empty model response — preserving the +// existing "no summary" behavior so the new path is purely additive. +func synthesizeSummary(parent context.Context, c *llm.Client, stdoutPath string) string { + if c == nil || stdoutPath == "" { + return "" + } + text := readAssistantTextTail(stdoutPath, synthesizeSummaryMaxBytes) + if strings.TrimSpace(text) == "" { + return "" + } + + cctx, cancel := context.WithTimeout(parent, synthesizeSummaryTimeout) + defer cancel() + resp, err := c.Chat(cctx, llm.ChatRequest{ + Messages: []llm.Message{ + {Role: "system", Content: "You summarize what an automated coding agent did. Reply with 2-4 sentences of plain prose. No bullets, no headings, no preamble."}, + {Role: "user", Content: "Here is the agent's output. Summarize what it accomplished:\n\n" + text}, + }, + }) + if err != nil { + return "" + } + return strings.TrimSpace(resp.Content) +} + +// readAssistantTextTail returns the concatenated `text` blocks from assistant +// stream-json events in the last maxBytes of the file. Non-assistant events +// (system, result, tool_use, etc.) are skipped so the LLM sees just what the +// agent said. Returns "" on any error. +func readAssistantTextTail(stdoutPath string, maxBytes int64) string { + f, err := os.Open(stdoutPath) + if err != nil { + return "" + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return "" + } + size := stat.Size() + if size > maxBytes { + if _, err := f.Seek(size-maxBytes, io.SeekStart); err != nil { + return "" + } + } + + var sb strings.Builder + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + first := size > maxBytes // if we seeked, drop the first (likely partial) line + for scanner.Scan() { + if first { + first = false + continue + } + var event struct { + Type string `json:"type"` + Message struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + } `json:"message"` + } + if err := json.Unmarshal(scanner.Bytes(), &event); err != nil || event.Type != "assistant" { + continue + } + for _, block := range event.Message.Content { + if block.Type == "text" && block.Text != "" { + sb.WriteString(block.Text) + sb.WriteString("\n") + } + } + } + return sb.String() +} + // summaryFromLine parses a single stream-json line and returns the text after // "## Summary" if the line is an assistant text block containing that heading. func summaryFromLine(line []byte) string { diff --git a/internal/executor/summary_synth_test.go b/internal/executor/summary_synth_test.go new file mode 100644 index 0000000..7ad396d --- /dev/null +++ b/internal/executor/summary_synth_test.go @@ -0,0 +1,241 @@ +package executor + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + + "github.com/thepeterstone/claudomator/internal/llm" + "github.com/thepeterstone/claudomator/internal/storage" +) + +func writeStreamLog(t *testing.T, lines []string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "stdout.log") + var sb strings.Builder + for _, l := range lines { + sb.WriteString(l) + sb.WriteString("\n") + } + if err := os.WriteFile(path, []byte(sb.String()), 0600); err != nil { + t.Fatal(err) + } + return path +} + +func TestSynthesizeSummary_NilClient(t *testing.T) { + got := synthesizeSummary(context.Background(), nil, "/some/path") + if got != "" { + t.Errorf("nil client: want empty, got %q", got) + } +} + +func TestSynthesizeSummary_EmptyPath(t *testing.T) { + c := &llm.Client{Endpoint: "http://unused", Model: "x"} + got := synthesizeSummary(context.Background(), c, "") + if got != "" { + t.Errorf("empty path: want empty, got %q", got) + } +} + +func TestSynthesizeSummary_MissingFile(t *testing.T) { + c := &llm.Client{Endpoint: "http://unused", Model: "x"} + got := synthesizeSummary(context.Background(), c, "/nonexistent/file.log") + if got != "" { + t.Errorf("missing file: want empty, got %q", got) + } +} + +func TestSynthesizeSummary_EmptyAssistantContent(t *testing.T) { + // Log contains only system/result events — no assistant text. The function + // should short-circuit without calling the LLM. + path := writeStreamLog(t, []string{ + `{"type":"system","subtype":"init"}`, + `{"type":"result","subtype":"success","total_cost_usd":0}`, + }) + + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, `{"choices":[{"message":{"content":"should not be returned"},"finish_reason":"stop"}],"usage":{}}`) + })) + defer srv.Close() + c := &llm.Client{Endpoint: srv.URL + "/v1", Model: "x"} + + got := synthesizeSummary(context.Background(), c, path) + if got != "" { + t.Errorf("empty content: want empty, got %q", got) + } + if atomic.LoadInt32(&calls) != 0 { + t.Errorf("LLM should not be called for empty assistant content") + } +} + +func TestSynthesizeSummary_LLMSuccess(t *testing.T) { + path := writeStreamLog(t, []string{ + `{"type":"assistant","message":{"content":[{"type":"text","text":"Ran the tests."}]}}`, + `{"type":"assistant","message":{"content":[{"type":"text","text":"Fixed the import."}]}}`, + `{"type":"result","subtype":"success"}`, + }) + + var capturedUser string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body struct { + Messages []struct { + Role, Content string + } `json:"messages"` + } + json.NewDecoder(r.Body).Decode(&body) + for _, m := range body.Messages { + if m.Role == "user" { + capturedUser = m.Content + } + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, `{"choices":[{"message":{"content":" Agent ran tests and fixed an import. "},"finish_reason":"stop"}],"usage":{}}`) + })) + defer srv.Close() + c := &llm.Client{Endpoint: srv.URL + "/v1", Model: "x"} + + got := synthesizeSummary(context.Background(), c, path) + if got != "Agent ran tests and fixed an import." { + t.Errorf("summary: got %q", got) + } + if !strings.Contains(capturedUser, "Ran the tests.") { + t.Errorf("user prompt missing first assistant text; got: %s", capturedUser) + } + if !strings.Contains(capturedUser, "Fixed the import.") { + t.Errorf("user prompt missing second assistant text; got: %s", capturedUser) + } +} + +// TestPool_HandleRunResult_LLMSummaryFallback verifies the Pool falls back to +// LLM-synthesized summary when extractSummary returns empty. +func TestPool_HandleRunResult_LLMSummaryFallback(t *testing.T) { + // stdout has assistant text but no "## Summary" heading. + stdoutPath := writeStreamLog(t, []string{ + `{"type":"assistant","message":{"content":[{"type":"text","text":"Did the work without writing a summary section."}]}}`, + }) + + llmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, `{"choices":[{"message":{"content":"Synthesized summary."},"finish_reason":"stop"}],"usage":{}}`) + })) + defer llmSrv.Close() + + store := newMinimalMockStore() + pool := newPoolWithMockStore(store) + pool.LLM = &llm.Client{Endpoint: llmSrv.URL + "/v1", Model: "x"} + + tk := makeTask("synth-summary") + store.tasks[tk.ID] = tk + exec := &storage.Execution{ID: "e-synth", TaskID: tk.ID, Status: "RUNNING", StdoutPath: stdoutPath} + + pool.handleRunResult(context.Background(), tk, exec, nil, "claude") + + id, summary, ok := store.lastSummaryUpdate() + if !ok { + t.Fatalf("expected UpdateTaskSummary to be called") + } + if id != tk.ID { + t.Errorf("summary recorded for wrong task: %q", id) + } + if summary != "Synthesized summary." { + t.Errorf("summary: got %q", summary) + } + + // Drain the result channel so the test exits cleanly. + <-pool.resultCh +} + +// TestPool_HandleRunResult_ExtractSummaryWins verifies the LLM is NOT called +// when the agent already wrote a "## Summary" section. +func TestPool_HandleRunResult_ExtractSummaryWins(t *testing.T) { + stdoutPath := writeStreamLog(t, []string{ + `{"type":"assistant","message":{"content":[{"type":"text","text":"## Summary\nAgent wrote its own summary."}]}}`, + }) + + var llmCalls int32 + llmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&llmCalls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, `{"choices":[{"message":{"content":"should not be used"},"finish_reason":"stop"}],"usage":{}}`) + })) + defer llmSrv.Close() + + store := newMinimalMockStore() + pool := newPoolWithMockStore(store) + pool.LLM = &llm.Client{Endpoint: llmSrv.URL + "/v1", Model: "x"} + + tk := makeTask("agent-summary") + store.tasks[tk.ID] = tk + exec := &storage.Execution{ID: "e-agent", TaskID: tk.ID, Status: "RUNNING", StdoutPath: stdoutPath} + + pool.handleRunResult(context.Background(), tk, exec, nil, "claude") + + if got := atomic.LoadInt32(&llmCalls); got != 0 { + t.Errorf("LLM should not be called when ## Summary is present; got %d calls", got) + } + _, summary, ok := store.lastSummaryUpdate() + if !ok { + t.Fatalf("expected UpdateTaskSummary") + } + if summary != "Agent wrote its own summary." { + t.Errorf("summary: got %q (want extractSummary output)", summary) + } + <-pool.resultCh +} + +func TestSynthesizeSummary_LLMFailure_ReturnsEmpty(t *testing.T) { + path := writeStreamLog(t, []string{ + `{"type":"assistant","message":{"content":[{"type":"text","text":"Did something."}]}}`, + }) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer srv.Close() + c := &llm.Client{Endpoint: srv.URL + "/v1", Model: "x"} + + got := synthesizeSummary(context.Background(), c, path) + if got != "" { + t.Errorf("LLM failure: want empty, got %q", got) + } +} + +// TestReadAssistantTextTail_TailingLargeFile verifies the seek-to-tail +// behavior drops early content but keeps later assistant text. +func TestReadAssistantTextTail_TailingLargeFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "stdout.log") + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + // Write a ton of garbage assistant lines, then a final marker. + for i := 0; i < 500; i++ { + fmt.Fprintf(f, `{"type":"assistant","message":{"content":[{"type":"text","text":"filler line that should be in the early part of a large file %04d"}]}}`+"\n", i) + } + fmt.Fprintln(f, `{"type":"assistant","message":{"content":[{"type":"text","text":"FINAL_MARKER_LINE"}]}}`) + f.Close() + + got := readAssistantTextTail(path, 4*1024) // 4 KB cap + if !strings.Contains(got, "FINAL_MARKER_LINE") { + t.Errorf("tail should contain final line; got: %s", got) + } + if strings.Contains(got, "filler line that should be in the early part of a large file 0000") { + end := 200 + if len(got) < end { + end = len(got) + } + t.Errorf("tail should NOT contain very-early line; got first 200 chars: %s", got[:end]) + } +} |
