summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-26 09:09:19 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-26 09:14:14 +0000
commit3f9843b34d7ae9df2dd9c69427ecab45744b97e9 (patch)
tree1c667c17d77b43a1e5fbcae464068a74c2857fb5 /internal
parentdac676e8284725c8ec6de08282fe08a9b519ccc8 (diff)
feat: graceful shutdown — drain workers before exit (default 3m timeout)
- Add workerWg to Pool; Shutdown() closes workCh and waits for all in-flight execute/executeResume goroutines to finish - Signal handler now shuts down HTTP first, then drains the pool - ShutdownTimeout config field (toml: shutdown_timeout); default 3m - Tests: WaitsForWorkers and TimesOut Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal')
-rw-r--r--internal/api/server.go3
-rw-r--r--internal/cli/serve.go26
-rw-r--r--internal/config/config.go4
-rw-r--r--internal/executor/executor.go35
-rw-r--r--internal/executor/executor_test.go82
5 files changed, 142 insertions, 8 deletions
diff --git a/internal/api/server.go b/internal/api/server.go
index 8eba829..be944a3 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -79,6 +79,9 @@ func (s *Server) SetWorkspaceRoot(path string) {
s.workspaceRoot = path
}
+// Pool returns the executor pool, for graceful shutdown by the caller.
+func (s *Server) Pool() *executor.Pool { return s.pool }
+
func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server {
wd, _ := os.Getwd()
s := &Server{
diff --git a/internal/cli/serve.go b/internal/cli/serve.go
index 644392e..f7493ed 100644
--- a/internal/cli/serve.go
+++ b/internal/cli/serve.go
@@ -174,15 +174,31 @@ func serve(addr string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
+ workerTimeout := 3 * time.Minute
+ if cfg.ShutdownTimeout > 0 {
+ workerTimeout = cfg.ShutdownTimeout
+ }
+
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigCh
- logger.Info("shutting down server...")
- shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second)
- defer shutdownCancel()
- if err := httpSrv.Shutdown(shutdownCtx); err != nil {
- logger.Warn("shutdown error", "err", err)
+ logger.Info("shutting down: draining workers...", "timeout", workerTimeout)
+
+ // Stop the HTTP server so no new requests come in.
+ httpCtx, httpCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer httpCancel()
+ if err := httpSrv.Shutdown(httpCtx); err != nil {
+ logger.Warn("http shutdown error", "err", err)
+ }
+
+ // Wait for in-flight task workers to finish.
+ workerCtx, workerCancel := context.WithTimeout(context.Background(), workerTimeout)
+ defer workerCancel()
+ if err := srv.Pool().Shutdown(workerCtx); err != nil {
+ logger.Warn("worker drain timed out", "err", err)
+ } else {
+ logger.Info("all workers finished cleanly")
}
}()
diff --git a/internal/config/config.go b/internal/config/config.go
index 428712f..71258c1 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
+ "time"
"github.com/BurntSushi/toml"
)
@@ -25,7 +26,8 @@ type Config struct {
GeminiBinaryPath string `toml:"gemini_binary_path"`
ClaudeImage string `toml:"claude_image"`
GeminiImage string `toml:"gemini_image"`
- MaxConcurrent int `toml:"max_concurrent"`
+ MaxConcurrent int `toml:"max_concurrent"`
+ ShutdownTimeout time.Duration `toml:"shutdown_timeout"`
DefaultTimeout string `toml:"default_timeout"`
ServerAddr string `toml:"server_addr"`
WebhookURL string `toml:"webhook_url"`
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 6aef736..ae040c2 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -80,6 +80,8 @@ type Pool struct {
startedCh chan string // task IDs that just transitioned to RUNNING
workCh chan workItem // internal bounded queue; Submit enqueues here
doneCh chan struct{} // signals when a worker slot is freed
+ workerWg sync.WaitGroup // tracks in-flight execute/executeResume goroutines
+ dispatchDone chan struct{} // closed when the dispatch goroutine exits
Questions *QuestionRegistry
Classifier *Classifier
}
@@ -112,6 +114,7 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger *
startedCh: make(chan string, maxConcurrent*2),
workCh: make(chan workItem, maxConcurrent*10+100),
doneCh: make(chan struct{}, maxConcurrent),
+ dispatchDone: make(chan struct{}),
Questions: NewQuestionRegistry(),
}
go p.dispatch()
@@ -122,6 +125,7 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger *
// and launches goroutines as soon as a pool slot is available. This prevents
// tasks from being rejected when the pool is temporarily at capacity.
func (p *Pool) dispatch() {
+ defer close(p.dispatchDone)
for item := range p.workCh {
for {
p.mu.Lock()
@@ -129,9 +133,9 @@ func (p *Pool) dispatch() {
p.active++
p.mu.Unlock()
if item.exec != nil {
- go p.executeResume(item.ctx, item.task, item.exec)
+ p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.executeResume(i.ctx, i.task, i.exec) }(item)
} else {
- go p.execute(item.ctx, item.task)
+ p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.execute(i.ctx, i.task) }(item)
}
break
}
@@ -163,6 +167,33 @@ func (p *Pool) Results() <-chan *Result {
return p.resultCh
}
+// Shutdown stops accepting new work and waits for all in-flight workers to
+// finish. Returns ctx.Err() if the context deadline is exceeded before all
+// workers complete.
+func (p *Pool) Shutdown(ctx context.Context) error {
+ // Stop the dispatch goroutine. We must wait for it to exit before calling
+ // workerWg.Wait() to avoid a race between dispatch's Add(1) and Wait().
+ close(p.workCh)
+ select {
+ case <-p.dispatchDone:
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+
+ done := make(chan struct{})
+ go func() {
+ p.workerWg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
// Cancel requests cancellation of a running task. Returns false if the task
// is not currently running in this pool.
func (p *Pool) Cancel(taskID string) bool {
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
index ad496e7..e16185d 100644
--- a/internal/executor/executor_test.go
+++ b/internal/executor/executor_test.go
@@ -2055,3 +2055,85 @@ func TestPool_ValidationTask_Fail_SetsNeedsFix(t *testing.T) {
t.Errorf("story status: want NEEDS_FIX, got %q", got.Status)
}
}
+
+func TestPool_Shutdown_WaitsForWorkers(t *testing.T) {
+ store := testStore(t)
+ started := make(chan struct{})
+ unblock := make(chan struct{})
+ runner := &mockRunner{
+ onRun: func(t *task.Task, e *storage.Execution) error {
+ close(started)
+ <-unblock
+ return nil
+ },
+ }
+ pool := NewPool(1, map[string]Runner{"claude": runner}, store,
+ slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ tk := makeTask("shutdown-task")
+ tk.State = task.StateQueued
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+
+ // Wait until the worker has started.
+ select {
+ case <-started:
+ case <-time.After(5 * time.Second):
+ t.Fatal("worker did not start")
+ }
+
+ // Shutdown should block until we unblock the worker.
+ done := make(chan error, 1)
+ go func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ done <- pool.Shutdown(ctx)
+ }()
+
+ // Shutdown should not have returned yet.
+ select {
+ case err := <-done:
+ t.Fatalf("Shutdown returned early: %v", err)
+ case <-time.After(50 * time.Millisecond):
+ }
+
+ close(unblock) // let the worker finish
+
+ select {
+ case err := <-done:
+ if err != nil {
+ t.Errorf("Shutdown returned error: %v", err)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("Shutdown did not return after worker finished")
+ }
+}
+
+func TestPool_Shutdown_TimesOut(t *testing.T) {
+ store := testStore(t)
+ unblock := make(chan struct{})
+ runner := &mockRunner{
+ onRun: func(t *task.Task, e *storage.Execution) error {
+ <-unblock // never unblocked
+ return nil
+ },
+ }
+ pool := NewPool(1, map[string]Runner{"claude": runner}, store,
+ slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ tk := makeTask("shutdown-timeout-task")
+ tk.State = task.StateQueued
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+
+ // Give worker a moment to start.
+ time.Sleep(50 * time.Millisecond)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+ err := pool.Shutdown(ctx)
+ if err == nil {
+ t.Error("expected timeout error, got nil")
+ }
+ close(unblock) // cleanup
+}