diff options
Diffstat (limited to 'internal/task')
| -rw-r--r-- | internal/task/parser.go | 61 | ||||
| -rw-r--r-- | internal/task/parser_test.go | 152 | ||||
| -rw-r--r-- | internal/task/task.go | 100 | ||||
| -rw-r--r-- | internal/task/task_test.go | 80 | ||||
| -rw-r--r-- | internal/task/validator.go | 65 | ||||
| -rw-r--r-- | internal/task/validator_test.go | 115 |
6 files changed, 573 insertions, 0 deletions
diff --git a/internal/task/parser.go b/internal/task/parser.go new file mode 100644 index 0000000..7a450b8 --- /dev/null +++ b/internal/task/parser.go @@ -0,0 +1,61 @@ +package task + +import ( + "fmt" + "os" + "time" + + "github.com/google/uuid" + "gopkg.in/yaml.v3" +) + +// ParseFile reads a YAML file and returns tasks. Supports both single-task +// and batch (tasks: [...]) formats. +func ParseFile(path string) ([]Task, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading task file: %w", err) + } + return Parse(data) +} + +// Parse parses YAML bytes into tasks. +func Parse(data []byte) ([]Task, error) { + // Try batch format first. + var batch BatchFile + if err := yaml.Unmarshal(data, &batch); err == nil && len(batch.Tasks) > 0 { + return initTasks(batch.Tasks), nil + } + + // Try single task. + var t Task + if err := yaml.Unmarshal(data, &t); err != nil { + return nil, fmt.Errorf("parsing task YAML: %w", err) + } + if t.Name == "" { + return nil, fmt.Errorf("task must have a name") + } + return initTasks([]Task{t}), nil +} + +func initTasks(tasks []Task) []Task { + now := time.Now() + for i := range tasks { + if tasks[i].ID == "" { + tasks[i].ID = uuid.New().String() + } + if tasks[i].Priority == "" { + tasks[i].Priority = PriorityNormal + } + if tasks[i].Retry.MaxAttempts == 0 { + tasks[i].Retry.MaxAttempts = 1 + } + if tasks[i].Retry.Backoff == "" { + tasks[i].Retry.Backoff = "exponential" + } + tasks[i].State = StatePending + tasks[i].CreatedAt = now + tasks[i].UpdatedAt = now + } + return tasks +} diff --git a/internal/task/parser_test.go b/internal/task/parser_test.go new file mode 100644 index 0000000..cb68e86 --- /dev/null +++ b/internal/task/parser_test.go @@ -0,0 +1,152 @@ +package task + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestParse_SingleTask(t *testing.T) { + yaml := ` +name: "Test Task" +description: "A simple test" +claude: + model: "sonnet" + instructions: "Do something" + working_dir: "/tmp" +timeout: "10m" +tags: + - "test" +` + tasks, err := Parse([]byte(yaml)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tasks) != 1 { + t.Fatalf("expected 1 task, got %d", len(tasks)) + } + task := tasks[0] + if task.Name != "Test Task" { + t.Errorf("expected name 'Test Task', got %q", task.Name) + } + if task.Claude.Model != "sonnet" { + t.Errorf("expected model 'sonnet', got %q", task.Claude.Model) + } + if task.Timeout.Duration != 10*time.Minute { + t.Errorf("expected timeout 10m, got %v", task.Timeout.Duration) + } + if task.State != StatePending { + t.Errorf("expected state PENDING, got %v", task.State) + } + if task.ID == "" { + t.Error("expected auto-generated ID") + } + if task.Priority != PriorityNormal { + t.Errorf("expected default priority 'normal', got %q", task.Priority) + } +} + +func TestParse_BatchTasks(t *testing.T) { + yaml := ` +tasks: + - name: "Task A" + claude: + instructions: "Do A" + working_dir: "/tmp" + tags: ["alpha"] + - name: "Task B" + claude: + instructions: "Do B" + working_dir: "/tmp" + tags: ["beta"] +` + tasks, err := Parse([]byte(yaml)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tasks) != 2 { + t.Fatalf("expected 2 tasks, got %d", len(tasks)) + } + if tasks[0].Name != "Task A" { + t.Errorf("expected 'Task A', got %q", tasks[0].Name) + } + if tasks[1].Name != "Task B" { + t.Errorf("expected 'Task B', got %q", tasks[1].Name) + } +} + +func TestParse_MissingName_ReturnsError(t *testing.T) { + yaml := ` +description: "no name" +claude: + instructions: "something" +` + _, err := Parse([]byte(yaml)) + if err == nil { + t.Fatal("expected error for missing name") + } +} + +func TestParse_DefaultRetryConfig(t *testing.T) { + yaml := ` +name: "Defaults" +claude: + instructions: "test" +` + tasks, err := Parse([]byte(yaml)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tasks[0].Retry.MaxAttempts != 1 { + t.Errorf("expected default max_attempts=1, got %d", tasks[0].Retry.MaxAttempts) + } + if tasks[0].Retry.Backoff != "exponential" { + t.Errorf("expected default backoff 'exponential', got %q", tasks[0].Retry.Backoff) + } +} + +func TestParse_WithPriority(t *testing.T) { + yaml := ` +name: "High Priority" +priority: "high" +claude: + instructions: "urgent" +` + tasks, err := Parse([]byte(yaml)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tasks[0].Priority != PriorityHigh { + t.Errorf("expected priority 'high', got %q", tasks[0].Priority) + } +} + +func TestParseFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "task.yaml") + content := ` +name: "File Task" +claude: + instructions: "from file" + working_dir: "/tmp" +` + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + tasks, err := ParseFile(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tasks) != 1 || tasks[0].Name != "File Task" { + t.Errorf("unexpected tasks: %+v", tasks) + } +} + +func TestParseFile_NotFound(t *testing.T) { + _, err := ParseFile("/nonexistent/task.yaml") + if err == nil { + t.Fatal("expected error for nonexistent file") + } +} diff --git a/internal/task/task.go b/internal/task/task.go new file mode 100644 index 0000000..3796cf3 --- /dev/null +++ b/internal/task/task.go @@ -0,0 +1,100 @@ +package task + +import "time" + +type State string + +const ( + StatePending State = "PENDING" + StateQueued State = "QUEUED" + StateRunning State = "RUNNING" + StateCompleted State = "COMPLETED" + StateFailed State = "FAILED" + StateTimedOut State = "TIMED_OUT" + StateCancelled State = "CANCELLED" + StateBudgetExceeded State = "BUDGET_EXCEEDED" +) + +type Priority string + +const ( + PriorityHigh Priority = "high" + PriorityNormal Priority = "normal" + PriorityLow Priority = "low" +) + +type ClaudeConfig struct { + Model string `yaml:"model"` + ContextFiles []string `yaml:"context_files"` + Instructions string `yaml:"instructions"` + WorkingDir string `yaml:"working_dir"` + MaxBudgetUSD float64 `yaml:"max_budget_usd"` + PermissionMode string `yaml:"permission_mode"` + AllowedTools []string `yaml:"allowed_tools"` + DisallowedTools []string `yaml:"disallowed_tools"` + SystemPromptAppend string `yaml:"system_prompt_append"` + AdditionalArgs []string `yaml:"additional_args"` +} + +type RetryConfig struct { + MaxAttempts int `yaml:"max_attempts"` + Backoff string `yaml:"backoff"` // "linear", "exponential" +} + +type Task struct { + ID string `yaml:"id"` + Name string `yaml:"name"` + Description string `yaml:"description"` + Claude ClaudeConfig `yaml:"claude"` + Timeout Duration `yaml:"timeout"` + Retry RetryConfig `yaml:"retry"` + Priority Priority `yaml:"priority"` + Tags []string `yaml:"tags"` + DependsOn []string `yaml:"depends_on"` + State State `yaml:"-"` + CreatedAt time.Time `yaml:"-"` + UpdatedAt time.Time `yaml:"-"` +} + +// Duration wraps time.Duration for YAML unmarshaling from strings like "30m". +type Duration struct { + time.Duration +} + +func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + dur, err := time.ParseDuration(s) + if err != nil { + return err + } + d.Duration = dur + return nil +} + +func (d Duration) MarshalYAML() (interface{}, error) { + return d.Duration.String(), nil +} + +// BatchFile represents a YAML file containing multiple tasks. +type BatchFile struct { + Tasks []Task `yaml:"tasks"` +} + +// ValidTransition returns true if moving from the current state to next is allowed. +func ValidTransition(from, to State) bool { + transitions := map[State][]State{ + StatePending: {StateQueued, StateCancelled}, + StateQueued: {StateRunning, StateCancelled}, + StateRunning: {StateCompleted, StateFailed, StateTimedOut, StateCancelled, StateBudgetExceeded}, + StateFailed: {StateQueued}, // retry + } + for _, allowed := range transitions[from] { + if allowed == to { + return true + } + } + return false +} diff --git a/internal/task/task_test.go b/internal/task/task_test.go new file mode 100644 index 0000000..96f5f6f --- /dev/null +++ b/internal/task/task_test.go @@ -0,0 +1,80 @@ +package task + +import ( + "testing" + "time" +) + +func TestValidTransition_AllowedTransitions(t *testing.T) { + tests := []struct { + name string + from State + to State + }{ + {"pending to queued", StatePending, StateQueued}, + {"pending to cancelled", StatePending, StateCancelled}, + {"queued to running", StateQueued, StateRunning}, + {"queued to cancelled", StateQueued, StateCancelled}, + {"running to completed", StateRunning, StateCompleted}, + {"running to failed", StateRunning, StateFailed}, + {"running to timed out", StateRunning, StateTimedOut}, + {"running to cancelled", StateRunning, StateCancelled}, + {"running to budget exceeded", StateRunning, StateBudgetExceeded}, + {"failed to queued (retry)", StateFailed, StateQueued}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !ValidTransition(tt.from, tt.to) { + t.Errorf("expected transition %s -> %s to be valid", tt.from, tt.to) + } + }) + } +} + +func TestValidTransition_DisallowedTransitions(t *testing.T) { + tests := []struct { + name string + from State + to State + }{ + {"pending to running", StatePending, StateRunning}, + {"pending to completed", StatePending, StateCompleted}, + {"queued to completed", StateQueued, StateCompleted}, + {"completed to running", StateCompleted, StateRunning}, + {"completed to queued", StateCompleted, StateQueued}, + {"failed to completed", StateFailed, StateCompleted}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if ValidTransition(tt.from, tt.to) { + t.Errorf("expected transition %s -> %s to be invalid", tt.from, tt.to) + } + }) + } +} + +func TestDuration_UnmarshalYAML(t *testing.T) { + var d Duration + unmarshal := func(v interface{}) error { + ptr := v.(*string) + *ptr = "30m" + return nil + } + if err := d.UnmarshalYAML(unmarshal); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if d.Duration != 30*time.Minute { + t.Errorf("expected 30m, got %v", d.Duration) + } +} + +func TestDuration_MarshalYAML(t *testing.T) { + d := Duration{Duration: 15 * time.Minute} + v, err := d.MarshalYAML() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v != "15m0s" { + t.Errorf("expected '15m0s', got %v", v) + } +} diff --git a/internal/task/validator.go b/internal/task/validator.go new file mode 100644 index 0000000..ea0b1c2 --- /dev/null +++ b/internal/task/validator.go @@ -0,0 +1,65 @@ +package task + +import ( + "fmt" + "strings" +) + +// ValidationError collects multiple validation failures. +type ValidationError struct { + Errors []string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("validation failed: %s", strings.Join(e.Errors, "; ")) +} + +func (e *ValidationError) Add(msg string) { + e.Errors = append(e.Errors, msg) +} + +func (e *ValidationError) HasErrors() bool { + return len(e.Errors) > 0 +} + +// Validate checks a task for required fields and valid values. +func Validate(t *Task) error { + ve := &ValidationError{} + + if t.Name == "" { + ve.Add("name is required") + } + if t.Claude.Instructions == "" { + ve.Add("claude.instructions is required") + } + if t.Claude.MaxBudgetUSD < 0 { + ve.Add("claude.max_budget_usd must be non-negative") + } + if t.Timeout.Duration < 0 { + ve.Add("timeout must be non-negative") + } + if t.Retry.MaxAttempts < 1 { + ve.Add("retry.max_attempts must be at least 1") + } + if t.Retry.Backoff != "" && t.Retry.Backoff != "linear" && t.Retry.Backoff != "exponential" { + ve.Add("retry.backoff must be 'linear' or 'exponential'") + } + validPriorities := map[Priority]bool{PriorityHigh: true, PriorityNormal: true, PriorityLow: true} + if t.Priority != "" && !validPriorities[t.Priority] { + ve.Add(fmt.Sprintf("invalid priority %q; must be high, normal, or low", t.Priority)) + } + if t.Claude.PermissionMode != "" { + validModes := map[string]bool{ + "default": true, "acceptEdits": true, "bypassPermissions": true, + "plan": true, "dontAsk": true, "delegate": true, + } + if !validModes[t.Claude.PermissionMode] { + ve.Add(fmt.Sprintf("invalid permission_mode %q", t.Claude.PermissionMode)) + } + } + + if ve.HasErrors() { + return ve + } + return nil +} diff --git a/internal/task/validator_test.go b/internal/task/validator_test.go new file mode 100644 index 0000000..967eed3 --- /dev/null +++ b/internal/task/validator_test.go @@ -0,0 +1,115 @@ +package task + +import ( + "strings" + "testing" +) + +func validTask() *Task { + return &Task{ + ID: "test-id", + Name: "Valid Task", + Claude: ClaudeConfig{ + Instructions: "do something", + WorkingDir: "/tmp", + }, + Priority: PriorityNormal, + Retry: RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + } +} + +func TestValidate_ValidTask_NoError(t *testing.T) { + task := validTask() + if err := Validate(task); err != nil { + t.Errorf("expected no error, got: %v", err) + } +} + +func TestValidate_MissingName_ReturnsError(t *testing.T) { + task := validTask() + task.Name = "" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "name is required") { + t.Errorf("expected 'name is required' in error, got: %v", err) + } +} + +func TestValidate_MissingInstructions_ReturnsError(t *testing.T) { + task := validTask() + task.Claude.Instructions = "" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "instructions is required") { + t.Errorf("expected 'instructions is required' in error, got: %v", err) + } +} + +func TestValidate_NegativeBudget_ReturnsError(t *testing.T) { + task := validTask() + task.Claude.MaxBudgetUSD = -1.0 + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "max_budget_usd") { + t.Errorf("expected budget error, got: %v", err) + } +} + +func TestValidate_InvalidBackoff_ReturnsError(t *testing.T) { + task := validTask() + task.Retry.Backoff = "random" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "backoff") { + t.Errorf("expected backoff error, got: %v", err) + } +} + +func TestValidate_InvalidPriority_ReturnsError(t *testing.T) { + task := validTask() + task.Priority = "urgent" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "invalid priority") { + t.Errorf("expected priority error, got: %v", err) + } +} + +func TestValidate_InvalidPermissionMode_ReturnsError(t *testing.T) { + task := validTask() + task.Claude.PermissionMode = "yolo" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "permission_mode") { + t.Errorf("expected permission_mode error, got: %v", err) + } +} + +func TestValidate_MultipleErrors(t *testing.T) { + task := &Task{ + Retry: RetryConfig{MaxAttempts: 0, Backoff: "bad"}, + } + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + ve, ok := err.(*ValidationError) + if !ok { + t.Fatalf("expected *ValidationError, got %T", err) + } + if len(ve.Errors) < 3 { + t.Errorf("expected at least 3 errors, got %d: %v", len(ve.Errors), ve.Errors) + } +} |
