summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/storage/db.go72
-rw-r--r--internal/storage/db_test.go64
2 files changed, 102 insertions, 34 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go
index b6af2c8..835ac29 100644
--- a/internal/storage/db.go
+++ b/internal/storage/db.go
@@ -360,53 +360,57 @@ func (s *DB) GetLatestExecution(taskID string) (*Execution, error) {
}
// DeleteTask removes a task and all its executions and subtasks (recursively).
-// Returns an error if the task does not exist.
+// Returns an error if the task does not exist. All deletes run in a single
+// transaction so a partial failure cannot leave orphaned executions.
func (s *DB) DeleteTask(id string) error {
- // Collect all task IDs to delete (the task + all descendant subtasks).
- toDelete := []string{id}
- // BFS over children.
- queue := []string{id}
- for len(queue) > 0 {
- parentID := queue[0]
- queue = queue[1:]
- rows, err := s.db.Query(`SELECT id FROM tasks WHERE parent_task_id = ?`, parentID)
- if err != nil {
- return fmt.Errorf("listing subtasks of %q: %w", parentID, err)
- }
- for rows.Next() {
- var childID string
- if err := rows.Scan(&childID); err != nil {
- rows.Close()
- return err
- }
- toDelete = append(toDelete, childID)
- queue = append(queue, childID)
- }
- rows.Close()
- if err := rows.Err(); err != nil {
+ tx, err := s.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback() //nolint:errcheck
+
+ // Collect all task IDs (root + all descendants) via recursive CTE.
+ rows, err := tx.Query(`
+ WITH RECURSIVE subtasks(id) AS (
+ SELECT id FROM tasks WHERE id = ?
+ UNION ALL
+ SELECT t.id FROM tasks t JOIN subtasks s ON t.parent_task_id = s.id
+ )
+ SELECT id FROM subtasks`, id)
+ if err != nil {
+ return fmt.Errorf("collecting subtask ids: %w", err)
+ }
+ var toDelete []string
+ for rows.Next() {
+ var tid string
+ if err := rows.Scan(&tid); err != nil {
+ rows.Close()
return err
}
+ toDelete = append(toDelete, tid)
+ }
+ rows.Close()
+ if err := rows.Err(); err != nil {
+ return err
+ }
+
+ if len(toDelete) == 0 {
+ return fmt.Errorf("task %q not found", id)
}
- // Delete executions for all collected tasks then the tasks themselves.
+ // Delete executions for all collected tasks, then the tasks themselves.
for _, tid := range toDelete {
- if _, err := s.db.Exec(`DELETE FROM executions WHERE task_id = ?`, tid); err != nil {
+ if _, err := tx.Exec(`DELETE FROM executions WHERE task_id = ?`, tid); err != nil {
return fmt.Errorf("deleting executions for task %q: %w", tid, err)
}
}
for _, tid := range toDelete {
- result, err := s.db.Exec(`DELETE FROM tasks WHERE id = ?`, tid)
- if err != nil {
+ if _, err := tx.Exec(`DELETE FROM tasks WHERE id = ?`, tid); err != nil {
return fmt.Errorf("deleting task %q: %w", tid, err)
}
- if tid == id {
- n, _ := result.RowsAffected()
- if n == 0 {
- return fmt.Errorf("task %q not found", id)
- }
- }
}
- return nil
+
+ return tx.Commit()
}
// RecentExecution is returned by ListRecentExecutions (JOIN with tasks for name).
diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go
index 8b10817..5f786ac 100644
--- a/internal/storage/db_test.go
+++ b/internal/storage/db_test.go
@@ -564,6 +564,70 @@ func TestDeleteTask_NotFound(t *testing.T) {
}
}
+func TestDeleteTask_DeepSubtaskCascadeAtomic(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC()
+
+ // 3-level hierarchy: root -> child -> grandchild
+ root := makeTestTask("deep-root", now)
+ child := makeTestTask("deep-child", now)
+ child.ParentTaskID = "deep-root"
+ grandchild := makeTestTask("deep-grandchild", now)
+ grandchild.ParentTaskID = "deep-child"
+
+ for _, tk := range []*task.Task{root, child, grandchild} {
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatalf("creating task %q: %v", tk.ID, err)
+ }
+ }
+
+ // Add one execution per level.
+ for i, tid := range []string{"deep-root", "deep-child", "deep-grandchild"} {
+ e := &Execution{
+ ID: fmt.Sprintf("deep-exec-%d", i),
+ TaskID: tid,
+ StartTime: now,
+ Status: "COMPLETED",
+ }
+ if err := db.CreateExecution(e); err != nil {
+ t.Fatalf("creating execution for %q: %v", tid, err)
+ }
+ }
+
+ if err := db.DeleteTask("deep-root"); err != nil {
+ t.Fatalf("DeleteTask: %v", err)
+ }
+
+ // All three tasks must be gone.
+ for _, tid := range []string{"deep-root", "deep-child", "deep-grandchild"} {
+ _, err := db.GetTask(tid)
+ if err == nil {
+ t.Errorf("task %q should have been deleted", tid)
+ }
+ }
+
+ // No executions should remain for any deleted task (no orphans).
+ rows, err := db.db.Query(`
+ SELECT e.id FROM executions e
+ LEFT JOIN tasks t ON e.task_id = t.id
+ WHERE t.id IS NULL`)
+ if err != nil {
+ t.Fatalf("orphan check query: %v", err)
+ }
+ defer rows.Close()
+ var orphans []string
+ for rows.Next() {
+ var eid string
+ if err := rows.Scan(&eid); err != nil {
+ t.Fatal(err)
+ }
+ orphans = append(orphans, eid)
+ }
+ if len(orphans) != 0 {
+ t.Errorf("orphaned execution rows after DeleteTask: %v", orphans)
+ }
+}
+
func TestStorage_GetLatestExecution(t *testing.T) {
db := testDB(t)
now := time.Now().UTC()