diff options
Diffstat (limited to 'internal/executor/executor.go')
| -rw-r--r-- | internal/executor/executor.go | 136 |
1 files changed, 97 insertions, 39 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 6bd1c68..d1c8e72 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -26,12 +26,20 @@ type Runner interface { Run(ctx context.Context, t *task.Task, exec *storage.Execution) error } +// workItem is an entry in the pool's internal work queue. +type workItem struct { + ctx context.Context + task *task.Task + exec *storage.Execution // non-nil for resume submissions +} + // Pool manages a bounded set of concurrent task workers. type Pool struct { - maxConcurrent int - runners map[string]Runner - store *storage.DB - logger *slog.Logger + maxConcurrent int + runners map[string]Runner + store *storage.DB + logger *slog.Logger + depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s mu sync.Mutex active int @@ -39,6 +47,8 @@ type Pool struct { rateLimited map[string]time.Time // agentType -> until cancels map[string]context.CancelFunc // taskID → cancel resultCh chan *Result + workCh chan workItem // internal bounded queue; Submit enqueues here + doneCh chan struct{} // signals when a worker slot is freed Questions *QuestionRegistry Classifier *Classifier } @@ -54,33 +64,57 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, lo if maxConcurrent < 1 { maxConcurrent = 1 } - return &Pool{ - maxConcurrent: maxConcurrent, - runners: runners, - store: store, - logger: logger, - activePerAgent: make(map[string]int), - rateLimited: make(map[string]time.Time), - cancels: make(map[string]context.CancelFunc), - resultCh: make(chan *Result, maxConcurrent*2), - Questions: NewQuestionRegistry(), + p := &Pool{ + maxConcurrent: maxConcurrent, + runners: runners, + store: store, + logger: logger, + depPollInterval: 5 * time.Second, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), + cancels: make(map[string]context.CancelFunc), + resultCh: make(chan *Result, maxConcurrent*2), + workCh: make(chan workItem, maxConcurrent*10+100), + doneCh: make(chan struct{}, maxConcurrent), + Questions: NewQuestionRegistry(), } + go p.dispatch() + return p } -// Submit dispatches a task for execution. Blocks if pool is at capacity. -func (p *Pool) Submit(ctx context.Context, t *task.Task) error { - p.mu.Lock() - if p.active >= p.maxConcurrent { - active := p.active - max := p.maxConcurrent - p.mu.Unlock() - return fmt.Errorf("executor pool at capacity (%d/%d)", active, max) +// dispatch is a long-running goroutine that reads from the internal work queue +// and launches goroutines as soon as a pool slot is available. This prevents +// tasks from being rejected when the pool is temporarily at capacity. +func (p *Pool) dispatch() { + for item := range p.workCh { + for { + p.mu.Lock() + if p.active < p.maxConcurrent { + p.active++ + p.mu.Unlock() + if item.exec != nil { + go p.executeResume(item.ctx, item.task, item.exec) + } else { + go p.execute(item.ctx, item.task) + } + break + } + p.mu.Unlock() + <-p.doneCh // wait for a worker to finish + } } - p.active++ - p.mu.Unlock() +} - go p.execute(ctx, t) - return nil +// Submit enqueues a task for execution. Returns an error only if the internal +// work queue is full. When the pool is at capacity the task is buffered and +// dispatched as soon as a slot becomes available. +func (p *Pool) Submit(ctx context.Context, t *task.Task) error { + select { + case p.workCh <- workItem{ctx: ctx, task: t}: + return nil + default: + return fmt.Errorf("executor work queue full (capacity %d)", cap(p.workCh)) + } } // Results returns the channel for reading execution results. @@ -104,18 +138,18 @@ func (p *Pool) Cancel(taskID string) bool { // SubmitResume re-queues a blocked task using the provided resume execution. // The execution must have ResumeSessionID and ResumeAnswer set. func (p *Pool) SubmitResume(ctx context.Context, t *task.Task, exec *storage.Execution) error { - p.mu.Lock() - if p.active >= p.maxConcurrent { - active := p.active - max := p.maxConcurrent - p.mu.Unlock() - return fmt.Errorf("executor pool at capacity (%d/%d)", active, max) + if t.State != task.StateBlocked && t.State != task.StateTimedOut { + return fmt.Errorf("task %s must be in BLOCKED or TIMED_OUT state to resume (current: %s)", t.ID, t.State) + } + if exec.ResumeSessionID == "" { + return fmt.Errorf("resume execution for task %s must have a ResumeSessionID", t.ID) + } + select { + case p.workCh <- workItem{ctx: ctx, task: t, exec: exec}: + return nil + default: + return fmt.Errorf("executor work queue full (capacity %d)", cap(p.workCh)) } - p.active++ - p.mu.Unlock() - - go p.executeResume(ctx, t, exec) - return nil } func (p *Pool) getRunner(t *task.Task) (Runner, error) { @@ -145,6 +179,10 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex p.active-- p.activePerAgent[agentType]-- p.mu.Unlock() + select { + case p.doneCh <- struct{}{}: + default: + } }() runner, err := p.getRunner(t) @@ -178,7 +216,15 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex } else { ctx, cancel = context.WithCancel(ctx) } - defer cancel() + p.mu.Lock() + p.cancels[t.ID] = cancel + p.mu.Unlock() + defer func() { + cancel() + p.mu.Lock() + delete(p.cancels, t.ID) + p.mu.Unlock() + }() err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() @@ -207,6 +253,10 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex exec.Status = "CANCELLED" exec.ErrorMsg = "execution cancelled" p.store.UpdateTaskState(t.ID, task.StateCancelled) + } else if isQuotaExhausted(err) { + exec.Status = "BUDGET_EXCEEDED" + exec.ErrorMsg = err.Error() + p.store.UpdateTaskState(t.ID, task.StateBudgetExceeded) } else { exec.Status = "FAILED" exec.ErrorMsg = err.Error() @@ -276,6 +326,10 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { p.active-- p.activePerAgent[agentType]-- p.mu.Unlock() + select { + case p.doneCh <- struct{}{}: + default: + } }() runner, err := p.getRunner(t) @@ -390,6 +444,10 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { exec.Status = "CANCELLED" exec.ErrorMsg = "execution cancelled" p.store.UpdateTaskState(t.ID, task.StateCancelled) + } else if isQuotaExhausted(err) { + exec.Status = "BUDGET_EXCEEDED" + exec.ErrorMsg = err.Error() + p.store.UpdateTaskState(t.ID, task.StateBudgetExceeded) } else { exec.Status = "FAILED" exec.ErrorMsg = err.Error() @@ -444,7 +502,7 @@ func (p *Pool) waitForDependencies(ctx context.Context, t *task.Task) error { select { case <-ctx.Done(): return ctx.Err() - case <-time.After(5 * time.Second): + case <-time.After(p.depPollInterval): } } } |
