package cli import ( "context" "fmt" "os" "os/signal" "syscall" "github.com/thepeterstone/claudomator/internal/executor" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" "github.com/spf13/cobra" ) func newRunCmd() *cobra.Command { var ( parallel int dryRun bool ) cmd := &cobra.Command{ Use: "run ", Short: "Run task(s) from a YAML file", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { return runTasks(args[0], parallel, dryRun) }, } cmd.Flags().IntVarP(¶llel, "parallel", "p", 3, "max concurrent executions") cmd.Flags().BoolVar(&dryRun, "dry-run", false, "validate without executing") return cmd } func runTasks(file string, parallel int, dryRun bool) error { if parallel < 1 { return fmt.Errorf("--parallel must be at least 1, got %d", parallel) } tasks, err := task.ParseFile(file) if err != nil { return fmt.Errorf("parsing: %w", err) } // Validate all tasks. for i := range tasks { if err := task.Validate(&tasks[i]); err != nil { return fmt.Errorf("task %q: %w", tasks[i].Name, err) } } if dryRun { fmt.Printf("Validated %d task(s) successfully.\n", len(tasks)) for _, t := range tasks { fmt.Printf(" - %s (model: %s, timeout: %v)\n", t.Name, t.Agent.Model, t.Timeout.Duration) } return nil } // Setup infrastructure. if err := cfg.EnsureDirs(); err != nil { return fmt.Errorf("creating dirs: %w", err) } store, err := storage.Open(cfg.DBPath) if err != nil { return fmt.Errorf("opening db: %w", err) } defer store.Close() logger := newLogger(verbose) apiURL := "http://localhost" + cfg.ServerAddr if len(cfg.ServerAddr) > 0 && cfg.ServerAddr[0] != ':' { apiURL = "http://" + cfg.ServerAddr } runners := map[string]executor.Runner{ "claude": &executor.ContainerRunner{ Image: cfg.ClaudeImage, Logger: logger, LogDir: cfg.LogDir, APIURL: apiURL, DropsDir: cfg.DropsDir, SSHAuthSock: cfg.SSHAuthSock, ClaudeBinary: cfg.ClaudeBinaryPath, GeminiBinary: cfg.GeminiBinaryPath, }, "gemini": &executor.ContainerRunner{ Image: cfg.GeminiImage, Logger: logger, LogDir: cfg.LogDir, APIURL: apiURL, DropsDir: cfg.DropsDir, SSHAuthSock: cfg.SSHAuthSock, ClaudeBinary: cfg.ClaudeBinaryPath, GeminiBinary: cfg.GeminiBinaryPath, }, } pool := executor.NewPool(parallel, runners, store, logger) if cfg.GeminiBinaryPath != "" { pool.Classifier = &executor.Classifier{GeminiBinaryPath: cfg.GeminiBinaryPath} } // Handle graceful shutdown. ctx, cancel := context.WithCancel(context.Background()) defer cancel() sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) go func() { <-sigCh fmt.Fprintln(os.Stderr, "\nShutting down...") cancel() }() // Submit all tasks. fmt.Printf("Dispatching %d task(s) (max concurrency: %d)...\n", len(tasks), parallel) for i := range tasks { if err := store.CreateTask(&tasks[i]); err != nil { return fmt.Errorf("storing task: %w", err) } if err := store.UpdateTaskState(tasks[i].ID, task.StateQueued); err != nil { return fmt.Errorf("queuing task: %w", err) } tasks[i].State = task.StateQueued if err := pool.Submit(ctx, &tasks[i]); err != nil { logger.Warn("could not submit task", "name", tasks[i].Name, "error", err) } } // Wait for all results. completed, failed := 0, 0 for i := 0; i < len(tasks); i++ { result := <-pool.Results() if result.Err != nil { failed++ fmt.Printf(" FAIL %s: %v\n", result.TaskID, result.Err) } else { completed++ fmt.Printf(" OK %s (cost: $%.4f)\n", result.TaskID, result.Execution.CostUSD) } } fmt.Printf("\nDone: %d completed, %d failed\n", completed, failed) if failed > 0 { return fmt.Errorf("%d task(s) failed", failed) } return nil }