summaryrefslogtreecommitdiff
path: root/internal/storage
diff options
context:
space:
mode:
Diffstat (limited to 'internal/storage')
-rw-r--r--internal/storage/db.go26
-rw-r--r--internal/storage/db_test.go34
2 files changed, 51 insertions, 9 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go
index c396bbe..0a4f7a5 100644
--- a/internal/storage/db.go
+++ b/internal/storage/db.go
@@ -193,21 +193,31 @@ func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) {
return tasks, rows.Err()
}
-// UpdateTaskState atomically updates a task's state.
+// UpdateTaskState atomically updates a task's state, enforcing valid transitions.
func (s *DB) UpdateTaskState(id string, newState task.State) error {
- now := time.Now().UTC()
- result, err := s.db.Exec(`UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?`, string(newState), now, id)
+ tx, err := s.db.Begin()
if err != nil {
return err
}
- n, err := result.RowsAffected()
- if err != nil {
+ defer tx.Rollback() //nolint:errcheck
+
+ var currentState string
+ if err := tx.QueryRow(`SELECT state FROM tasks WHERE id = ?`, id).Scan(&currentState); err != nil {
+ if err == sql.ErrNoRows {
+ return fmt.Errorf("task %q not found", id)
+ }
return err
}
- if n == 0 {
- return fmt.Errorf("task %q not found", id)
+
+ if !task.ValidTransition(task.State(currentState), newState) {
+ return fmt.Errorf("invalid state transition %s → %s for task %q", currentState, newState, id)
}
- return nil
+
+ now := time.Now().UTC()
+ if _, err := tx.Exec(`UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?`, string(newState), now, id); err != nil {
+ return err
+ }
+ return tx.Commit()
}
// RejectTask sets a task's state to PENDING and stores the rejection comment.
diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go
index 36f1644..f737096 100644
--- a/internal/storage/db_test.go
+++ b/internal/storage/db_test.go
@@ -41,7 +41,7 @@ func TestCreateTask_AndGetTask(t *testing.T) {
Type: "claude",
Model: "sonnet",
Instructions: "do it",
- WorkingDir: "/tmp",
+ ProjectDir: "/tmp",
MaxBudgetUSD: 2.5,
},
Priority: task.PriorityHigh,
@@ -124,6 +124,38 @@ func TestUpdateTaskState_NotFound(t *testing.T) {
}
}
+func TestUpdateTaskState_InvalidTransition(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC()
+ tk := &task.Task{
+ ID: "task-invalid",
+ Name: "InvalidTransition",
+ Claude: task.ClaudeConfig{Instructions: "test"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StatePending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+
+ // PENDING → COMPLETED is not a valid transition.
+ err := db.UpdateTaskState("task-invalid", task.StateCompleted)
+ if err == nil {
+ t.Fatal("expected error for invalid state transition PENDING → COMPLETED")
+ }
+
+ // State must not have changed.
+ got, _ := db.GetTask("task-invalid")
+ if got.State != task.StatePending {
+ t.Errorf("state must remain PENDING, got %v", got.State)
+ }
+}
+
func TestListTasks_FilterByState(t *testing.T) {
db := testDB(t)
now := time.Now().UTC()