diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-03-26 09:09:19 +0000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-03-26 09:14:14 +0000 |
| commit | 3f9843b34d7ae9df2dd9c69427ecab45744b97e9 (patch) | |
| tree | 1c667c17d77b43a1e5fbcae464068a74c2857fb5 | |
| parent | dac676e8284725c8ec6de08282fe08a9b519ccc8 (diff) | |
feat: graceful shutdown — drain workers before exit (default 3m timeout)
- Add workerWg to Pool; Shutdown() closes workCh and waits for all
in-flight execute/executeResume goroutines to finish
- Signal handler now shuts down HTTP first, then drains the pool
- ShutdownTimeout config field (toml: shutdown_timeout); default 3m
- Tests: WaitsForWorkers and TimesOut
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
| -rw-r--r-- | internal/api/server.go | 3 | ||||
| -rw-r--r-- | internal/cli/serve.go | 26 | ||||
| -rw-r--r-- | internal/config/config.go | 4 | ||||
| -rw-r--r-- | internal/executor/executor.go | 35 | ||||
| -rw-r--r-- | internal/executor/executor_test.go | 82 |
5 files changed, 142 insertions, 8 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index 8eba829..be944a3 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -79,6 +79,9 @@ func (s *Server) SetWorkspaceRoot(path string) { s.workspaceRoot = path } +// Pool returns the executor pool, for graceful shutdown by the caller. +func (s *Server) Pool() *executor.Pool { return s.pool } + func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server { wd, _ := os.Getwd() s := &Server{ diff --git a/internal/cli/serve.go b/internal/cli/serve.go index 644392e..f7493ed 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -174,15 +174,31 @@ func serve(addr string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + workerTimeout := 3 * time.Minute + if cfg.ShutdownTimeout > 0 { + workerTimeout = cfg.ShutdownTimeout + } + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) go func() { <-sigCh - logger.Info("shutting down server...") - shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second) - defer shutdownCancel() - if err := httpSrv.Shutdown(shutdownCtx); err != nil { - logger.Warn("shutdown error", "err", err) + logger.Info("shutting down: draining workers...", "timeout", workerTimeout) + + // Stop the HTTP server so no new requests come in. + httpCtx, httpCancel := context.WithTimeout(ctx, 5*time.Second) + defer httpCancel() + if err := httpSrv.Shutdown(httpCtx); err != nil { + logger.Warn("http shutdown error", "err", err) + } + + // Wait for in-flight task workers to finish. + workerCtx, workerCancel := context.WithTimeout(context.Background(), workerTimeout) + defer workerCancel() + if err := srv.Pool().Shutdown(workerCtx); err != nil { + logger.Warn("worker drain timed out", "err", err) + } else { + logger.Info("all workers finished cleanly") } }() diff --git a/internal/config/config.go b/internal/config/config.go index 428712f..71258c1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "time" "github.com/BurntSushi/toml" ) @@ -25,7 +26,8 @@ type Config struct { GeminiBinaryPath string `toml:"gemini_binary_path"` ClaudeImage string `toml:"claude_image"` GeminiImage string `toml:"gemini_image"` - MaxConcurrent int `toml:"max_concurrent"` + MaxConcurrent int `toml:"max_concurrent"` + ShutdownTimeout time.Duration `toml:"shutdown_timeout"` DefaultTimeout string `toml:"default_timeout"` ServerAddr string `toml:"server_addr"` WebhookURL string `toml:"webhook_url"` diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 6aef736..ae040c2 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -80,6 +80,8 @@ type Pool struct { startedCh chan string // task IDs that just transitioned to RUNNING workCh chan workItem // internal bounded queue; Submit enqueues here doneCh chan struct{} // signals when a worker slot is freed + workerWg sync.WaitGroup // tracks in-flight execute/executeResume goroutines + dispatchDone chan struct{} // closed when the dispatch goroutine exits Questions *QuestionRegistry Classifier *Classifier } @@ -112,6 +114,7 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger * startedCh: make(chan string, maxConcurrent*2), workCh: make(chan workItem, maxConcurrent*10+100), doneCh: make(chan struct{}, maxConcurrent), + dispatchDone: make(chan struct{}), Questions: NewQuestionRegistry(), } go p.dispatch() @@ -122,6 +125,7 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger * // and launches goroutines as soon as a pool slot is available. This prevents // tasks from being rejected when the pool is temporarily at capacity. func (p *Pool) dispatch() { + defer close(p.dispatchDone) for item := range p.workCh { for { p.mu.Lock() @@ -129,9 +133,9 @@ func (p *Pool) dispatch() { p.active++ p.mu.Unlock() if item.exec != nil { - go p.executeResume(item.ctx, item.task, item.exec) + p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.executeResume(i.ctx, i.task, i.exec) }(item) } else { - go p.execute(item.ctx, item.task) + p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.execute(i.ctx, i.task) }(item) } break } @@ -163,6 +167,33 @@ func (p *Pool) Results() <-chan *Result { return p.resultCh } +// Shutdown stops accepting new work and waits for all in-flight workers to +// finish. Returns ctx.Err() if the context deadline is exceeded before all +// workers complete. +func (p *Pool) Shutdown(ctx context.Context) error { + // Stop the dispatch goroutine. We must wait for it to exit before calling + // workerWg.Wait() to avoid a race between dispatch's Add(1) and Wait(). + close(p.workCh) + select { + case <-p.dispatchDone: + case <-ctx.Done(): + return ctx.Err() + } + + done := make(chan struct{}) + go func() { + p.workerWg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Cancel requests cancellation of a running task. Returns false if the task // is not currently running in this pool. func (p *Pool) Cancel(taskID string) bool { diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index ad496e7..e16185d 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -2055,3 +2055,85 @@ func TestPool_ValidationTask_Fail_SetsNeedsFix(t *testing.T) { t.Errorf("story status: want NEEDS_FIX, got %q", got.Status) } } + +func TestPool_Shutdown_WaitsForWorkers(t *testing.T) { + store := testStore(t) + started := make(chan struct{}) + unblock := make(chan struct{}) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + close(started) + <-unblock + return nil + }, + } + pool := NewPool(1, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("shutdown-task") + tk.State = task.StateQueued + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + // Wait until the worker has started. + select { + case <-started: + case <-time.After(5 * time.Second): + t.Fatal("worker did not start") + } + + // Shutdown should block until we unblock the worker. + done := make(chan error, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + done <- pool.Shutdown(ctx) + }() + + // Shutdown should not have returned yet. + select { + case err := <-done: + t.Fatalf("Shutdown returned early: %v", err) + case <-time.After(50 * time.Millisecond): + } + + close(unblock) // let the worker finish + + select { + case err := <-done: + if err != nil { + t.Errorf("Shutdown returned error: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Shutdown did not return after worker finished") + } +} + +func TestPool_Shutdown_TimesOut(t *testing.T) { + store := testStore(t) + unblock := make(chan struct{}) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + <-unblock // never unblocked + return nil + }, + } + pool := NewPool(1, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("shutdown-timeout-task") + tk.State = task.StateQueued + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + // Give worker a moment to start. + time.Sleep(50 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + err := pool.Shutdown(ctx) + if err == nil { + t.Error("expected timeout error, got nil") + } + close(unblock) // cleanup +} |
