summaryrefslogtreecommitdiff
path: root/internal/storage
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-03 21:15:01 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-03 21:15:01 +0000
commit704d007a26cac804148a51d35e129beaea382fb0 (patch)
tree5061ca129ea033e8689d0a5bdc9d7ddbb9c09f56 /internal/storage
parent58f1f0909b8329b1219c5de9d0df2b4c6c93fec9 (diff)
Add subtask support: parent_task_id, ListSubtasks, UpdateTask
- Task struct gains ParentTaskID field - DB schema adds parent_task_id column (additive migration) - DB.ListSubtasks fetches children of a parent task - DB.UpdateTask allows partial field updates (name, description, state, etc.) - Templates table added to initial schema Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal/storage')
-rw-r--r--internal/storage/db.go149
-rw-r--r--internal/storage/db_test.go89
2 files changed, 222 insertions, 16 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go
index b004cd1..0117ae7 100644
--- a/internal/storage/db.go
+++ b/internal/storage/db.go
@@ -4,6 +4,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
+ "strings"
"time"
"github.com/thepeterstone/claudomator/internal/task"
@@ -43,6 +44,7 @@ func (s *DB) migrate() error {
retry_json TEXT NOT NULL DEFAULT '{}',
tags_json TEXT NOT NULL DEFAULT '[]',
depends_on_json TEXT NOT NULL DEFAULT '[]',
+ parent_task_id TEXT,
state TEXT NOT NULL DEFAULT 'PENDING',
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
@@ -66,9 +68,40 @@ func (s *DB) migrate() error {
CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state);
CREATE INDEX IF NOT EXISTS idx_executions_status ON executions(status);
CREATE INDEX IF NOT EXISTS idx_executions_task_id ON executions(task_id);
+
+ CREATE TABLE IF NOT EXISTS templates (
+ id TEXT PRIMARY KEY,
+ name TEXT NOT NULL,
+ description TEXT,
+ config_json TEXT NOT NULL DEFAULT '{}',
+ timeout TEXT,
+ priority TEXT NOT NULL DEFAULT 'normal',
+ tags_json TEXT NOT NULL DEFAULT '[]',
+ created_at DATETIME NOT NULL,
+ updated_at DATETIME NOT NULL
+ );
`
- _, err := s.db.Exec(schema)
- return err
+ if _, err := s.db.Exec(schema); err != nil {
+ return err
+ }
+ // Additive migrations for columns added after initial schema.
+ migrations := []string{
+ `ALTER TABLE tasks ADD COLUMN parent_task_id TEXT`,
+ }
+ for _, m := range migrations {
+ if _, err := s.db.Exec(m); err != nil {
+ // SQLite returns an error if the column already exists; ignore it.
+ if !isColumnExistsError(err) {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func isColumnExistsError(err error) bool {
+ msg := err.Error()
+ return strings.Contains(msg, "duplicate column name") || strings.Contains(msg, "already exists")
}
// CreateTask inserts a task into the database.
@@ -91,24 +124,24 @@ func (s *DB) CreateTask(t *task.Task) error {
}
_, err = s.db.Exec(`
- INSERT INTO tasks (id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+ INSERT INTO tasks (id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
t.ID, t.Name, t.Description, string(configJSON), string(t.Priority),
t.Timeout.Duration.Nanoseconds(), string(retryJSON), string(tagsJSON), string(depsJSON),
- string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(),
+ t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(),
)
return err
}
// GetTask retrieves a task by ID.
func (s *DB) GetTask(id string) (*task.Task, error) {
- row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE id = ?`, id)
+ row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE id = ?`, id)
return scanTask(row)
}
// ListTasks returns tasks matching the given filter.
func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) {
- query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE 1=1`
+ query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE 1=1`
var args []interface{}
if filter.State != "" {
@@ -138,6 +171,25 @@ func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) {
return tasks, rows.Err()
}
+// ListSubtasks returns all tasks whose parent_task_id matches the given ID.
+func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) {
+ rows, err := s.db.Query(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var tasks []*task.Task
+ for rows.Next() {
+ t, err := scanTaskRows(rows)
+ if err != nil {
+ return nil, err
+ }
+ tasks = append(tasks, t)
+ }
+ return tasks, rows.Err()
+}
+
// UpdateTaskState atomically updates a task's state.
func (s *DB) UpdateTaskState(id string, newState task.State) error {
now := time.Now().UTC()
@@ -155,6 +207,69 @@ func (s *DB) UpdateTaskState(id string, newState task.State) error {
return nil
}
+// TaskUpdate holds the fields that UpdateTask may change.
+type TaskUpdate struct {
+ Name string
+ Description string
+ Config task.ClaudeConfig
+ Priority task.Priority
+ TimeoutNS int64
+ Retry task.RetryConfig
+ Tags []string
+ DependsOn []string
+}
+
+// UpdateTask replaces editable fields on a task and resets its state to PENDING.
+// Returns an error if the task does not exist.
+func (s *DB) UpdateTask(id string, u TaskUpdate) error {
+ configJSON, err := json.Marshal(u.Config)
+ if err != nil {
+ return fmt.Errorf("marshaling config: %w", err)
+ }
+ retryJSON, err := json.Marshal(u.Retry)
+ if err != nil {
+ return fmt.Errorf("marshaling retry: %w", err)
+ }
+ tags := u.Tags
+ if tags == nil {
+ tags = []string{}
+ }
+ tagsJSON, err := json.Marshal(tags)
+ if err != nil {
+ return fmt.Errorf("marshaling tags: %w", err)
+ }
+ deps := u.DependsOn
+ if deps == nil {
+ deps = []string{}
+ }
+ depsJSON, err := json.Marshal(deps)
+ if err != nil {
+ return fmt.Errorf("marshaling depends_on: %w", err)
+ }
+
+ now := time.Now().UTC()
+ result, err := s.db.Exec(`
+ UPDATE tasks
+ SET name = ?, description = ?, config_json = ?, priority = ?, timeout_ns = ?,
+ retry_json = ?, tags_json = ?, depends_on_json = ?, state = ?, updated_at = ?
+ WHERE id = ?`,
+ u.Name, u.Description, string(configJSON), string(u.Priority), u.TimeoutNS,
+ string(retryJSON), string(tagsJSON), string(depsJSON), string(task.StatePending), now,
+ id,
+ )
+ if err != nil {
+ return err
+ }
+ n, err := result.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if n == 0 {
+ return fmt.Errorf("task %q not found", id)
+ }
+ return nil
+}
+
// TaskFilter specifies criteria for listing tasks.
type TaskFilter struct {
State task.State
@@ -228,16 +343,18 @@ type scanner interface {
func scanTask(row scanner) (*task.Task, error) {
var (
- t task.Task
- configJSON string
- retryJSON string
- tagsJSON string
- depsJSON string
- state string
- priority string
- timeoutNS int64
+ t task.Task
+ configJSON string
+ retryJSON string
+ tagsJSON string
+ depsJSON string
+ state string
+ priority string
+ timeoutNS int64
+ parentTaskID sql.NullString
)
- err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &state, &t.CreatedAt, &t.UpdatedAt)
+ err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt)
+ t.ParentTaskID = parentTaskID.String
if err != nil {
return nil, err
}
diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go
index db6c8ad..7eb81d2 100644
--- a/internal/storage/db_test.go
+++ b/internal/storage/db_test.go
@@ -251,6 +251,95 @@ func TestListExecutions(t *testing.T) {
}
}
+func TestDB_UpdateTask(t *testing.T) {
+ t.Run("happy path", func(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC().Truncate(time.Second)
+
+ tk := &task.Task{
+ ID: "upd-1",
+ Name: "Original Name",
+ Description: "original desc",
+ Claude: task.ClaudeConfig{Model: "sonnet", Instructions: "original"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{"old"},
+ DependsOn: []string{},
+ State: task.StateCompleted,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ tk.Timeout.Duration = 5 * time.Minute
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatalf("creating task: %v", err)
+ }
+
+ u := TaskUpdate{
+ Name: "Updated Name",
+ Description: "updated desc",
+ Config: task.ClaudeConfig{Model: "opus", Instructions: "updated"},
+ Priority: task.PriorityHigh,
+ TimeoutNS: int64(15 * time.Minute),
+ Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "exponential"},
+ Tags: []string{"new", "tag"},
+ DependsOn: []string{"dep-1"},
+ }
+ if err := db.UpdateTask("upd-1", u); err != nil {
+ t.Fatalf("UpdateTask: %v", err)
+ }
+
+ got, err := db.GetTask("upd-1")
+ if err != nil {
+ t.Fatalf("GetTask: %v", err)
+ }
+ if got.Name != "Updated Name" {
+ t.Errorf("name: want 'Updated Name', got %q", got.Name)
+ }
+ if got.Description != "updated desc" {
+ t.Errorf("description: want 'updated desc', got %q", got.Description)
+ }
+ if got.Claude.Model != "opus" {
+ t.Errorf("model: want 'opus', got %q", got.Claude.Model)
+ }
+ if got.Priority != task.PriorityHigh {
+ t.Errorf("priority: want 'high', got %q", got.Priority)
+ }
+ if got.Timeout.Duration != 15*time.Minute {
+ t.Errorf("timeout: want 15m, got %v", got.Timeout.Duration)
+ }
+ if got.Retry.MaxAttempts != 3 {
+ t.Errorf("retry.max_attempts: want 3, got %d", got.Retry.MaxAttempts)
+ }
+ if got.Retry.Backoff != "exponential" {
+ t.Errorf("retry.backoff: want 'exponential', got %q", got.Retry.Backoff)
+ }
+ if len(got.Tags) != 2 || got.Tags[0] != "new" || got.Tags[1] != "tag" {
+ t.Errorf("tags: want [new tag], got %v", got.Tags)
+ }
+ if len(got.DependsOn) != 1 || got.DependsOn[0] != "dep-1" {
+ t.Errorf("depends_on: want [dep-1], got %v", got.DependsOn)
+ }
+ if got.State != task.StatePending {
+ t.Errorf("state: want PENDING after update, got %v", got.State)
+ }
+ // id and created_at must be unchanged
+ if got.ID != "upd-1" {
+ t.Errorf("id changed: got %q", got.ID)
+ }
+ if !got.CreatedAt.Equal(now) {
+ t.Errorf("created_at changed: want %v, got %v", now, got.CreatedAt)
+ }
+ })
+
+ t.Run("not found", func(t *testing.T) {
+ db := testDB(t)
+ err := db.UpdateTask("nonexistent", TaskUpdate{Name: "x"})
+ if err == nil {
+ t.Fatal("expected error for nonexistent task, got nil")
+ }
+ })
+}
+
func TestUpdateExecution(t *testing.T) {
db := testDB(t)
now := time.Now().UTC()