diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/executor/executor.go | 92 | ||||
| -rw-r--r-- | internal/executor/executor_test.go | 190 |
2 files changed, 200 insertions, 82 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go index f445ef3..8924830 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -246,6 +246,14 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() + p.handleRunResult(ctx, t, exec, err, agentType) +} + +// handleRunResult applies the shared post-run error-classification and +// state-update logic used by both execute() and executeResume(). It sets +// exec.Status and exec.ErrorMsg, updates storage, and emits the result to +// resultCh. The caller must set exec.EndTime before calling. +func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.Execution, err error, agentType string) { if err != nil { if isRateLimitError(err) || isQuotaExhausted(err) { p.mu.Lock() @@ -323,7 +331,7 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex } if updateErr := p.store.UpdateExecution(exec); updateErr != nil { - p.logger.Error("failed to update resume execution", "error", updateErr) + p.logger.Error("failed to update execution", "error", updateErr) } p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} } @@ -493,87 +501,7 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() - if err != nil { - if isRateLimitError(err) || isQuotaExhausted(err) { - p.mu.Lock() - retryAfter := parseRetryAfter(err.Error()) - if retryAfter == 0 { - if isQuotaExhausted(err) { - retryAfter = 5 * time.Hour - } else { - retryAfter = 1 * time.Minute - } - } - p.rateLimited[agentType] = time.Now().Add(retryAfter) - p.logger.Info("agent rate limited", "agent", agentType, "retryAfter", retryAfter, "quotaExhausted", isQuotaExhausted(err)) - p.mu.Unlock() - } - - var blockedErr *BlockedError - if errors.As(err, &blockedErr) { - exec.Status = "BLOCKED" - if err := p.store.UpdateTaskState(t.ID, task.StateBlocked); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateBlocked, "error", err) - } - if err := p.store.UpdateTaskQuestion(t.ID, blockedErr.QuestionJSON); err != nil { - p.logger.Error("failed to update task question", "taskID", t.ID, "error", err) - } - } else if ctx.Err() == context.DeadlineExceeded { - exec.Status = "TIMED_OUT" - exec.ErrorMsg = "execution timed out" - if err := p.store.UpdateTaskState(t.ID, task.StateTimedOut); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateTimedOut, "error", err) - } - } else if ctx.Err() == context.Canceled { - exec.Status = "CANCELLED" - exec.ErrorMsg = "execution cancelled" - if err := p.store.UpdateTaskState(t.ID, task.StateCancelled); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateCancelled, "error", err) - } - } else if isQuotaExhausted(err) { - exec.Status = "BUDGET_EXCEEDED" - exec.ErrorMsg = err.Error() - if err := p.store.UpdateTaskState(t.ID, task.StateBudgetExceeded); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateBudgetExceeded, "error", err) - } - } else { - exec.Status = "FAILED" - exec.ErrorMsg = err.Error() - if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err) - } - } - } else { - if t.ParentTaskID == "" { - subtasks, subErr := p.store.ListSubtasks(t.ID) - if subErr != nil { - p.logger.Error("failed to list subtasks", "taskID", t.ID, "error", subErr) - } - if subErr == nil && len(subtasks) > 0 { - exec.Status = "BLOCKED" - if err := p.store.UpdateTaskState(t.ID, task.StateBlocked); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateBlocked, "error", err) - } - } else { - exec.Status = "READY" - if err := p.store.UpdateTaskState(t.ID, task.StateReady); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateReady, "error", err) - } - } - } else { - exec.Status = "COMPLETED" - if err := p.store.UpdateTaskState(t.ID, task.StateCompleted); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateCompleted, "error", err) - } - p.maybeUnblockParent(t.ParentTaskID) - } - } - - if updateErr := p.store.UpdateExecution(exec); updateErr != nil { - p.logger.Error("failed to update execution", "error", updateErr) - } - - p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} + p.handleRunResult(ctx, t, exec, err, agentType) } // RecoverStaleRunning marks any tasks stuck in RUNNING state (from a previous diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index 9896ba1..0935545 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -805,6 +805,196 @@ func TestPool_Submit_ParentNotBlocked_NoTransition(t *testing.T) { } } +// minimalMockStore is a standalone Store implementation for unit-testing Pool +// methods that do not require a real SQLite database. +type minimalMockStore struct { + mu sync.Mutex + tasks map[string]*task.Task + executions map[string]*storage.Execution + stateUpdates []struct{ id string; state task.State } + questionUpdates []string + subtasksFunc func(parentID string) ([]*task.Task, error) + updateExecErr error + updateStateErr error +} + +func newMinimalMockStore() *minimalMockStore { + return &minimalMockStore{ + tasks: make(map[string]*task.Task), + executions: make(map[string]*storage.Execution), + } +} + +func (m *minimalMockStore) GetTask(id string) (*task.Task, error) { + m.mu.Lock() + defer m.mu.Unlock() + t, ok := m.tasks[id] + if !ok { + return nil, fmt.Errorf("task %q not found", id) + } + return t, nil +} +func (m *minimalMockStore) ListTasks(_ storage.TaskFilter) ([]*task.Task, error) { return nil, nil } +func (m *minimalMockStore) ListSubtasks(parentID string) ([]*task.Task, error) { + if m.subtasksFunc != nil { + return m.subtasksFunc(parentID) + } + return nil, nil +} +func (m *minimalMockStore) ListExecutions(_ string) ([]*storage.Execution, error) { return nil, nil } +func (m *minimalMockStore) CreateExecution(e *storage.Execution) error { return nil } +func (m *minimalMockStore) UpdateExecution(e *storage.Execution) error { + return m.updateExecErr +} +func (m *minimalMockStore) UpdateTaskState(id string, newState task.State) error { + if m.updateStateErr != nil { + return m.updateStateErr + } + m.mu.Lock() + m.stateUpdates = append(m.stateUpdates, struct{ id string; state task.State }{id, newState}) + if t, ok := m.tasks[id]; ok { + t.State = newState + } + m.mu.Unlock() + return nil +} +func (m *minimalMockStore) UpdateTaskQuestion(taskID, questionJSON string) error { + m.mu.Lock() + m.questionUpdates = append(m.questionUpdates, questionJSON) + m.mu.Unlock() + return nil +} + +func (m *minimalMockStore) lastStateUpdate() (string, task.State, bool) { + m.mu.Lock() + defer m.mu.Unlock() + if len(m.stateUpdates) == 0 { + return "", "", false + } + u := m.stateUpdates[len(m.stateUpdates)-1] + return u.id, u.state, true +} + +func newPoolWithMockStore(store Store) *Pool { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + return &Pool{ + maxConcurrent: 2, + runners: map[string]Runner{"claude": &mockRunner{}}, + store: store, + logger: logger, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), + cancels: make(map[string]context.CancelFunc), + resultCh: make(chan *Result, 4), + workCh: make(chan workItem, 4), + doneCh: make(chan struct{}, 2), + Questions: NewQuestionRegistry(), + } +} + +// TestHandleRunResult_SharedPath verifies that handleRunResult correctly +// classifies runner errors and transitions task state via the store. +func TestHandleRunResult_SharedPath(t *testing.T) { + t.Run("generic error sets FAILED", func(t *testing.T) { + store := newMinimalMockStore() + pool := newPoolWithMockStore(store) + tk := makeTask("hrr-fail") + store.tasks[tk.ID] = tk + + exec := &storage.Execution{ID: "e1", TaskID: tk.ID, Status: "RUNNING"} + ctx := context.Background() + + pool.handleRunResult(ctx, tk, exec, fmt.Errorf("something broke"), "claude") + + if exec.Status != "FAILED" { + t.Errorf("exec.Status: want FAILED, got %q", exec.Status) + } + if exec.ErrorMsg != "something broke" { + t.Errorf("exec.ErrorMsg: want %q, got %q", "something broke", exec.ErrorMsg) + } + _, state, ok := store.lastStateUpdate() + if !ok || state != task.StateFailed { + t.Errorf("expected UpdateTaskState(FAILED), got state=%v ok=%v", state, ok) + } + result := <-pool.resultCh + if result.Err == nil || result.Execution.Status != "FAILED" { + t.Errorf("unexpected result: %+v", result) + } + }) + + t.Run("nil error top-level no subtasks sets READY", func(t *testing.T) { + store := newMinimalMockStore() + pool := newPoolWithMockStore(store) + tk := makeTask("hrr-ready") + store.tasks[tk.ID] = tk + + exec := &storage.Execution{ID: "e2", TaskID: tk.ID, Status: "RUNNING"} + ctx := context.Background() + + pool.handleRunResult(ctx, tk, exec, nil, "claude") + + if exec.Status != "READY" { + t.Errorf("exec.Status: want READY, got %q", exec.Status) + } + _, state, ok := store.lastStateUpdate() + if !ok || state != task.StateReady { + t.Errorf("expected UpdateTaskState(READY), got state=%v ok=%v", state, ok) + } + result := <-pool.resultCh + if result.Err != nil || result.Execution.Status != "READY" { + t.Errorf("unexpected result: %+v", result) + } + }) + + t.Run("nil error subtask sets COMPLETED", func(t *testing.T) { + store := newMinimalMockStore() + pool := newPoolWithMockStore(store) + parent := makeTask("hrr-parent") + parent.State = task.StateBlocked + store.tasks[parent.ID] = parent + + tk := makeTask("hrr-sub") + tk.ParentTaskID = parent.ID + store.tasks[tk.ID] = tk + + exec := &storage.Execution{ID: "e3", TaskID: tk.ID, Status: "RUNNING"} + ctx := context.Background() + + pool.handleRunResult(ctx, tk, exec, nil, "claude") + + if exec.Status != "COMPLETED" { + t.Errorf("exec.Status: want COMPLETED, got %q", exec.Status) + } + result := <-pool.resultCh + if result.Err != nil || result.Execution.Status != "COMPLETED" { + t.Errorf("unexpected result: %+v", result) + } + }) + + t.Run("timeout sets TIMED_OUT", func(t *testing.T) { + store := newMinimalMockStore() + pool := newPoolWithMockStore(store) + tk := makeTask("hrr-timeout") + store.tasks[tk.ID] = tk + + exec := &storage.Execution{ID: "e4", TaskID: tk.ID, Status: "RUNNING"} + ctx, cancel := context.WithCancel(context.Background()) + cancel() // make ctx.Err() == context.Canceled + + // Simulate deadline exceeded by using a deadline-exceeded context. + dctx, dcancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer dcancel() + + pool.handleRunResult(dctx, tk, exec, context.DeadlineExceeded, "claude") + + if exec.Status != "TIMED_OUT" { + t.Errorf("exec.Status: want TIMED_OUT, got %q", exec.Status) + } + _ = ctx + <-pool.resultCh + }) +} + func TestPool_UnsupportedAgent(t *testing.T) { store := testStore(t) runners := map[string]Runner{"claude": &mockRunner{}} |
