diff options
Diffstat (limited to 'internal/executor/executor.go')
| -rw-r--r-- | internal/executor/executor.go | 35 |
1 files changed, 33 insertions, 2 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 6aef736..ae040c2 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -80,6 +80,8 @@ type Pool struct { startedCh chan string // task IDs that just transitioned to RUNNING workCh chan workItem // internal bounded queue; Submit enqueues here doneCh chan struct{} // signals when a worker slot is freed + workerWg sync.WaitGroup // tracks in-flight execute/executeResume goroutines + dispatchDone chan struct{} // closed when the dispatch goroutine exits Questions *QuestionRegistry Classifier *Classifier } @@ -112,6 +114,7 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger * startedCh: make(chan string, maxConcurrent*2), workCh: make(chan workItem, maxConcurrent*10+100), doneCh: make(chan struct{}, maxConcurrent), + dispatchDone: make(chan struct{}), Questions: NewQuestionRegistry(), } go p.dispatch() @@ -122,6 +125,7 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger * // and launches goroutines as soon as a pool slot is available. This prevents // tasks from being rejected when the pool is temporarily at capacity. func (p *Pool) dispatch() { + defer close(p.dispatchDone) for item := range p.workCh { for { p.mu.Lock() @@ -129,9 +133,9 @@ func (p *Pool) dispatch() { p.active++ p.mu.Unlock() if item.exec != nil { - go p.executeResume(item.ctx, item.task, item.exec) + p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.executeResume(i.ctx, i.task, i.exec) }(item) } else { - go p.execute(item.ctx, item.task) + p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.execute(i.ctx, i.task) }(item) } break } @@ -163,6 +167,33 @@ func (p *Pool) Results() <-chan *Result { return p.resultCh } +// Shutdown stops accepting new work and waits for all in-flight workers to +// finish. Returns ctx.Err() if the context deadline is exceeded before all +// workers complete. +func (p *Pool) Shutdown(ctx context.Context) error { + // Stop the dispatch goroutine. We must wait for it to exit before calling + // workerWg.Wait() to avoid a race between dispatch's Add(1) and Wait(). + close(p.workCh) + select { + case <-p.dispatchDone: + case <-ctx.Done(): + return ctx.Err() + } + + done := make(chan struct{}) + go func() { + p.workerWg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Cancel requests cancellation of a running task. Returns false if the task // is not currently running in this pool. func (p *Pool) Cancel(taskID string) bool { |
