diff options
Diffstat (limited to 'internal/executor/executor_test.go')
| -rw-r--r-- | internal/executor/executor_test.go | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index ad496e7..e16185d 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -2055,3 +2055,85 @@ func TestPool_ValidationTask_Fail_SetsNeedsFix(t *testing.T) { t.Errorf("story status: want NEEDS_FIX, got %q", got.Status) } } + +func TestPool_Shutdown_WaitsForWorkers(t *testing.T) { + store := testStore(t) + started := make(chan struct{}) + unblock := make(chan struct{}) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + close(started) + <-unblock + return nil + }, + } + pool := NewPool(1, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("shutdown-task") + tk.State = task.StateQueued + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + // Wait until the worker has started. + select { + case <-started: + case <-time.After(5 * time.Second): + t.Fatal("worker did not start") + } + + // Shutdown should block until we unblock the worker. + done := make(chan error, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + done <- pool.Shutdown(ctx) + }() + + // Shutdown should not have returned yet. + select { + case err := <-done: + t.Fatalf("Shutdown returned early: %v", err) + case <-time.After(50 * time.Millisecond): + } + + close(unblock) // let the worker finish + + select { + case err := <-done: + if err != nil { + t.Errorf("Shutdown returned error: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Shutdown did not return after worker finished") + } +} + +func TestPool_Shutdown_TimesOut(t *testing.T) { + store := testStore(t) + unblock := make(chan struct{}) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + <-unblock // never unblocked + return nil + }, + } + pool := NewPool(1, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("shutdown-timeout-task") + tk.State = task.StateQueued + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + // Give worker a moment to start. + time.Sleep(50 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + err := pool.Shutdown(ctx) + if err == nil { + t.Error("expected timeout error, got nil") + } + close(unblock) // cleanup +} |
