diff options
Diffstat (limited to 'internal/storage/db.go')
| -rw-r--r-- | internal/storage/db.go | 149 |
1 files changed, 133 insertions, 16 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go index b004cd1..0117ae7 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" + "strings" "time" "github.com/thepeterstone/claudomator/internal/task" @@ -43,6 +44,7 @@ func (s *DB) migrate() error { retry_json TEXT NOT NULL DEFAULT '{}', tags_json TEXT NOT NULL DEFAULT '[]', depends_on_json TEXT NOT NULL DEFAULT '[]', + parent_task_id TEXT, state TEXT NOT NULL DEFAULT 'PENDING', created_at DATETIME NOT NULL, updated_at DATETIME NOT NULL @@ -66,9 +68,40 @@ func (s *DB) migrate() error { CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state); CREATE INDEX IF NOT EXISTS idx_executions_status ON executions(status); CREATE INDEX IF NOT EXISTS idx_executions_task_id ON executions(task_id); + + CREATE TABLE IF NOT EXISTS templates ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + config_json TEXT NOT NULL DEFAULT '{}', + timeout TEXT, + priority TEXT NOT NULL DEFAULT 'normal', + tags_json TEXT NOT NULL DEFAULT '[]', + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + ); ` - _, err := s.db.Exec(schema) - return err + if _, err := s.db.Exec(schema); err != nil { + return err + } + // Additive migrations for columns added after initial schema. + migrations := []string{ + `ALTER TABLE tasks ADD COLUMN parent_task_id TEXT`, + } + for _, m := range migrations { + if _, err := s.db.Exec(m); err != nil { + // SQLite returns an error if the column already exists; ignore it. + if !isColumnExistsError(err) { + return err + } + } + } + return nil +} + +func isColumnExistsError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "duplicate column name") || strings.Contains(msg, "already exists") } // CreateTask inserts a task into the database. @@ -91,24 +124,24 @@ func (s *DB) CreateTask(t *task.Task) error { } _, err = s.db.Exec(` - INSERT INTO tasks (id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + INSERT INTO tasks (id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, t.ID, t.Name, t.Description, string(configJSON), string(t.Priority), t.Timeout.Duration.Nanoseconds(), string(retryJSON), string(tagsJSON), string(depsJSON), - string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), + t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), ) return err } // GetTask retrieves a task by ID. func (s *DB) GetTask(id string) (*task.Task, error) { - row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE id = ?`, id) + row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE id = ?`, id) return scanTask(row) } // ListTasks returns tasks matching the given filter. func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { - query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE 1=1` + query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE 1=1` var args []interface{} if filter.State != "" { @@ -138,6 +171,25 @@ func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { return tasks, rows.Err() } +// ListSubtasks returns all tasks whose parent_task_id matches the given ID. +func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) { + rows, err := s.db.Query(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) + if err != nil { + return nil, err + } + defer rows.Close() + + var tasks []*task.Task + for rows.Next() { + t, err := scanTaskRows(rows) + if err != nil { + return nil, err + } + tasks = append(tasks, t) + } + return tasks, rows.Err() +} + // UpdateTaskState atomically updates a task's state. func (s *DB) UpdateTaskState(id string, newState task.State) error { now := time.Now().UTC() @@ -155,6 +207,69 @@ func (s *DB) UpdateTaskState(id string, newState task.State) error { return nil } +// TaskUpdate holds the fields that UpdateTask may change. +type TaskUpdate struct { + Name string + Description string + Config task.ClaudeConfig + Priority task.Priority + TimeoutNS int64 + Retry task.RetryConfig + Tags []string + DependsOn []string +} + +// UpdateTask replaces editable fields on a task and resets its state to PENDING. +// Returns an error if the task does not exist. +func (s *DB) UpdateTask(id string, u TaskUpdate) error { + configJSON, err := json.Marshal(u.Config) + if err != nil { + return fmt.Errorf("marshaling config: %w", err) + } + retryJSON, err := json.Marshal(u.Retry) + if err != nil { + return fmt.Errorf("marshaling retry: %w", err) + } + tags := u.Tags + if tags == nil { + tags = []string{} + } + tagsJSON, err := json.Marshal(tags) + if err != nil { + return fmt.Errorf("marshaling tags: %w", err) + } + deps := u.DependsOn + if deps == nil { + deps = []string{} + } + depsJSON, err := json.Marshal(deps) + if err != nil { + return fmt.Errorf("marshaling depends_on: %w", err) + } + + now := time.Now().UTC() + result, err := s.db.Exec(` + UPDATE tasks + SET name = ?, description = ?, config_json = ?, priority = ?, timeout_ns = ?, + retry_json = ?, tags_json = ?, depends_on_json = ?, state = ?, updated_at = ? + WHERE id = ?`, + u.Name, u.Description, string(configJSON), string(u.Priority), u.TimeoutNS, + string(retryJSON), string(tagsJSON), string(depsJSON), string(task.StatePending), now, + id, + ) + if err != nil { + return err + } + n, err := result.RowsAffected() + if err != nil { + return err + } + if n == 0 { + return fmt.Errorf("task %q not found", id) + } + return nil +} + // TaskFilter specifies criteria for listing tasks. type TaskFilter struct { State task.State @@ -228,16 +343,18 @@ type scanner interface { func scanTask(row scanner) (*task.Task, error) { var ( - t task.Task - configJSON string - retryJSON string - tagsJSON string - depsJSON string - state string - priority string - timeoutNS int64 + t task.Task + configJSON string + retryJSON string + tagsJSON string + depsJSON string + state string + priority string + timeoutNS int64 + parentTaskID sql.NullString ) - err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &state, &t.CreatedAt, &t.UpdatedAt) + err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt) + t.ParentTaskID = parentTaskID.String if err != nil { return nil, err } |
