diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-03-04 21:25:34 +0000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-03-04 21:25:34 +0000 |
| commit | 6511d6e0ff139495413c7848a9b4aabb9d9ee4e2 (patch) | |
| tree | 95bd6a0efc0ace206a5716da62a5956491cb46e7 | |
| parent | 3962597950421e422b6e1ce57764550f5600ded6 (diff) | |
Add READY state for human-in-the-loop verification
Top-level tasks now land in READY after successful execution instead of
going directly to COMPLETED. Subtasks (with parent_task_id) skip the gate
and remain COMPLETED. Users accept or reject via new API endpoints:
POST /api/tasks/{id}/accept → READY → COMPLETED
POST /api/tasks/{id}/reject → READY → PENDING (with rejection_comment)
- task: add StateReady, RejectionComment field, update ValidTransition
- storage: migrate rejection_comment column, add RejectTask method
- executor: route top-level vs subtask to READY vs COMPLETED
- api: /accept and /reject handlers with 409 on invalid state
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
| -rw-r--r-- | internal/api/server.go | 49 | ||||
| -rw-r--r-- | internal/api/server_test.go | 69 | ||||
| -rw-r--r-- | internal/executor/executor.go | 9 | ||||
| -rw-r--r-- | internal/executor/executor_test.go | 39 | ||||
| -rw-r--r-- | internal/storage/db.go | 53 | ||||
| -rw-r--r-- | internal/storage/db_test.go | 49 | ||||
| -rw-r--r-- | internal/task/task.go | 11 | ||||
| -rw-r--r-- | internal/task/task_test.go | 14 |
8 files changed, 266 insertions, 27 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index 8415b28..608cdd4 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -60,6 +60,8 @@ func (s *Server) routes() { s.mux.HandleFunc("GET /api/tasks", s.handleListTasks) s.mux.HandleFunc("GET /api/tasks/{id}", s.handleGetTask) s.mux.HandleFunc("POST /api/tasks/{id}/run", s.handleRunTask) + s.mux.HandleFunc("POST /api/tasks/{id}/accept", s.handleAcceptTask) + s.mux.HandleFunc("POST /api/tasks/{id}/reject", s.handleRejectTask) s.mux.HandleFunc("GET /api/tasks/{id}/subtasks", s.handleListSubtasks) s.mux.HandleFunc("GET /api/tasks/{id}/executions", s.handleListExecutions) s.mux.HandleFunc("GET /api/executions/{id}", s.handleGetExecution) @@ -210,6 +212,53 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { }) } +func (s *Server) handleAcceptTask(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + t, err := s.store.GetTask(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + if !task.ValidTransition(t.State, task.StateCompleted) { + writeJSON(w, http.StatusConflict, map[string]string{ + "error": fmt.Sprintf("task cannot be accepted from state %s", t.State), + }) + return + } + if err := s.store.UpdateTaskState(id, task.StateCompleted); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]string{"message": "task accepted", "task_id": id}) +} + +func (s *Server) handleRejectTask(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + t, err := s.store.GetTask(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + if !task.ValidTransition(t.State, task.StatePending) { + writeJSON(w, http.StatusConflict, map[string]string{ + "error": fmt.Sprintf("task cannot be rejected from state %s", t.State), + }) + return + } + var input struct { + Comment string `json:"comment"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if err := s.store.RejectTask(id, input.Comment); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]string{"message": "task rejected", "task_id": id}) +} + func (s *Server) handleListSubtasks(w http.ResponseWriter, r *http.Request) { parentID := r.PathValue("id") tasks, err := s.store.ListSubtasks(parentID) diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 68f3657..5094961 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -252,6 +252,75 @@ func TestRunTask_CompletedTask_Returns409(t *testing.T) { } } +func TestAcceptTask_ReadyTask_Returns200(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "accept-ready", task.StateReady) + + req := httptest.NewRequest("POST", "/api/tasks/accept-ready/accept", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } + + got, _ := store.GetTask("accept-ready") + if got.State != task.StateCompleted { + t.Errorf("task state: want COMPLETED, got %v", got.State) + } +} + +func TestAcceptTask_NonReadyTask_Returns409(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "accept-pending", task.StatePending) + + req := httptest.NewRequest("POST", "/api/tasks/accept-pending/accept", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestRejectTask_ReadyTask_Returns200(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "reject-ready", task.StateReady) + + body := bytes.NewBufferString(`{"comment": "needs more detail"}`) + req := httptest.NewRequest("POST", "/api/tasks/reject-ready/reject", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } + + got, _ := store.GetTask("reject-ready") + if got.State != task.StatePending { + t.Errorf("task state: want PENDING, got %v", got.State) + } + if got.RejectionComment != "needs more detail" { + t.Errorf("rejection_comment: want 'needs more detail', got %q", got.RejectionComment) + } +} + +func TestRejectTask_NonReadyTask_Returns409(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "reject-pending", task.StatePending) + + body := bytes.NewBufferString(`{"comment": "comment"}`) + req := httptest.NewRequest("POST", "/api/tasks/reject-pending/reject", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) + } +} + func TestCORS_Headers(t *testing.T) { srv, _ := testServer(t) diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 68ebdf3..d25d3b4 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -149,8 +149,13 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { p.store.UpdateTaskState(t.ID, task.StateFailed) } } else { - exec.Status = "COMPLETED" - p.store.UpdateTaskState(t.ID, task.StateCompleted) + if t.ParentTaskID == "" { + exec.Status = "READY" + p.store.UpdateTaskState(t.ID, task.StateReady) + } else { + exec.Status = "COMPLETED" + p.store.UpdateTaskState(t.ID, task.StateCompleted) + } } if updateErr := p.store.UpdateExecution(exec); updateErr != nil { diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index 5d6a55a..18a79bb 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -73,13 +73,41 @@ func makeTask(id string) *task.Task { } } -func TestPool_Submit_Success(t *testing.T) { +func TestPool_Submit_TopLevel_GoesToReady(t *testing.T) { store := testStore(t) runner := &mockRunner{} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) pool := NewPool(2, runner, store, logger) - tk := makeTask("ps-1") + tk := makeTask("ps-1") // no ParentTaskID → top-level + store.CreateTask(tk) + + if err := pool.Submit(context.Background(), tk); err != nil { + t.Fatalf("submit: %v", err) + } + + result := <-pool.Results() + if result.Err != nil { + t.Errorf("expected no error, got: %v", result.Err) + } + if result.Execution.Status != "READY" { + t.Errorf("status: want READY, got %q", result.Execution.Status) + } + + got, _ := store.GetTask("ps-1") + if got.State != task.StateReady { + t.Errorf("task state: want READY, got %v", got.State) + } +} + +func TestPool_Submit_Subtask_GoesToCompleted(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + tk := makeTask("sub-1") + tk.ParentTaskID = "parent-99" // subtask store.CreateTask(tk) if err := pool.Submit(context.Background(), tk); err != nil { @@ -94,8 +122,7 @@ func TestPool_Submit_Success(t *testing.T) { t.Errorf("status: want COMPLETED, got %q", result.Execution.Status) } - // Verify task state in DB. - got, _ := store.GetTask("ps-1") + got, _ := store.GetTask("sub-1") if got.State != task.StateCompleted { t.Errorf("task state: want COMPLETED, got %v", got.State) } @@ -195,8 +222,8 @@ func TestPool_ConcurrentExecution(t *testing.T) { for i := 0; i < 3; i++ { result := <-pool.Results() - if result.Execution.Status != "COMPLETED" { - t.Errorf("task %s: want COMPLETED, got %q", result.TaskID, result.Execution.Status) + if result.Execution.Status != "READY" { + t.Errorf("task %s: want READY, got %q", result.TaskID, result.Execution.Status) } } diff --git a/internal/storage/db.go b/internal/storage/db.go index 0117ae7..e656f98 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -87,6 +87,7 @@ func (s *DB) migrate() error { // Additive migrations for columns added after initial schema. migrations := []string{ `ALTER TABLE tasks ADD COLUMN parent_task_id TEXT`, + `ALTER TABLE tasks ADD COLUMN rejection_comment TEXT`, } for _, m := range migrations { if _, err := s.db.Exec(m); err != nil { @@ -135,13 +136,13 @@ func (s *DB) CreateTask(t *task.Task) error { // GetTask retrieves a task by ID. func (s *DB) GetTask(id string) (*task.Task, error) { - row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE id = ?`, id) + row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment FROM tasks WHERE id = ?`, id) return scanTask(row) } // ListTasks returns tasks matching the given filter. func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { - query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE 1=1` + query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment FROM tasks WHERE 1=1` var args []interface{} if filter.State != "" { @@ -173,7 +174,7 @@ func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { // ListSubtasks returns all tasks whose parent_task_id matches the given ID. func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) { - rows, err := s.db.Query(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) + rows, err := s.db.Query(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) if err != nil { return nil, err } @@ -207,6 +208,24 @@ func (s *DB) UpdateTaskState(id string, newState task.State) error { return nil } +// RejectTask sets a task's state to PENDING and stores the rejection comment. +func (s *DB) RejectTask(id, comment string) error { + now := time.Now().UTC() + result, err := s.db.Exec(`UPDATE tasks SET state = ?, rejection_comment = ?, updated_at = ? WHERE id = ?`, + string(task.StatePending), comment, now, id) + if err != nil { + return err + } + n, err := result.RowsAffected() + if err != nil { + return err + } + if n == 0 { + return fmt.Errorf("task %q not found", id) + } + return nil +} + // TaskUpdate holds the fields that UpdateTask may change. type TaskUpdate struct { Name string @@ -330,9 +349,11 @@ func (s *DB) ListExecutions(taskID string) ([]*Execution, error) { // UpdateExecution updates a completed execution. func (s *DB) UpdateExecution(e *Execution) error { _, err := s.db.Exec(` - UPDATE executions SET end_time = ?, exit_code = ?, status = ?, cost_usd = ?, error_msg = ? + UPDATE executions SET end_time = ?, exit_code = ?, status = ?, cost_usd = ?, error_msg = ?, + stdout_path = ?, stderr_path = ?, artifact_dir = ? WHERE id = ?`, - e.EndTime.UTC(), e.ExitCode, e.Status, e.CostUSD, e.ErrorMsg, e.ID, + e.EndTime.UTC(), e.ExitCode, e.Status, e.CostUSD, e.ErrorMsg, + e.StdoutPath, e.StderrPath, e.ArtifactDir, e.ID, ) return err } @@ -343,18 +364,20 @@ type scanner interface { func scanTask(row scanner) (*task.Task, error) { var ( - t task.Task - configJSON string - retryJSON string - tagsJSON string - depsJSON string - state string - priority string - timeoutNS int64 - parentTaskID sql.NullString + t task.Task + configJSON string + retryJSON string + tagsJSON string + depsJSON string + state string + priority string + timeoutNS int64 + parentTaskID sql.NullString + rejectionComment sql.NullString ) - err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt) + err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt, &rejectionComment) t.ParentTaskID = parentTaskID.String + t.RejectionComment = rejectionComment.String if err != nil { return nil, err } diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index 7eb81d2..4f9069a 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -340,6 +340,43 @@ func TestDB_UpdateTask(t *testing.T) { }) } +func TestRejectTask(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + tk := &task.Task{ + ID: "reject-1", Name: "R", Claude: task.ClaudeConfig{Instructions: "x"}, + Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StateReady, CreatedAt: now, UpdatedAt: now, + } + if err := db.CreateTask(tk); err != nil { + t.Fatal(err) + } + + if err := db.RejectTask("reject-1", "needs more detail"); err != nil { + t.Fatalf("RejectTask: %v", err) + } + + got, err := db.GetTask("reject-1") + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.State != task.StatePending { + t.Errorf("state: want PENDING, got %v", got.State) + } + if got.RejectionComment != "needs more detail" { + t.Errorf("rejection_comment: want 'needs more detail', got %q", got.RejectionComment) + } +} + +func TestRejectTask_NotFound(t *testing.T) { + db := testDB(t) + err := db.RejectTask("nonexistent", "comment") + if err == nil { + t.Fatal("expected error for nonexistent task") + } +} + func TestUpdateExecution(t *testing.T) { db := testDB(t) now := time.Now().UTC() @@ -360,6 +397,9 @@ func TestUpdateExecution(t *testing.T) { exec.ExitCode = 1 exec.ErrorMsg = "something broke" exec.EndTime = now.Add(2 * time.Minute) + exec.StdoutPath = "/tmp/exec/stdout.log" + exec.StderrPath = "/tmp/exec/stderr.log" + exec.ArtifactDir = "/tmp/exec" if err := db.UpdateExecution(exec); err != nil { t.Fatal(err) } @@ -371,4 +411,13 @@ func TestUpdateExecution(t *testing.T) { if got.ErrorMsg != "something broke" { t.Errorf("error: want 'something broke', got %q", got.ErrorMsg) } + if got.StdoutPath != "/tmp/exec/stdout.log" { + t.Errorf("stdout_path: want /tmp/exec/stdout.log, got %q", got.StdoutPath) + } + if got.StderrPath != "/tmp/exec/stderr.log" { + t.Errorf("stderr_path: want /tmp/exec/stderr.log, got %q", got.StderrPath) + } + if got.ArtifactDir != "/tmp/exec" { + t.Errorf("artifact_dir: want /tmp/exec, got %q", got.ArtifactDir) + } } diff --git a/internal/task/task.go b/internal/task/task.go index d360a07..3e74a82 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -8,6 +8,7 @@ const ( StatePending State = "PENDING" StateQueued State = "QUEUED" StateRunning State = "RUNNING" + StateReady State = "READY" StateCompleted State = "COMPLETED" StateFailed State = "FAILED" StateTimedOut State = "TIMED_OUT" @@ -53,9 +54,10 @@ type Task struct { Priority Priority `yaml:"priority" json:"priority"` Tags []string `yaml:"tags" json:"tags"` DependsOn []string `yaml:"depends_on" json:"depends_on"` - State State `yaml:"-" json:"state"` - CreatedAt time.Time `yaml:"-" json:"created_at"` - UpdatedAt time.Time `yaml:"-" json:"updated_at"` + State State `yaml:"-" json:"state"` + RejectionComment string `yaml:"-" json:"rejection_comment,omitempty"` + CreatedAt time.Time `yaml:"-" json:"created_at"` + UpdatedAt time.Time `yaml:"-" json:"updated_at"` } // Duration wraps time.Duration for YAML unmarshaling from strings like "30m". @@ -90,7 +92,8 @@ func ValidTransition(from, to State) bool { transitions := map[State][]State{ StatePending: {StateQueued, StateCancelled}, StateQueued: {StateRunning, StateCancelled}, - StateRunning: {StateCompleted, StateFailed, StateTimedOut, StateCancelled, StateBudgetExceeded}, + StateRunning: {StateReady, StateCompleted, StateFailed, StateTimedOut, StateCancelled, StateBudgetExceeded}, + StateReady: {StateCompleted, StatePending}, StateFailed: {StateQueued}, // retry StateTimedOut: {StateQueued}, // retry } diff --git a/internal/task/task_test.go b/internal/task/task_test.go index a8e0a84..6498271 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -45,6 +45,7 @@ func TestValidTransition_DisallowedTransitions(t *testing.T) { {"completed to queued", StateCompleted, StateQueued}, {"failed to completed", StateFailed, StateCompleted}, {"timed out to completed", StateTimedOut, StateCompleted}, + {"ready to queued", StateReady, StateQueued}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -55,6 +56,19 @@ func TestValidTransition_DisallowedTransitions(t *testing.T) { } } +func TestValidTransition_ReadyState(t *testing.T) { + valid := []struct{ from, to State }{ + {StateRunning, StateReady}, + {StateReady, StateCompleted}, + {StateReady, StatePending}, + } + for _, tt := range valid { + if !ValidTransition(tt.from, tt.to) { + t.Errorf("expected transition %s -> %s to be valid", tt.from, tt.to) + } + } +} + func TestDuration_UnmarshalYAML(t *testing.T) { var d Duration unmarshal := func(v interface{}) error { |
