summaryrefslogtreecommitdiff
path: root/internal/executor/executor_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/executor/executor_test.go')
-rw-r--r--internal/executor/executor_test.go82
1 files changed, 82 insertions, 0 deletions
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
index ad496e7..e16185d 100644
--- a/internal/executor/executor_test.go
+++ b/internal/executor/executor_test.go
@@ -2055,3 +2055,85 @@ func TestPool_ValidationTask_Fail_SetsNeedsFix(t *testing.T) {
t.Errorf("story status: want NEEDS_FIX, got %q", got.Status)
}
}
+
+func TestPool_Shutdown_WaitsForWorkers(t *testing.T) {
+ store := testStore(t)
+ started := make(chan struct{})
+ unblock := make(chan struct{})
+ runner := &mockRunner{
+ onRun: func(t *task.Task, e *storage.Execution) error {
+ close(started)
+ <-unblock
+ return nil
+ },
+ }
+ pool := NewPool(1, map[string]Runner{"claude": runner}, store,
+ slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ tk := makeTask("shutdown-task")
+ tk.State = task.StateQueued
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+
+ // Wait until the worker has started.
+ select {
+ case <-started:
+ case <-time.After(5 * time.Second):
+ t.Fatal("worker did not start")
+ }
+
+ // Shutdown should block until we unblock the worker.
+ done := make(chan error, 1)
+ go func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ done <- pool.Shutdown(ctx)
+ }()
+
+ // Shutdown should not have returned yet.
+ select {
+ case err := <-done:
+ t.Fatalf("Shutdown returned early: %v", err)
+ case <-time.After(50 * time.Millisecond):
+ }
+
+ close(unblock) // let the worker finish
+
+ select {
+ case err := <-done:
+ if err != nil {
+ t.Errorf("Shutdown returned error: %v", err)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("Shutdown did not return after worker finished")
+ }
+}
+
+func TestPool_Shutdown_TimesOut(t *testing.T) {
+ store := testStore(t)
+ unblock := make(chan struct{})
+ runner := &mockRunner{
+ onRun: func(t *task.Task, e *storage.Execution) error {
+ <-unblock // never unblocked
+ return nil
+ },
+ }
+ pool := NewPool(1, map[string]Runner{"claude": runner}, store,
+ slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ tk := makeTask("shutdown-timeout-task")
+ tk.State = task.StateQueued
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+
+ // Give worker a moment to start.
+ time.Sleep(50 * time.Millisecond)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+ err := pool.Shutdown(ctx)
+ if err == nil {
+ t.Error("expected timeout error, got nil")
+ }
+ close(unblock) // cleanup
+}