summaryrefslogtreecommitdiff
path: root/internal/task
diff options
context:
space:
mode:
Diffstat (limited to 'internal/task')
-rw-r--r--internal/task/parser.go61
-rw-r--r--internal/task/parser_test.go152
-rw-r--r--internal/task/task.go100
-rw-r--r--internal/task/task_test.go80
-rw-r--r--internal/task/validator.go65
-rw-r--r--internal/task/validator_test.go115
6 files changed, 573 insertions, 0 deletions
diff --git a/internal/task/parser.go b/internal/task/parser.go
new file mode 100644
index 0000000..7a450b8
--- /dev/null
+++ b/internal/task/parser.go
@@ -0,0 +1,61 @@
+package task
+
+import (
+ "fmt"
+ "os"
+ "time"
+
+ "github.com/google/uuid"
+ "gopkg.in/yaml.v3"
+)
+
+// ParseFile reads a YAML file and returns tasks. Supports both single-task
+// and batch (tasks: [...]) formats.
+func ParseFile(path string) ([]Task, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, fmt.Errorf("reading task file: %w", err)
+ }
+ return Parse(data)
+}
+
+// Parse parses YAML bytes into tasks.
+func Parse(data []byte) ([]Task, error) {
+ // Try batch format first.
+ var batch BatchFile
+ if err := yaml.Unmarshal(data, &batch); err == nil && len(batch.Tasks) > 0 {
+ return initTasks(batch.Tasks), nil
+ }
+
+ // Try single task.
+ var t Task
+ if err := yaml.Unmarshal(data, &t); err != nil {
+ return nil, fmt.Errorf("parsing task YAML: %w", err)
+ }
+ if t.Name == "" {
+ return nil, fmt.Errorf("task must have a name")
+ }
+ return initTasks([]Task{t}), nil
+}
+
+func initTasks(tasks []Task) []Task {
+ now := time.Now()
+ for i := range tasks {
+ if tasks[i].ID == "" {
+ tasks[i].ID = uuid.New().String()
+ }
+ if tasks[i].Priority == "" {
+ tasks[i].Priority = PriorityNormal
+ }
+ if tasks[i].Retry.MaxAttempts == 0 {
+ tasks[i].Retry.MaxAttempts = 1
+ }
+ if tasks[i].Retry.Backoff == "" {
+ tasks[i].Retry.Backoff = "exponential"
+ }
+ tasks[i].State = StatePending
+ tasks[i].CreatedAt = now
+ tasks[i].UpdatedAt = now
+ }
+ return tasks
+}
diff --git a/internal/task/parser_test.go b/internal/task/parser_test.go
new file mode 100644
index 0000000..cb68e86
--- /dev/null
+++ b/internal/task/parser_test.go
@@ -0,0 +1,152 @@
+package task
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+)
+
+func TestParse_SingleTask(t *testing.T) {
+ yaml := `
+name: "Test Task"
+description: "A simple test"
+claude:
+ model: "sonnet"
+ instructions: "Do something"
+ working_dir: "/tmp"
+timeout: "10m"
+tags:
+ - "test"
+`
+ tasks, err := Parse([]byte(yaml))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(tasks) != 1 {
+ t.Fatalf("expected 1 task, got %d", len(tasks))
+ }
+ task := tasks[0]
+ if task.Name != "Test Task" {
+ t.Errorf("expected name 'Test Task', got %q", task.Name)
+ }
+ if task.Claude.Model != "sonnet" {
+ t.Errorf("expected model 'sonnet', got %q", task.Claude.Model)
+ }
+ if task.Timeout.Duration != 10*time.Minute {
+ t.Errorf("expected timeout 10m, got %v", task.Timeout.Duration)
+ }
+ if task.State != StatePending {
+ t.Errorf("expected state PENDING, got %v", task.State)
+ }
+ if task.ID == "" {
+ t.Error("expected auto-generated ID")
+ }
+ if task.Priority != PriorityNormal {
+ t.Errorf("expected default priority 'normal', got %q", task.Priority)
+ }
+}
+
+func TestParse_BatchTasks(t *testing.T) {
+ yaml := `
+tasks:
+ - name: "Task A"
+ claude:
+ instructions: "Do A"
+ working_dir: "/tmp"
+ tags: ["alpha"]
+ - name: "Task B"
+ claude:
+ instructions: "Do B"
+ working_dir: "/tmp"
+ tags: ["beta"]
+`
+ tasks, err := Parse([]byte(yaml))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(tasks) != 2 {
+ t.Fatalf("expected 2 tasks, got %d", len(tasks))
+ }
+ if tasks[0].Name != "Task A" {
+ t.Errorf("expected 'Task A', got %q", tasks[0].Name)
+ }
+ if tasks[1].Name != "Task B" {
+ t.Errorf("expected 'Task B', got %q", tasks[1].Name)
+ }
+}
+
+func TestParse_MissingName_ReturnsError(t *testing.T) {
+ yaml := `
+description: "no name"
+claude:
+ instructions: "something"
+`
+ _, err := Parse([]byte(yaml))
+ if err == nil {
+ t.Fatal("expected error for missing name")
+ }
+}
+
+func TestParse_DefaultRetryConfig(t *testing.T) {
+ yaml := `
+name: "Defaults"
+claude:
+ instructions: "test"
+`
+ tasks, err := Parse([]byte(yaml))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if tasks[0].Retry.MaxAttempts != 1 {
+ t.Errorf("expected default max_attempts=1, got %d", tasks[0].Retry.MaxAttempts)
+ }
+ if tasks[0].Retry.Backoff != "exponential" {
+ t.Errorf("expected default backoff 'exponential', got %q", tasks[0].Retry.Backoff)
+ }
+}
+
+func TestParse_WithPriority(t *testing.T) {
+ yaml := `
+name: "High Priority"
+priority: "high"
+claude:
+ instructions: "urgent"
+`
+ tasks, err := Parse([]byte(yaml))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if tasks[0].Priority != PriorityHigh {
+ t.Errorf("expected priority 'high', got %q", tasks[0].Priority)
+ }
+}
+
+func TestParseFile(t *testing.T) {
+ dir := t.TempDir()
+ path := filepath.Join(dir, "task.yaml")
+ content := `
+name: "File Task"
+claude:
+ instructions: "from file"
+ working_dir: "/tmp"
+`
+ if err := os.WriteFile(path, []byte(content), 0644); err != nil {
+ t.Fatal(err)
+ }
+
+ tasks, err := ParseFile(path)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(tasks) != 1 || tasks[0].Name != "File Task" {
+ t.Errorf("unexpected tasks: %+v", tasks)
+ }
+}
+
+func TestParseFile_NotFound(t *testing.T) {
+ _, err := ParseFile("/nonexistent/task.yaml")
+ if err == nil {
+ t.Fatal("expected error for nonexistent file")
+ }
+}
diff --git a/internal/task/task.go b/internal/task/task.go
new file mode 100644
index 0000000..3796cf3
--- /dev/null
+++ b/internal/task/task.go
@@ -0,0 +1,100 @@
+package task
+
+import "time"
+
+type State string
+
+const (
+ StatePending State = "PENDING"
+ StateQueued State = "QUEUED"
+ StateRunning State = "RUNNING"
+ StateCompleted State = "COMPLETED"
+ StateFailed State = "FAILED"
+ StateTimedOut State = "TIMED_OUT"
+ StateCancelled State = "CANCELLED"
+ StateBudgetExceeded State = "BUDGET_EXCEEDED"
+)
+
+type Priority string
+
+const (
+ PriorityHigh Priority = "high"
+ PriorityNormal Priority = "normal"
+ PriorityLow Priority = "low"
+)
+
+type ClaudeConfig struct {
+ Model string `yaml:"model"`
+ ContextFiles []string `yaml:"context_files"`
+ Instructions string `yaml:"instructions"`
+ WorkingDir string `yaml:"working_dir"`
+ MaxBudgetUSD float64 `yaml:"max_budget_usd"`
+ PermissionMode string `yaml:"permission_mode"`
+ AllowedTools []string `yaml:"allowed_tools"`
+ DisallowedTools []string `yaml:"disallowed_tools"`
+ SystemPromptAppend string `yaml:"system_prompt_append"`
+ AdditionalArgs []string `yaml:"additional_args"`
+}
+
+type RetryConfig struct {
+ MaxAttempts int `yaml:"max_attempts"`
+ Backoff string `yaml:"backoff"` // "linear", "exponential"
+}
+
+type Task struct {
+ ID string `yaml:"id"`
+ Name string `yaml:"name"`
+ Description string `yaml:"description"`
+ Claude ClaudeConfig `yaml:"claude"`
+ Timeout Duration `yaml:"timeout"`
+ Retry RetryConfig `yaml:"retry"`
+ Priority Priority `yaml:"priority"`
+ Tags []string `yaml:"tags"`
+ DependsOn []string `yaml:"depends_on"`
+ State State `yaml:"-"`
+ CreatedAt time.Time `yaml:"-"`
+ UpdatedAt time.Time `yaml:"-"`
+}
+
+// Duration wraps time.Duration for YAML unmarshaling from strings like "30m".
+type Duration struct {
+ time.Duration
+}
+
+func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
+ var s string
+ if err := unmarshal(&s); err != nil {
+ return err
+ }
+ dur, err := time.ParseDuration(s)
+ if err != nil {
+ return err
+ }
+ d.Duration = dur
+ return nil
+}
+
+func (d Duration) MarshalYAML() (interface{}, error) {
+ return d.Duration.String(), nil
+}
+
+// BatchFile represents a YAML file containing multiple tasks.
+type BatchFile struct {
+ Tasks []Task `yaml:"tasks"`
+}
+
+// ValidTransition returns true if moving from the current state to next is allowed.
+func ValidTransition(from, to State) bool {
+ transitions := map[State][]State{
+ StatePending: {StateQueued, StateCancelled},
+ StateQueued: {StateRunning, StateCancelled},
+ StateRunning: {StateCompleted, StateFailed, StateTimedOut, StateCancelled, StateBudgetExceeded},
+ StateFailed: {StateQueued}, // retry
+ }
+ for _, allowed := range transitions[from] {
+ if allowed == to {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/task/task_test.go b/internal/task/task_test.go
new file mode 100644
index 0000000..96f5f6f
--- /dev/null
+++ b/internal/task/task_test.go
@@ -0,0 +1,80 @@
+package task
+
+import (
+ "testing"
+ "time"
+)
+
+func TestValidTransition_AllowedTransitions(t *testing.T) {
+ tests := []struct {
+ name string
+ from State
+ to State
+ }{
+ {"pending to queued", StatePending, StateQueued},
+ {"pending to cancelled", StatePending, StateCancelled},
+ {"queued to running", StateQueued, StateRunning},
+ {"queued to cancelled", StateQueued, StateCancelled},
+ {"running to completed", StateRunning, StateCompleted},
+ {"running to failed", StateRunning, StateFailed},
+ {"running to timed out", StateRunning, StateTimedOut},
+ {"running to cancelled", StateRunning, StateCancelled},
+ {"running to budget exceeded", StateRunning, StateBudgetExceeded},
+ {"failed to queued (retry)", StateFailed, StateQueued},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if !ValidTransition(tt.from, tt.to) {
+ t.Errorf("expected transition %s -> %s to be valid", tt.from, tt.to)
+ }
+ })
+ }
+}
+
+func TestValidTransition_DisallowedTransitions(t *testing.T) {
+ tests := []struct {
+ name string
+ from State
+ to State
+ }{
+ {"pending to running", StatePending, StateRunning},
+ {"pending to completed", StatePending, StateCompleted},
+ {"queued to completed", StateQueued, StateCompleted},
+ {"completed to running", StateCompleted, StateRunning},
+ {"completed to queued", StateCompleted, StateQueued},
+ {"failed to completed", StateFailed, StateCompleted},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if ValidTransition(tt.from, tt.to) {
+ t.Errorf("expected transition %s -> %s to be invalid", tt.from, tt.to)
+ }
+ })
+ }
+}
+
+func TestDuration_UnmarshalYAML(t *testing.T) {
+ var d Duration
+ unmarshal := func(v interface{}) error {
+ ptr := v.(*string)
+ *ptr = "30m"
+ return nil
+ }
+ if err := d.UnmarshalYAML(unmarshal); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if d.Duration != 30*time.Minute {
+ t.Errorf("expected 30m, got %v", d.Duration)
+ }
+}
+
+func TestDuration_MarshalYAML(t *testing.T) {
+ d := Duration{Duration: 15 * time.Minute}
+ v, err := d.MarshalYAML()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if v != "15m0s" {
+ t.Errorf("expected '15m0s', got %v", v)
+ }
+}
diff --git a/internal/task/validator.go b/internal/task/validator.go
new file mode 100644
index 0000000..ea0b1c2
--- /dev/null
+++ b/internal/task/validator.go
@@ -0,0 +1,65 @@
+package task
+
+import (
+ "fmt"
+ "strings"
+)
+
+// ValidationError collects multiple validation failures.
+type ValidationError struct {
+ Errors []string
+}
+
+func (e *ValidationError) Error() string {
+ return fmt.Sprintf("validation failed: %s", strings.Join(e.Errors, "; "))
+}
+
+func (e *ValidationError) Add(msg string) {
+ e.Errors = append(e.Errors, msg)
+}
+
+func (e *ValidationError) HasErrors() bool {
+ return len(e.Errors) > 0
+}
+
+// Validate checks a task for required fields and valid values.
+func Validate(t *Task) error {
+ ve := &ValidationError{}
+
+ if t.Name == "" {
+ ve.Add("name is required")
+ }
+ if t.Claude.Instructions == "" {
+ ve.Add("claude.instructions is required")
+ }
+ if t.Claude.MaxBudgetUSD < 0 {
+ ve.Add("claude.max_budget_usd must be non-negative")
+ }
+ if t.Timeout.Duration < 0 {
+ ve.Add("timeout must be non-negative")
+ }
+ if t.Retry.MaxAttempts < 1 {
+ ve.Add("retry.max_attempts must be at least 1")
+ }
+ if t.Retry.Backoff != "" && t.Retry.Backoff != "linear" && t.Retry.Backoff != "exponential" {
+ ve.Add("retry.backoff must be 'linear' or 'exponential'")
+ }
+ validPriorities := map[Priority]bool{PriorityHigh: true, PriorityNormal: true, PriorityLow: true}
+ if t.Priority != "" && !validPriorities[t.Priority] {
+ ve.Add(fmt.Sprintf("invalid priority %q; must be high, normal, or low", t.Priority))
+ }
+ if t.Claude.PermissionMode != "" {
+ validModes := map[string]bool{
+ "default": true, "acceptEdits": true, "bypassPermissions": true,
+ "plan": true, "dontAsk": true, "delegate": true,
+ }
+ if !validModes[t.Claude.PermissionMode] {
+ ve.Add(fmt.Sprintf("invalid permission_mode %q", t.Claude.PermissionMode))
+ }
+ }
+
+ if ve.HasErrors() {
+ return ve
+ }
+ return nil
+}
diff --git a/internal/task/validator_test.go b/internal/task/validator_test.go
new file mode 100644
index 0000000..967eed3
--- /dev/null
+++ b/internal/task/validator_test.go
@@ -0,0 +1,115 @@
+package task
+
+import (
+ "strings"
+ "testing"
+)
+
+func validTask() *Task {
+ return &Task{
+ ID: "test-id",
+ Name: "Valid Task",
+ Claude: ClaudeConfig{
+ Instructions: "do something",
+ WorkingDir: "/tmp",
+ },
+ Priority: PriorityNormal,
+ Retry: RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
+ }
+}
+
+func TestValidate_ValidTask_NoError(t *testing.T) {
+ task := validTask()
+ if err := Validate(task); err != nil {
+ t.Errorf("expected no error, got: %v", err)
+ }
+}
+
+func TestValidate_MissingName_ReturnsError(t *testing.T) {
+ task := validTask()
+ task.Name = ""
+ err := Validate(task)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "name is required") {
+ t.Errorf("expected 'name is required' in error, got: %v", err)
+ }
+}
+
+func TestValidate_MissingInstructions_ReturnsError(t *testing.T) {
+ task := validTask()
+ task.Claude.Instructions = ""
+ err := Validate(task)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "instructions is required") {
+ t.Errorf("expected 'instructions is required' in error, got: %v", err)
+ }
+}
+
+func TestValidate_NegativeBudget_ReturnsError(t *testing.T) {
+ task := validTask()
+ task.Claude.MaxBudgetUSD = -1.0
+ err := Validate(task)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "max_budget_usd") {
+ t.Errorf("expected budget error, got: %v", err)
+ }
+}
+
+func TestValidate_InvalidBackoff_ReturnsError(t *testing.T) {
+ task := validTask()
+ task.Retry.Backoff = "random"
+ err := Validate(task)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "backoff") {
+ t.Errorf("expected backoff error, got: %v", err)
+ }
+}
+
+func TestValidate_InvalidPriority_ReturnsError(t *testing.T) {
+ task := validTask()
+ task.Priority = "urgent"
+ err := Validate(task)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "invalid priority") {
+ t.Errorf("expected priority error, got: %v", err)
+ }
+}
+
+func TestValidate_InvalidPermissionMode_ReturnsError(t *testing.T) {
+ task := validTask()
+ task.Claude.PermissionMode = "yolo"
+ err := Validate(task)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "permission_mode") {
+ t.Errorf("expected permission_mode error, got: %v", err)
+ }
+}
+
+func TestValidate_MultipleErrors(t *testing.T) {
+ task := &Task{
+ Retry: RetryConfig{MaxAttempts: 0, Backoff: "bad"},
+ }
+ err := Validate(task)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ ve, ok := err.(*ValidationError)
+ if !ok {
+ t.Fatalf("expected *ValidationError, got %T", err)
+ }
+ if len(ve.Errors) < 3 {
+ t.Errorf("expected at least 3 errors, got %d: %v", len(ve.Errors), ve.Errors)
+ }
+}