diff options
Diffstat (limited to 'internal/storage/db.go')
| -rw-r--r-- | internal/storage/db.go | 72 |
1 files changed, 38 insertions, 34 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go index b6af2c8..835ac29 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -360,53 +360,57 @@ func (s *DB) GetLatestExecution(taskID string) (*Execution, error) { } // DeleteTask removes a task and all its executions and subtasks (recursively). -// Returns an error if the task does not exist. +// Returns an error if the task does not exist. All deletes run in a single +// transaction so a partial failure cannot leave orphaned executions. func (s *DB) DeleteTask(id string) error { - // Collect all task IDs to delete (the task + all descendant subtasks). - toDelete := []string{id} - // BFS over children. - queue := []string{id} - for len(queue) > 0 { - parentID := queue[0] - queue = queue[1:] - rows, err := s.db.Query(`SELECT id FROM tasks WHERE parent_task_id = ?`, parentID) - if err != nil { - return fmt.Errorf("listing subtasks of %q: %w", parentID, err) - } - for rows.Next() { - var childID string - if err := rows.Scan(&childID); err != nil { - rows.Close() - return err - } - toDelete = append(toDelete, childID) - queue = append(queue, childID) - } - rows.Close() - if err := rows.Err(); err != nil { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() //nolint:errcheck + + // Collect all task IDs (root + all descendants) via recursive CTE. + rows, err := tx.Query(` + WITH RECURSIVE subtasks(id) AS ( + SELECT id FROM tasks WHERE id = ? + UNION ALL + SELECT t.id FROM tasks t JOIN subtasks s ON t.parent_task_id = s.id + ) + SELECT id FROM subtasks`, id) + if err != nil { + return fmt.Errorf("collecting subtask ids: %w", err) + } + var toDelete []string + for rows.Next() { + var tid string + if err := rows.Scan(&tid); err != nil { + rows.Close() return err } + toDelete = append(toDelete, tid) + } + rows.Close() + if err := rows.Err(); err != nil { + return err + } + + if len(toDelete) == 0 { + return fmt.Errorf("task %q not found", id) } - // Delete executions for all collected tasks then the tasks themselves. + // Delete executions for all collected tasks, then the tasks themselves. for _, tid := range toDelete { - if _, err := s.db.Exec(`DELETE FROM executions WHERE task_id = ?`, tid); err != nil { + if _, err := tx.Exec(`DELETE FROM executions WHERE task_id = ?`, tid); err != nil { return fmt.Errorf("deleting executions for task %q: %w", tid, err) } } for _, tid := range toDelete { - result, err := s.db.Exec(`DELETE FROM tasks WHERE id = ?`, tid) - if err != nil { + if _, err := tx.Exec(`DELETE FROM tasks WHERE id = ?`, tid); err != nil { return fmt.Errorf("deleting task %q: %w", tid, err) } - if tid == id { - n, _ := result.RowsAffected() - if n == 0 { - return fmt.Errorf("task %q not found", id) - } - } } - return nil + + return tx.Commit() } // RecentExecution is returned by ListRecentExecutions (JOIN with tasks for name). |
