diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-03-07 23:52:02 +0000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-03-08 04:52:10 +0000 |
| commit | f2d6822db559f680766daf9d66dd3631ed4adcaa (patch) | |
| tree | 0537889a7b88b6009964754129e48de25b3355d7 /internal/executor | |
| parent | 9d2fc7663a161bb471a9defc3000212264767866 (diff) | |
feat(executor): support multiple runners in Pool
Diffstat (limited to 'internal/executor')
| -rw-r--r-- | internal/executor/executor.go | 53 |
1 files changed, 46 insertions, 7 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 62fed2e..0786d86 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -29,7 +29,7 @@ type Runner interface { // Pool manages a bounded set of concurrent task workers. type Pool struct { maxConcurrent int - runner Runner + runners map[string]Runner store *storage.DB logger *slog.Logger @@ -47,13 +47,13 @@ type Result struct { Err error } -func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.Logger) *Pool { +func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, logger *slog.Logger) *Pool { if maxConcurrent < 1 { maxConcurrent = 1 } return &Pool{ maxConcurrent: maxConcurrent, - runner: runner, + runners: runners, store: store, logger: logger, cancels: make(map[string]context.CancelFunc), @@ -113,6 +113,18 @@ func (p *Pool) SubmitResume(ctx context.Context, t *task.Task, exec *storage.Exe return nil } +func (p *Pool) getRunner(t *task.Task) (Runner, error) { + agentType := t.Agent.Type + if agentType == "" { + agentType = "claude" // Default for backward compatibility + } + runner, ok := p.runners[agentType] + if !ok { + return nil, fmt.Errorf("unsupported agent type: %q", agentType) + } + return runner, nil +} + func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Execution) { defer func() { p.mu.Lock() @@ -120,8 +132,15 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex p.mu.Unlock() }() + runner, err := p.getRunner(t) + if err != nil { + p.logger.Error("failed to get runner for resume", "error", err, "taskID", t.ID) + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} + return + } + // Pre-populate log paths. - if lp, ok := p.runner.(LogPather); ok { + if lp, ok := runner.(LogPather); ok { if logDir := lp.ExecLogDir(exec.ID); logDir != "" { exec.StdoutPath = filepath.Join(logDir, "stdout.log") exec.StderrPath = filepath.Join(logDir, "stderr.log") @@ -146,7 +165,7 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex } defer cancel() - err := p.runner.Run(ctx, t, exec) + err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() if err != nil { @@ -198,6 +217,26 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { p.mu.Unlock() }() + runner, err := p.getRunner(t) + if err != nil { + p.logger.Error("failed to get runner", "error", err, "taskID", t.ID) + now := time.Now().UTC() + exec := &storage.Execution{ + ID: uuid.New().String(), + TaskID: t.ID, + StartTime: now, + EndTime: now, + Status: "FAILED", + ErrorMsg: err.Error(), + } + if createErr := p.store.CreateExecution(exec); createErr != nil { + p.logger.Error("failed to create execution record", "error", createErr) + } + p.store.UpdateTaskState(t.ID, task.StateFailed) + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} + return + } + // Wait for all dependencies to complete before starting execution. if len(t.DependsOn) > 0 { if err := p.waitForDependencies(ctx, t); err != nil { @@ -229,7 +268,7 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { // Pre-populate log paths so they're available in the DB immediately — // before the subprocess starts — enabling live tailing and debugging. - if lp, ok := p.runner.(LogPather); ok { + if lp, ok := runner.(LogPather); ok { if logDir := lp.ExecLogDir(execID); logDir != "" { exec.StdoutPath = filepath.Join(logDir, "stdout.log") exec.StderrPath = filepath.Join(logDir, "stderr.log") @@ -263,7 +302,7 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { }() // Run the task. - err := p.runner.Run(ctx, t, exec) + err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() if err != nil { |
