summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/storage/db.go54
1 files changed, 52 insertions, 2 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go
index 0e4d6f1..1aac754 100644
--- a/internal/storage/db.go
+++ b/internal/storage/db.go
@@ -359,6 +359,56 @@ func (s *DB) GetLatestExecution(taskID string) (*Execution, error) {
return scanExecution(row)
}
+// DeleteTask removes a task and all its executions and subtasks (recursively).
+// Returns an error if the task does not exist.
+func (s *DB) DeleteTask(id string) error {
+ // Collect all task IDs to delete (the task + all descendant subtasks).
+ toDelete := []string{id}
+ // BFS over children.
+ queue := []string{id}
+ for len(queue) > 0 {
+ parentID := queue[0]
+ queue = queue[1:]
+ rows, err := s.db.Query(`SELECT id FROM tasks WHERE parent_task_id = ?`, parentID)
+ if err != nil {
+ return fmt.Errorf("listing subtasks of %q: %w", parentID, err)
+ }
+ for rows.Next() {
+ var childID string
+ if err := rows.Scan(&childID); err != nil {
+ rows.Close()
+ return err
+ }
+ toDelete = append(toDelete, childID)
+ queue = append(queue, childID)
+ }
+ rows.Close()
+ if err := rows.Err(); err != nil {
+ return err
+ }
+ }
+
+ // Delete executions for all collected tasks then the tasks themselves.
+ for _, tid := range toDelete {
+ if _, err := s.db.Exec(`DELETE FROM executions WHERE task_id = ?`, tid); err != nil {
+ return fmt.Errorf("deleting executions for task %q: %w", tid, err)
+ }
+ }
+ for _, tid := range toDelete {
+ result, err := s.db.Exec(`DELETE FROM tasks WHERE id = ?`, tid)
+ if err != nil {
+ return fmt.Errorf("deleting task %q: %w", tid, err)
+ }
+ if tid == id {
+ n, _ := result.RowsAffected()
+ if n == 0 {
+ return fmt.Errorf("task %q not found", id)
+ }
+ }
+ }
+ return nil
+}
+
// UpdateTaskQuestion stores the pending question JSON on a task.
// Pass empty string to clear the question after it has been answered.
func (s *DB) UpdateTaskQuestion(taskID, questionJSON string) error {
@@ -371,10 +421,10 @@ func (s *DB) UpdateTaskQuestion(taskID, questionJSON string) error {
func (s *DB) UpdateExecution(e *Execution) error {
_, err := s.db.Exec(`
UPDATE executions SET end_time = ?, exit_code = ?, status = ?, cost_usd = ?, error_msg = ?,
- stdout_path = ?, stderr_path = ?, artifact_dir = ?
+ stdout_path = ?, stderr_path = ?, artifact_dir = ?, session_id = ?
WHERE id = ?`,
e.EndTime.UTC(), e.ExitCode, e.Status, e.CostUSD, e.ErrorMsg,
- e.StdoutPath, e.StderrPath, e.ArtifactDir, e.ID,
+ e.StdoutPath, e.StderrPath, e.ArtifactDir, e.SessionID, e.ID,
)
return err
}