diff options
Diffstat (limited to 'internal/executor/executor.go')
| -rw-r--r-- | internal/executor/executor.go | 148 |
1 files changed, 133 insertions, 15 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 0245899..d1c8e72 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -36,18 +36,21 @@ type workItem struct { // Pool manages a bounded set of concurrent task workers. type Pool struct { maxConcurrent int - runner Runner + runners map[string]Runner store *storage.DB logger *slog.Logger depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s - mu sync.Mutex - active int - cancels map[string]context.CancelFunc // taskID → cancel - resultCh chan *Result - workCh chan workItem // internal bounded queue; Submit enqueues here - doneCh chan struct{} // signals when a worker slot is freed - Questions *QuestionRegistry + mu sync.Mutex + active int + activePerAgent map[string]int + rateLimited map[string]time.Time // agentType -> until + cancels map[string]context.CancelFunc // taskID → cancel + resultCh chan *Result + workCh chan workItem // internal bounded queue; Submit enqueues here + doneCh chan struct{} // signals when a worker slot is freed + Questions *QuestionRegistry + Classifier *Classifier } // Result is emitted when a task execution completes. @@ -57,16 +60,18 @@ type Result struct { Err error } -func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.Logger) *Pool { +func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, logger *slog.Logger) *Pool { if maxConcurrent < 1 { maxConcurrent = 1 } p := &Pool{ maxConcurrent: maxConcurrent, - runner: runner, + runners: runners, store: store, logger: logger, depPollInterval: 5 * time.Second, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), cancels: make(map[string]context.CancelFunc), resultCh: make(chan *Result, maxConcurrent*2), workCh: make(chan workItem, maxConcurrent*10+100), @@ -147,10 +152,32 @@ func (p *Pool) SubmitResume(ctx context.Context, t *task.Task, exec *storage.Exe } } +func (p *Pool) getRunner(t *task.Task) (Runner, error) { + agentType := t.Agent.Type + if agentType == "" { + agentType = "claude" // Default for backward compatibility + } + runner, ok := p.runners[agentType] + if !ok { + return nil, fmt.Errorf("unsupported agent type: %q", agentType) + } + return runner, nil +} + func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Execution) { + agentType := t.Agent.Type + if agentType == "" { + agentType = "claude" + } + + p.mu.Lock() + p.activePerAgent[agentType]++ + p.mu.Unlock() + defer func() { p.mu.Lock() p.active-- + p.activePerAgent[agentType]-- p.mu.Unlock() select { case p.doneCh <- struct{}{}: @@ -158,8 +185,15 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex } }() + runner, err := p.getRunner(t) + if err != nil { + p.logger.Error("failed to get runner for resume", "error", err, "taskID", t.ID) + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} + return + } + // Pre-populate log paths. - if lp, ok := p.runner.(LogPather); ok { + if lp, ok := runner.(LogPather); ok { if logDir := lp.ExecLogDir(exec.ID); logDir != "" { exec.StdoutPath = filepath.Join(logDir, "stdout.log") exec.StderrPath = filepath.Join(logDir, "stderr.log") @@ -182,12 +216,30 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex } else { ctx, cancel = context.WithCancel(ctx) } - defer cancel() + p.mu.Lock() + p.cancels[t.ID] = cancel + p.mu.Unlock() + defer func() { + cancel() + p.mu.Lock() + delete(p.cancels, t.ID) + p.mu.Unlock() + }() - err := p.runner.Run(ctx, t, exec) + err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() if err != nil { + if isRateLimitError(err) { + p.mu.Lock() + retryAfter := parseRetryAfter(err.Error()) + if retryAfter == 0 { + retryAfter = 1 * time.Minute + } + p.rateLimited[agentType] = time.Now().Add(retryAfter) + p.mu.Unlock() + } + var blockedErr *BlockedError if errors.As(err, &blockedErr) { exec.Status = "BLOCKED" @@ -234,9 +286,45 @@ func (p *Pool) ActiveCount() int { } func (p *Pool) execute(ctx context.Context, t *task.Task) { + // 1. Classification + if p.Classifier != nil { + p.mu.Lock() + activeTasks := make(map[string]int) + rateLimited := make(map[string]bool) + now := time.Now() + for agent := range p.runners { + activeTasks[agent] = p.activePerAgent[agent] + rateLimited[agent] = now.Before(p.rateLimited[agent]) + } + status := SystemStatus{ + ActiveTasks: activeTasks, + RateLimited: rateLimited, + } + p.mu.Unlock() + + cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status) + if err == nil { + p.logger.Info("task classified", "taskID", t.ID, "agent", cls.AgentType, "model", cls.Model, "reason", cls.Reason) + t.Agent.Type = cls.AgentType + t.Agent.Model = cls.Model + } else { + p.logger.Error("classification failed", "error", err, "taskID", t.ID) + } + } + + agentType := t.Agent.Type + if agentType == "" { + agentType = "claude" + } + + p.mu.Lock() + p.activePerAgent[agentType]++ + p.mu.Unlock() + defer func() { p.mu.Lock() p.active-- + p.activePerAgent[agentType]-- p.mu.Unlock() select { case p.doneCh <- struct{}{}: @@ -244,6 +332,26 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { } }() + runner, err := p.getRunner(t) + if err != nil { + p.logger.Error("failed to get runner", "error", err, "taskID", t.ID) + now := time.Now().UTC() + exec := &storage.Execution{ + ID: uuid.New().String(), + TaskID: t.ID, + StartTime: now, + EndTime: now, + Status: "FAILED", + ErrorMsg: err.Error(), + } + if createErr := p.store.CreateExecution(exec); createErr != nil { + p.logger.Error("failed to create execution record", "error", createErr) + } + p.store.UpdateTaskState(t.ID, task.StateFailed) + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} + return + } + // Wait for all dependencies to complete before starting execution. if len(t.DependsOn) > 0 { if err := p.waitForDependencies(ctx, t); err != nil { @@ -275,7 +383,7 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { // Pre-populate log paths so they're available in the DB immediately — // before the subprocess starts — enabling live tailing and debugging. - if lp, ok := p.runner.(LogPather); ok { + if lp, ok := runner.(LogPather); ok { if logDir := lp.ExecLogDir(execID); logDir != "" { exec.StdoutPath = filepath.Join(logDir, "stdout.log") exec.StderrPath = filepath.Join(logDir, "stderr.log") @@ -309,10 +417,20 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { }() // Run the task. - err := p.runner.Run(ctx, t, exec) + err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() if err != nil { + if isRateLimitError(err) { + p.mu.Lock() + retryAfter := parseRetryAfter(err.Error()) + if retryAfter == 0 { + retryAfter = 1 * time.Minute + } + p.rateLimited[agentType] = time.Now().Add(retryAfter) + p.mu.Unlock() + } + var blockedErr *BlockedError if errors.As(err, &blockedErr) { exec.Status = "BLOCKED" |
