From 704d007a26cac804148a51d35e129beaea382fb0 Mon Sep 17 00:00:00 2001 From: Peter Stone Date: Tue, 3 Mar 2026 21:15:01 +0000 Subject: Add subtask support: parent_task_id, ListSubtasks, UpdateTask - Task struct gains ParentTaskID field - DB schema adds parent_task_id column (additive migration) - DB.ListSubtasks fetches children of a parent task - DB.UpdateTask allows partial field updates (name, description, state, etc.) - Templates table added to initial schema Co-Authored-By: Claude Sonnet 4.6 --- internal/storage/db.go | 149 +++++++++++++++++++++++++++++++++++++++----- internal/storage/db_test.go | 89 ++++++++++++++++++++++++++ internal/task/task.go | 7 ++- internal/task/task_test.go | 2 + 4 files changed, 229 insertions(+), 18 deletions(-) (limited to 'internal') diff --git a/internal/storage/db.go b/internal/storage/db.go index b004cd1..0117ae7 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" + "strings" "time" "github.com/thepeterstone/claudomator/internal/task" @@ -43,6 +44,7 @@ func (s *DB) migrate() error { retry_json TEXT NOT NULL DEFAULT '{}', tags_json TEXT NOT NULL DEFAULT '[]', depends_on_json TEXT NOT NULL DEFAULT '[]', + parent_task_id TEXT, state TEXT NOT NULL DEFAULT 'PENDING', created_at DATETIME NOT NULL, updated_at DATETIME NOT NULL @@ -66,9 +68,40 @@ func (s *DB) migrate() error { CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state); CREATE INDEX IF NOT EXISTS idx_executions_status ON executions(status); CREATE INDEX IF NOT EXISTS idx_executions_task_id ON executions(task_id); + + CREATE TABLE IF NOT EXISTS templates ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + config_json TEXT NOT NULL DEFAULT '{}', + timeout TEXT, + priority TEXT NOT NULL DEFAULT 'normal', + tags_json TEXT NOT NULL DEFAULT '[]', + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + ); ` - _, err := s.db.Exec(schema) - return err + if _, err := s.db.Exec(schema); err != nil { + return err + } + // Additive migrations for columns added after initial schema. + migrations := []string{ + `ALTER TABLE tasks ADD COLUMN parent_task_id TEXT`, + } + for _, m := range migrations { + if _, err := s.db.Exec(m); err != nil { + // SQLite returns an error if the column already exists; ignore it. + if !isColumnExistsError(err) { + return err + } + } + } + return nil +} + +func isColumnExistsError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "duplicate column name") || strings.Contains(msg, "already exists") } // CreateTask inserts a task into the database. @@ -91,24 +124,24 @@ func (s *DB) CreateTask(t *task.Task) error { } _, err = s.db.Exec(` - INSERT INTO tasks (id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + INSERT INTO tasks (id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, t.ID, t.Name, t.Description, string(configJSON), string(t.Priority), t.Timeout.Duration.Nanoseconds(), string(retryJSON), string(tagsJSON), string(depsJSON), - string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), + t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), ) return err } // GetTask retrieves a task by ID. func (s *DB) GetTask(id string) (*task.Task, error) { - row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE id = ?`, id) + row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE id = ?`, id) return scanTask(row) } // ListTasks returns tasks matching the given filter. func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { - query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE 1=1` + query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE 1=1` var args []interface{} if filter.State != "" { @@ -138,6 +171,25 @@ func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { return tasks, rows.Err() } +// ListSubtasks returns all tasks whose parent_task_id matches the given ID. +func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) { + rows, err := s.db.Query(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) + if err != nil { + return nil, err + } + defer rows.Close() + + var tasks []*task.Task + for rows.Next() { + t, err := scanTaskRows(rows) + if err != nil { + return nil, err + } + tasks = append(tasks, t) + } + return tasks, rows.Err() +} + // UpdateTaskState atomically updates a task's state. func (s *DB) UpdateTaskState(id string, newState task.State) error { now := time.Now().UTC() @@ -155,6 +207,69 @@ func (s *DB) UpdateTaskState(id string, newState task.State) error { return nil } +// TaskUpdate holds the fields that UpdateTask may change. +type TaskUpdate struct { + Name string + Description string + Config task.ClaudeConfig + Priority task.Priority + TimeoutNS int64 + Retry task.RetryConfig + Tags []string + DependsOn []string +} + +// UpdateTask replaces editable fields on a task and resets its state to PENDING. +// Returns an error if the task does not exist. +func (s *DB) UpdateTask(id string, u TaskUpdate) error { + configJSON, err := json.Marshal(u.Config) + if err != nil { + return fmt.Errorf("marshaling config: %w", err) + } + retryJSON, err := json.Marshal(u.Retry) + if err != nil { + return fmt.Errorf("marshaling retry: %w", err) + } + tags := u.Tags + if tags == nil { + tags = []string{} + } + tagsJSON, err := json.Marshal(tags) + if err != nil { + return fmt.Errorf("marshaling tags: %w", err) + } + deps := u.DependsOn + if deps == nil { + deps = []string{} + } + depsJSON, err := json.Marshal(deps) + if err != nil { + return fmt.Errorf("marshaling depends_on: %w", err) + } + + now := time.Now().UTC() + result, err := s.db.Exec(` + UPDATE tasks + SET name = ?, description = ?, config_json = ?, priority = ?, timeout_ns = ?, + retry_json = ?, tags_json = ?, depends_on_json = ?, state = ?, updated_at = ? + WHERE id = ?`, + u.Name, u.Description, string(configJSON), string(u.Priority), u.TimeoutNS, + string(retryJSON), string(tagsJSON), string(depsJSON), string(task.StatePending), now, + id, + ) + if err != nil { + return err + } + n, err := result.RowsAffected() + if err != nil { + return err + } + if n == 0 { + return fmt.Errorf("task %q not found", id) + } + return nil +} + // TaskFilter specifies criteria for listing tasks. type TaskFilter struct { State task.State @@ -228,16 +343,18 @@ type scanner interface { func scanTask(row scanner) (*task.Task, error) { var ( - t task.Task - configJSON string - retryJSON string - tagsJSON string - depsJSON string - state string - priority string - timeoutNS int64 + t task.Task + configJSON string + retryJSON string + tagsJSON string + depsJSON string + state string + priority string + timeoutNS int64 + parentTaskID sql.NullString ) - err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &state, &t.CreatedAt, &t.UpdatedAt) + err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt) + t.ParentTaskID = parentTaskID.String if err != nil { return nil, err } diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index db6c8ad..7eb81d2 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -251,6 +251,95 @@ func TestListExecutions(t *testing.T) { } } +func TestDB_UpdateTask(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + db := testDB(t) + now := time.Now().UTC().Truncate(time.Second) + + tk := &task.Task{ + ID: "upd-1", + Name: "Original Name", + Description: "original desc", + Claude: task.ClaudeConfig{Model: "sonnet", Instructions: "original"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{"old"}, + DependsOn: []string{}, + State: task.StateCompleted, + CreatedAt: now, + UpdatedAt: now, + } + tk.Timeout.Duration = 5 * time.Minute + if err := db.CreateTask(tk); err != nil { + t.Fatalf("creating task: %v", err) + } + + u := TaskUpdate{ + Name: "Updated Name", + Description: "updated desc", + Config: task.ClaudeConfig{Model: "opus", Instructions: "updated"}, + Priority: task.PriorityHigh, + TimeoutNS: int64(15 * time.Minute), + Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "exponential"}, + Tags: []string{"new", "tag"}, + DependsOn: []string{"dep-1"}, + } + if err := db.UpdateTask("upd-1", u); err != nil { + t.Fatalf("UpdateTask: %v", err) + } + + got, err := db.GetTask("upd-1") + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.Name != "Updated Name" { + t.Errorf("name: want 'Updated Name', got %q", got.Name) + } + if got.Description != "updated desc" { + t.Errorf("description: want 'updated desc', got %q", got.Description) + } + if got.Claude.Model != "opus" { + t.Errorf("model: want 'opus', got %q", got.Claude.Model) + } + if got.Priority != task.PriorityHigh { + t.Errorf("priority: want 'high', got %q", got.Priority) + } + if got.Timeout.Duration != 15*time.Minute { + t.Errorf("timeout: want 15m, got %v", got.Timeout.Duration) + } + if got.Retry.MaxAttempts != 3 { + t.Errorf("retry.max_attempts: want 3, got %d", got.Retry.MaxAttempts) + } + if got.Retry.Backoff != "exponential" { + t.Errorf("retry.backoff: want 'exponential', got %q", got.Retry.Backoff) + } + if len(got.Tags) != 2 || got.Tags[0] != "new" || got.Tags[1] != "tag" { + t.Errorf("tags: want [new tag], got %v", got.Tags) + } + if len(got.DependsOn) != 1 || got.DependsOn[0] != "dep-1" { + t.Errorf("depends_on: want [dep-1], got %v", got.DependsOn) + } + if got.State != task.StatePending { + t.Errorf("state: want PENDING after update, got %v", got.State) + } + // id and created_at must be unchanged + if got.ID != "upd-1" { + t.Errorf("id changed: got %q", got.ID) + } + if !got.CreatedAt.Equal(now) { + t.Errorf("created_at changed: want %v, got %v", now, got.CreatedAt) + } + }) + + t.Run("not found", func(t *testing.T) { + db := testDB(t) + err := db.UpdateTask("nonexistent", TaskUpdate{Name: "x"}) + if err == nil { + t.Fatal("expected error for nonexistent task, got nil") + } + }) +} + func TestUpdateExecution(t *testing.T) { db := testDB(t) now := time.Now().UTC() diff --git a/internal/task/task.go b/internal/task/task.go index 5c28f63..d360a07 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -34,6 +34,7 @@ type ClaudeConfig struct { DisallowedTools []string `yaml:"disallowed_tools" json:"disallowed_tools"` SystemPromptAppend string `yaml:"system_prompt_append" json:"system_prompt_append"` AdditionalArgs []string `yaml:"additional_args" json:"additional_args"` + SkipPlanning bool `yaml:"skip_planning" json:"skip_planning"` } type RetryConfig struct { @@ -42,8 +43,9 @@ type RetryConfig struct { } type Task struct { - ID string `yaml:"id" json:"id"` - Name string `yaml:"name" json:"name"` + ID string `yaml:"id" json:"id"` + ParentTaskID string `yaml:"parent_task_id" json:"parent_task_id"` + Name string `yaml:"name" json:"name"` Description string `yaml:"description" json:"description"` Claude ClaudeConfig `yaml:"claude" json:"claude"` Timeout Duration `yaml:"timeout" json:"timeout"` @@ -90,6 +92,7 @@ func ValidTransition(from, to State) bool { StateQueued: {StateRunning, StateCancelled}, StateRunning: {StateCompleted, StateFailed, StateTimedOut, StateCancelled, StateBudgetExceeded}, StateFailed: {StateQueued}, // retry + StateTimedOut: {StateQueued}, // retry } for _, allowed := range transitions[from] { if allowed == to { diff --git a/internal/task/task_test.go b/internal/task/task_test.go index 96f5f6f..a8e0a84 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -21,6 +21,7 @@ func TestValidTransition_AllowedTransitions(t *testing.T) { {"running to cancelled", StateRunning, StateCancelled}, {"running to budget exceeded", StateRunning, StateBudgetExceeded}, {"failed to queued (retry)", StateFailed, StateQueued}, + {"timed out to queued (retry)", StateTimedOut, StateQueued}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -43,6 +44,7 @@ func TestValidTransition_DisallowedTransitions(t *testing.T) { {"completed to running", StateCompleted, StateRunning}, {"completed to queued", StateCompleted, StateQueued}, {"failed to completed", StateFailed, StateCompleted}, + {"timed out to completed", StateTimedOut, StateCompleted}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { -- cgit v1.2.3