summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/api/server.go49
-rw-r--r--internal/api/server_test.go69
-rw-r--r--internal/executor/executor.go9
-rw-r--r--internal/executor/executor_test.go39
-rw-r--r--internal/storage/db.go53
-rw-r--r--internal/storage/db_test.go49
-rw-r--r--internal/task/task.go11
-rw-r--r--internal/task/task_test.go14
8 files changed, 266 insertions, 27 deletions
diff --git a/internal/api/server.go b/internal/api/server.go
index 8415b28..608cdd4 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -60,6 +60,8 @@ func (s *Server) routes() {
s.mux.HandleFunc("GET /api/tasks", s.handleListTasks)
s.mux.HandleFunc("GET /api/tasks/{id}", s.handleGetTask)
s.mux.HandleFunc("POST /api/tasks/{id}/run", s.handleRunTask)
+ s.mux.HandleFunc("POST /api/tasks/{id}/accept", s.handleAcceptTask)
+ s.mux.HandleFunc("POST /api/tasks/{id}/reject", s.handleRejectTask)
s.mux.HandleFunc("GET /api/tasks/{id}/subtasks", s.handleListSubtasks)
s.mux.HandleFunc("GET /api/tasks/{id}/executions", s.handleListExecutions)
s.mux.HandleFunc("GET /api/executions/{id}", s.handleGetExecution)
@@ -210,6 +212,53 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
})
}
+func (s *Server) handleAcceptTask(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ t, err := s.store.GetTask(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
+ return
+ }
+ if !task.ValidTransition(t.State, task.StateCompleted) {
+ writeJSON(w, http.StatusConflict, map[string]string{
+ "error": fmt.Sprintf("task cannot be accepted from state %s", t.State),
+ })
+ return
+ }
+ if err := s.store.UpdateTaskState(id, task.StateCompleted); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusOK, map[string]string{"message": "task accepted", "task_id": id})
+}
+
+func (s *Server) handleRejectTask(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ t, err := s.store.GetTask(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
+ return
+ }
+ if !task.ValidTransition(t.State, task.StatePending) {
+ writeJSON(w, http.StatusConflict, map[string]string{
+ "error": fmt.Sprintf("task cannot be rejected from state %s", t.State),
+ })
+ return
+ }
+ var input struct {
+ Comment string `json:"comment"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if err := s.store.RejectTask(id, input.Comment); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusOK, map[string]string{"message": "task rejected", "task_id": id})
+}
+
func (s *Server) handleListSubtasks(w http.ResponseWriter, r *http.Request) {
parentID := r.PathValue("id")
tasks, err := s.store.ListSubtasks(parentID)
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index 68f3657..5094961 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -252,6 +252,75 @@ func TestRunTask_CompletedTask_Returns409(t *testing.T) {
}
}
+func TestAcceptTask_ReadyTask_Returns200(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "accept-ready", task.StateReady)
+
+ req := httptest.NewRequest("POST", "/api/tasks/accept-ready/accept", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ got, _ := store.GetTask("accept-ready")
+ if got.State != task.StateCompleted {
+ t.Errorf("task state: want COMPLETED, got %v", got.State)
+ }
+}
+
+func TestAcceptTask_NonReadyTask_Returns409(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "accept-pending", task.StatePending)
+
+ req := httptest.NewRequest("POST", "/api/tasks/accept-pending/accept", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusConflict {
+ t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestRejectTask_ReadyTask_Returns200(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "reject-ready", task.StateReady)
+
+ body := bytes.NewBufferString(`{"comment": "needs more detail"}`)
+ req := httptest.NewRequest("POST", "/api/tasks/reject-ready/reject", body)
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ got, _ := store.GetTask("reject-ready")
+ if got.State != task.StatePending {
+ t.Errorf("task state: want PENDING, got %v", got.State)
+ }
+ if got.RejectionComment != "needs more detail" {
+ t.Errorf("rejection_comment: want 'needs more detail', got %q", got.RejectionComment)
+ }
+}
+
+func TestRejectTask_NonReadyTask_Returns409(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "reject-pending", task.StatePending)
+
+ body := bytes.NewBufferString(`{"comment": "comment"}`)
+ req := httptest.NewRequest("POST", "/api/tasks/reject-pending/reject", body)
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusConflict {
+ t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
func TestCORS_Headers(t *testing.T) {
srv, _ := testServer(t)
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 68ebdf3..d25d3b4 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -149,8 +149,13 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
p.store.UpdateTaskState(t.ID, task.StateFailed)
}
} else {
- exec.Status = "COMPLETED"
- p.store.UpdateTaskState(t.ID, task.StateCompleted)
+ if t.ParentTaskID == "" {
+ exec.Status = "READY"
+ p.store.UpdateTaskState(t.ID, task.StateReady)
+ } else {
+ exec.Status = "COMPLETED"
+ p.store.UpdateTaskState(t.ID, task.StateCompleted)
+ }
}
if updateErr := p.store.UpdateExecution(exec); updateErr != nil {
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
index 5d6a55a..18a79bb 100644
--- a/internal/executor/executor_test.go
+++ b/internal/executor/executor_test.go
@@ -73,13 +73,41 @@ func makeTask(id string) *task.Task {
}
}
-func TestPool_Submit_Success(t *testing.T) {
+func TestPool_Submit_TopLevel_GoesToReady(t *testing.T) {
store := testStore(t)
runner := &mockRunner{}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
pool := NewPool(2, runner, store, logger)
- tk := makeTask("ps-1")
+ tk := makeTask("ps-1") // no ParentTaskID → top-level
+ store.CreateTask(tk)
+
+ if err := pool.Submit(context.Background(), tk); err != nil {
+ t.Fatalf("submit: %v", err)
+ }
+
+ result := <-pool.Results()
+ if result.Err != nil {
+ t.Errorf("expected no error, got: %v", result.Err)
+ }
+ if result.Execution.Status != "READY" {
+ t.Errorf("status: want READY, got %q", result.Execution.Status)
+ }
+
+ got, _ := store.GetTask("ps-1")
+ if got.State != task.StateReady {
+ t.Errorf("task state: want READY, got %v", got.State)
+ }
+}
+
+func TestPool_Submit_Subtask_GoesToCompleted(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runner, store, logger)
+
+ tk := makeTask("sub-1")
+ tk.ParentTaskID = "parent-99" // subtask
store.CreateTask(tk)
if err := pool.Submit(context.Background(), tk); err != nil {
@@ -94,8 +122,7 @@ func TestPool_Submit_Success(t *testing.T) {
t.Errorf("status: want COMPLETED, got %q", result.Execution.Status)
}
- // Verify task state in DB.
- got, _ := store.GetTask("ps-1")
+ got, _ := store.GetTask("sub-1")
if got.State != task.StateCompleted {
t.Errorf("task state: want COMPLETED, got %v", got.State)
}
@@ -195,8 +222,8 @@ func TestPool_ConcurrentExecution(t *testing.T) {
for i := 0; i < 3; i++ {
result := <-pool.Results()
- if result.Execution.Status != "COMPLETED" {
- t.Errorf("task %s: want COMPLETED, got %q", result.TaskID, result.Execution.Status)
+ if result.Execution.Status != "READY" {
+ t.Errorf("task %s: want READY, got %q", result.TaskID, result.Execution.Status)
}
}
diff --git a/internal/storage/db.go b/internal/storage/db.go
index 0117ae7..e656f98 100644
--- a/internal/storage/db.go
+++ b/internal/storage/db.go
@@ -87,6 +87,7 @@ func (s *DB) migrate() error {
// Additive migrations for columns added after initial schema.
migrations := []string{
`ALTER TABLE tasks ADD COLUMN parent_task_id TEXT`,
+ `ALTER TABLE tasks ADD COLUMN rejection_comment TEXT`,
}
for _, m := range migrations {
if _, err := s.db.Exec(m); err != nil {
@@ -135,13 +136,13 @@ func (s *DB) CreateTask(t *task.Task) error {
// GetTask retrieves a task by ID.
func (s *DB) GetTask(id string) (*task.Task, error) {
- row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE id = ?`, id)
+ row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment FROM tasks WHERE id = ?`, id)
return scanTask(row)
}
// ListTasks returns tasks matching the given filter.
func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) {
- query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE 1=1`
+ query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment FROM tasks WHERE 1=1`
var args []interface{}
if filter.State != "" {
@@ -173,7 +174,7 @@ func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) {
// ListSubtasks returns all tasks whose parent_task_id matches the given ID.
func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) {
- rows, err := s.db.Query(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID)
+ rows, err := s.db.Query(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID)
if err != nil {
return nil, err
}
@@ -207,6 +208,24 @@ func (s *DB) UpdateTaskState(id string, newState task.State) error {
return nil
}
+// RejectTask sets a task's state to PENDING and stores the rejection comment.
+func (s *DB) RejectTask(id, comment string) error {
+ now := time.Now().UTC()
+ result, err := s.db.Exec(`UPDATE tasks SET state = ?, rejection_comment = ?, updated_at = ? WHERE id = ?`,
+ string(task.StatePending), comment, now, id)
+ if err != nil {
+ return err
+ }
+ n, err := result.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if n == 0 {
+ return fmt.Errorf("task %q not found", id)
+ }
+ return nil
+}
+
// TaskUpdate holds the fields that UpdateTask may change.
type TaskUpdate struct {
Name string
@@ -330,9 +349,11 @@ func (s *DB) ListExecutions(taskID string) ([]*Execution, error) {
// UpdateExecution updates a completed execution.
func (s *DB) UpdateExecution(e *Execution) error {
_, err := s.db.Exec(`
- UPDATE executions SET end_time = ?, exit_code = ?, status = ?, cost_usd = ?, error_msg = ?
+ UPDATE executions SET end_time = ?, exit_code = ?, status = ?, cost_usd = ?, error_msg = ?,
+ stdout_path = ?, stderr_path = ?, artifact_dir = ?
WHERE id = ?`,
- e.EndTime.UTC(), e.ExitCode, e.Status, e.CostUSD, e.ErrorMsg, e.ID,
+ e.EndTime.UTC(), e.ExitCode, e.Status, e.CostUSD, e.ErrorMsg,
+ e.StdoutPath, e.StderrPath, e.ArtifactDir, e.ID,
)
return err
}
@@ -343,18 +364,20 @@ type scanner interface {
func scanTask(row scanner) (*task.Task, error) {
var (
- t task.Task
- configJSON string
- retryJSON string
- tagsJSON string
- depsJSON string
- state string
- priority string
- timeoutNS int64
- parentTaskID sql.NullString
+ t task.Task
+ configJSON string
+ retryJSON string
+ tagsJSON string
+ depsJSON string
+ state string
+ priority string
+ timeoutNS int64
+ parentTaskID sql.NullString
+ rejectionComment sql.NullString
)
- err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt)
+ err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt, &rejectionComment)
t.ParentTaskID = parentTaskID.String
+ t.RejectionComment = rejectionComment.String
if err != nil {
return nil, err
}
diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go
index 7eb81d2..4f9069a 100644
--- a/internal/storage/db_test.go
+++ b/internal/storage/db_test.go
@@ -340,6 +340,43 @@ func TestDB_UpdateTask(t *testing.T) {
})
}
+func TestRejectTask(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC()
+ tk := &task.Task{
+ ID: "reject-1", Name: "R", Claude: task.ClaudeConfig{Instructions: "x"},
+ Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{},
+ State: task.StateReady, CreatedAt: now, UpdatedAt: now,
+ }
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := db.RejectTask("reject-1", "needs more detail"); err != nil {
+ t.Fatalf("RejectTask: %v", err)
+ }
+
+ got, err := db.GetTask("reject-1")
+ if err != nil {
+ t.Fatalf("GetTask: %v", err)
+ }
+ if got.State != task.StatePending {
+ t.Errorf("state: want PENDING, got %v", got.State)
+ }
+ if got.RejectionComment != "needs more detail" {
+ t.Errorf("rejection_comment: want 'needs more detail', got %q", got.RejectionComment)
+ }
+}
+
+func TestRejectTask_NotFound(t *testing.T) {
+ db := testDB(t)
+ err := db.RejectTask("nonexistent", "comment")
+ if err == nil {
+ t.Fatal("expected error for nonexistent task")
+ }
+}
+
func TestUpdateExecution(t *testing.T) {
db := testDB(t)
now := time.Now().UTC()
@@ -360,6 +397,9 @@ func TestUpdateExecution(t *testing.T) {
exec.ExitCode = 1
exec.ErrorMsg = "something broke"
exec.EndTime = now.Add(2 * time.Minute)
+ exec.StdoutPath = "/tmp/exec/stdout.log"
+ exec.StderrPath = "/tmp/exec/stderr.log"
+ exec.ArtifactDir = "/tmp/exec"
if err := db.UpdateExecution(exec); err != nil {
t.Fatal(err)
}
@@ -371,4 +411,13 @@ func TestUpdateExecution(t *testing.T) {
if got.ErrorMsg != "something broke" {
t.Errorf("error: want 'something broke', got %q", got.ErrorMsg)
}
+ if got.StdoutPath != "/tmp/exec/stdout.log" {
+ t.Errorf("stdout_path: want /tmp/exec/stdout.log, got %q", got.StdoutPath)
+ }
+ if got.StderrPath != "/tmp/exec/stderr.log" {
+ t.Errorf("stderr_path: want /tmp/exec/stderr.log, got %q", got.StderrPath)
+ }
+ if got.ArtifactDir != "/tmp/exec" {
+ t.Errorf("artifact_dir: want /tmp/exec, got %q", got.ArtifactDir)
+ }
}
diff --git a/internal/task/task.go b/internal/task/task.go
index d360a07..3e74a82 100644
--- a/internal/task/task.go
+++ b/internal/task/task.go
@@ -8,6 +8,7 @@ const (
StatePending State = "PENDING"
StateQueued State = "QUEUED"
StateRunning State = "RUNNING"
+ StateReady State = "READY"
StateCompleted State = "COMPLETED"
StateFailed State = "FAILED"
StateTimedOut State = "TIMED_OUT"
@@ -53,9 +54,10 @@ type Task struct {
Priority Priority `yaml:"priority" json:"priority"`
Tags []string `yaml:"tags" json:"tags"`
DependsOn []string `yaml:"depends_on" json:"depends_on"`
- State State `yaml:"-" json:"state"`
- CreatedAt time.Time `yaml:"-" json:"created_at"`
- UpdatedAt time.Time `yaml:"-" json:"updated_at"`
+ State State `yaml:"-" json:"state"`
+ RejectionComment string `yaml:"-" json:"rejection_comment,omitempty"`
+ CreatedAt time.Time `yaml:"-" json:"created_at"`
+ UpdatedAt time.Time `yaml:"-" json:"updated_at"`
}
// Duration wraps time.Duration for YAML unmarshaling from strings like "30m".
@@ -90,7 +92,8 @@ func ValidTransition(from, to State) bool {
transitions := map[State][]State{
StatePending: {StateQueued, StateCancelled},
StateQueued: {StateRunning, StateCancelled},
- StateRunning: {StateCompleted, StateFailed, StateTimedOut, StateCancelled, StateBudgetExceeded},
+ StateRunning: {StateReady, StateCompleted, StateFailed, StateTimedOut, StateCancelled, StateBudgetExceeded},
+ StateReady: {StateCompleted, StatePending},
StateFailed: {StateQueued}, // retry
StateTimedOut: {StateQueued}, // retry
}
diff --git a/internal/task/task_test.go b/internal/task/task_test.go
index a8e0a84..6498271 100644
--- a/internal/task/task_test.go
+++ b/internal/task/task_test.go
@@ -45,6 +45,7 @@ func TestValidTransition_DisallowedTransitions(t *testing.T) {
{"completed to queued", StateCompleted, StateQueued},
{"failed to completed", StateFailed, StateCompleted},
{"timed out to completed", StateTimedOut, StateCompleted},
+ {"ready to queued", StateReady, StateQueued},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -55,6 +56,19 @@ func TestValidTransition_DisallowedTransitions(t *testing.T) {
}
}
+func TestValidTransition_ReadyState(t *testing.T) {
+ valid := []struct{ from, to State }{
+ {StateRunning, StateReady},
+ {StateReady, StateCompleted},
+ {StateReady, StatePending},
+ }
+ for _, tt := range valid {
+ if !ValidTransition(tt.from, tt.to) {
+ t.Errorf("expected transition %s -> %s to be valid", tt.from, tt.to)
+ }
+ }
+}
+
func TestDuration_UnmarshalYAML(t *testing.T) {
var d Duration
unmarshal := func(v interface{}) error {