diff options
| -rw-r--r-- | internal/storage/db.go | 72 | ||||
| -rw-r--r-- | internal/storage/db_test.go | 64 |
2 files changed, 102 insertions, 34 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go index b6af2c8..835ac29 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -360,53 +360,57 @@ func (s *DB) GetLatestExecution(taskID string) (*Execution, error) { } // DeleteTask removes a task and all its executions and subtasks (recursively). -// Returns an error if the task does not exist. +// Returns an error if the task does not exist. All deletes run in a single +// transaction so a partial failure cannot leave orphaned executions. func (s *DB) DeleteTask(id string) error { - // Collect all task IDs to delete (the task + all descendant subtasks). - toDelete := []string{id} - // BFS over children. - queue := []string{id} - for len(queue) > 0 { - parentID := queue[0] - queue = queue[1:] - rows, err := s.db.Query(`SELECT id FROM tasks WHERE parent_task_id = ?`, parentID) - if err != nil { - return fmt.Errorf("listing subtasks of %q: %w", parentID, err) - } - for rows.Next() { - var childID string - if err := rows.Scan(&childID); err != nil { - rows.Close() - return err - } - toDelete = append(toDelete, childID) - queue = append(queue, childID) - } - rows.Close() - if err := rows.Err(); err != nil { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() //nolint:errcheck + + // Collect all task IDs (root + all descendants) via recursive CTE. + rows, err := tx.Query(` + WITH RECURSIVE subtasks(id) AS ( + SELECT id FROM tasks WHERE id = ? + UNION ALL + SELECT t.id FROM tasks t JOIN subtasks s ON t.parent_task_id = s.id + ) + SELECT id FROM subtasks`, id) + if err != nil { + return fmt.Errorf("collecting subtask ids: %w", err) + } + var toDelete []string + for rows.Next() { + var tid string + if err := rows.Scan(&tid); err != nil { + rows.Close() return err } + toDelete = append(toDelete, tid) + } + rows.Close() + if err := rows.Err(); err != nil { + return err + } + + if len(toDelete) == 0 { + return fmt.Errorf("task %q not found", id) } - // Delete executions for all collected tasks then the tasks themselves. + // Delete executions for all collected tasks, then the tasks themselves. for _, tid := range toDelete { - if _, err := s.db.Exec(`DELETE FROM executions WHERE task_id = ?`, tid); err != nil { + if _, err := tx.Exec(`DELETE FROM executions WHERE task_id = ?`, tid); err != nil { return fmt.Errorf("deleting executions for task %q: %w", tid, err) } } for _, tid := range toDelete { - result, err := s.db.Exec(`DELETE FROM tasks WHERE id = ?`, tid) - if err != nil { + if _, err := tx.Exec(`DELETE FROM tasks WHERE id = ?`, tid); err != nil { return fmt.Errorf("deleting task %q: %w", tid, err) } - if tid == id { - n, _ := result.RowsAffected() - if n == 0 { - return fmt.Errorf("task %q not found", id) - } - } } - return nil + + return tx.Commit() } // RecentExecution is returned by ListRecentExecutions (JOIN with tasks for name). diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index 8b10817..5f786ac 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -564,6 +564,70 @@ func TestDeleteTask_NotFound(t *testing.T) { } } +func TestDeleteTask_DeepSubtaskCascadeAtomic(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + + // 3-level hierarchy: root -> child -> grandchild + root := makeTestTask("deep-root", now) + child := makeTestTask("deep-child", now) + child.ParentTaskID = "deep-root" + grandchild := makeTestTask("deep-grandchild", now) + grandchild.ParentTaskID = "deep-child" + + for _, tk := range []*task.Task{root, child, grandchild} { + if err := db.CreateTask(tk); err != nil { + t.Fatalf("creating task %q: %v", tk.ID, err) + } + } + + // Add one execution per level. + for i, tid := range []string{"deep-root", "deep-child", "deep-grandchild"} { + e := &Execution{ + ID: fmt.Sprintf("deep-exec-%d", i), + TaskID: tid, + StartTime: now, + Status: "COMPLETED", + } + if err := db.CreateExecution(e); err != nil { + t.Fatalf("creating execution for %q: %v", tid, err) + } + } + + if err := db.DeleteTask("deep-root"); err != nil { + t.Fatalf("DeleteTask: %v", err) + } + + // All three tasks must be gone. + for _, tid := range []string{"deep-root", "deep-child", "deep-grandchild"} { + _, err := db.GetTask(tid) + if err == nil { + t.Errorf("task %q should have been deleted", tid) + } + } + + // No executions should remain for any deleted task (no orphans). + rows, err := db.db.Query(` + SELECT e.id FROM executions e + LEFT JOIN tasks t ON e.task_id = t.id + WHERE t.id IS NULL`) + if err != nil { + t.Fatalf("orphan check query: %v", err) + } + defer rows.Close() + var orphans []string + for rows.Next() { + var eid string + if err := rows.Scan(&eid); err != nil { + t.Fatal(err) + } + orphans = append(orphans, eid) + } + if len(orphans) != 0 { + t.Errorf("orphaned execution rows after DeleteTask: %v", orphans) + } +} + func TestStorage_GetLatestExecution(t *testing.T) { db := testDB(t) now := time.Now().UTC() |
