summaryrefslogtreecommitdiff
path: root/internal/executor/executor.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/executor/executor.go')
-rw-r--r--internal/executor/executor.go148
1 files changed, 133 insertions, 15 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 0245899..d1c8e72 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -36,18 +36,21 @@ type workItem struct {
// Pool manages a bounded set of concurrent task workers.
type Pool struct {
maxConcurrent int
- runner Runner
+ runners map[string]Runner
store *storage.DB
logger *slog.Logger
depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s
- mu sync.Mutex
- active int
- cancels map[string]context.CancelFunc // taskID → cancel
- resultCh chan *Result
- workCh chan workItem // internal bounded queue; Submit enqueues here
- doneCh chan struct{} // signals when a worker slot is freed
- Questions *QuestionRegistry
+ mu sync.Mutex
+ active int
+ activePerAgent map[string]int
+ rateLimited map[string]time.Time // agentType -> until
+ cancels map[string]context.CancelFunc // taskID → cancel
+ resultCh chan *Result
+ workCh chan workItem // internal bounded queue; Submit enqueues here
+ doneCh chan struct{} // signals when a worker slot is freed
+ Questions *QuestionRegistry
+ Classifier *Classifier
}
// Result is emitted when a task execution completes.
@@ -57,16 +60,18 @@ type Result struct {
Err error
}
-func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.Logger) *Pool {
+func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, logger *slog.Logger) *Pool {
if maxConcurrent < 1 {
maxConcurrent = 1
}
p := &Pool{
maxConcurrent: maxConcurrent,
- runner: runner,
+ runners: runners,
store: store,
logger: logger,
depPollInterval: 5 * time.Second,
+ activePerAgent: make(map[string]int),
+ rateLimited: make(map[string]time.Time),
cancels: make(map[string]context.CancelFunc),
resultCh: make(chan *Result, maxConcurrent*2),
workCh: make(chan workItem, maxConcurrent*10+100),
@@ -147,10 +152,32 @@ func (p *Pool) SubmitResume(ctx context.Context, t *task.Task, exec *storage.Exe
}
}
+func (p *Pool) getRunner(t *task.Task) (Runner, error) {
+ agentType := t.Agent.Type
+ if agentType == "" {
+ agentType = "claude" // Default for backward compatibility
+ }
+ runner, ok := p.runners[agentType]
+ if !ok {
+ return nil, fmt.Errorf("unsupported agent type: %q", agentType)
+ }
+ return runner, nil
+}
+
func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Execution) {
+ agentType := t.Agent.Type
+ if agentType == "" {
+ agentType = "claude"
+ }
+
+ p.mu.Lock()
+ p.activePerAgent[agentType]++
+ p.mu.Unlock()
+
defer func() {
p.mu.Lock()
p.active--
+ p.activePerAgent[agentType]--
p.mu.Unlock()
select {
case p.doneCh <- struct{}{}:
@@ -158,8 +185,15 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex
}
}()
+ runner, err := p.getRunner(t)
+ if err != nil {
+ p.logger.Error("failed to get runner for resume", "error", err, "taskID", t.ID)
+ p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err}
+ return
+ }
+
// Pre-populate log paths.
- if lp, ok := p.runner.(LogPather); ok {
+ if lp, ok := runner.(LogPather); ok {
if logDir := lp.ExecLogDir(exec.ID); logDir != "" {
exec.StdoutPath = filepath.Join(logDir, "stdout.log")
exec.StderrPath = filepath.Join(logDir, "stderr.log")
@@ -182,12 +216,30 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex
} else {
ctx, cancel = context.WithCancel(ctx)
}
- defer cancel()
+ p.mu.Lock()
+ p.cancels[t.ID] = cancel
+ p.mu.Unlock()
+ defer func() {
+ cancel()
+ p.mu.Lock()
+ delete(p.cancels, t.ID)
+ p.mu.Unlock()
+ }()
- err := p.runner.Run(ctx, t, exec)
+ err = runner.Run(ctx, t, exec)
exec.EndTime = time.Now().UTC()
if err != nil {
+ if isRateLimitError(err) {
+ p.mu.Lock()
+ retryAfter := parseRetryAfter(err.Error())
+ if retryAfter == 0 {
+ retryAfter = 1 * time.Minute
+ }
+ p.rateLimited[agentType] = time.Now().Add(retryAfter)
+ p.mu.Unlock()
+ }
+
var blockedErr *BlockedError
if errors.As(err, &blockedErr) {
exec.Status = "BLOCKED"
@@ -234,9 +286,45 @@ func (p *Pool) ActiveCount() int {
}
func (p *Pool) execute(ctx context.Context, t *task.Task) {
+ // 1. Classification
+ if p.Classifier != nil {
+ p.mu.Lock()
+ activeTasks := make(map[string]int)
+ rateLimited := make(map[string]bool)
+ now := time.Now()
+ for agent := range p.runners {
+ activeTasks[agent] = p.activePerAgent[agent]
+ rateLimited[agent] = now.Before(p.rateLimited[agent])
+ }
+ status := SystemStatus{
+ ActiveTasks: activeTasks,
+ RateLimited: rateLimited,
+ }
+ p.mu.Unlock()
+
+ cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status)
+ if err == nil {
+ p.logger.Info("task classified", "taskID", t.ID, "agent", cls.AgentType, "model", cls.Model, "reason", cls.Reason)
+ t.Agent.Type = cls.AgentType
+ t.Agent.Model = cls.Model
+ } else {
+ p.logger.Error("classification failed", "error", err, "taskID", t.ID)
+ }
+ }
+
+ agentType := t.Agent.Type
+ if agentType == "" {
+ agentType = "claude"
+ }
+
+ p.mu.Lock()
+ p.activePerAgent[agentType]++
+ p.mu.Unlock()
+
defer func() {
p.mu.Lock()
p.active--
+ p.activePerAgent[agentType]--
p.mu.Unlock()
select {
case p.doneCh <- struct{}{}:
@@ -244,6 +332,26 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
}
}()
+ runner, err := p.getRunner(t)
+ if err != nil {
+ p.logger.Error("failed to get runner", "error", err, "taskID", t.ID)
+ now := time.Now().UTC()
+ exec := &storage.Execution{
+ ID: uuid.New().String(),
+ TaskID: t.ID,
+ StartTime: now,
+ EndTime: now,
+ Status: "FAILED",
+ ErrorMsg: err.Error(),
+ }
+ if createErr := p.store.CreateExecution(exec); createErr != nil {
+ p.logger.Error("failed to create execution record", "error", createErr)
+ }
+ p.store.UpdateTaskState(t.ID, task.StateFailed)
+ p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err}
+ return
+ }
+
// Wait for all dependencies to complete before starting execution.
if len(t.DependsOn) > 0 {
if err := p.waitForDependencies(ctx, t); err != nil {
@@ -275,7 +383,7 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
// Pre-populate log paths so they're available in the DB immediately —
// before the subprocess starts — enabling live tailing and debugging.
- if lp, ok := p.runner.(LogPather); ok {
+ if lp, ok := runner.(LogPather); ok {
if logDir := lp.ExecLogDir(execID); logDir != "" {
exec.StdoutPath = filepath.Join(logDir, "stdout.log")
exec.StderrPath = filepath.Join(logDir, "stderr.log")
@@ -309,10 +417,20 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
}()
// Run the task.
- err := p.runner.Run(ctx, t, exec)
+ err = runner.Run(ctx, t, exec)
exec.EndTime = time.Now().UTC()
if err != nil {
+ if isRateLimitError(err) {
+ p.mu.Lock()
+ retryAfter := parseRetryAfter(err.Error())
+ if retryAfter == 0 {
+ retryAfter = 1 * time.Minute
+ }
+ p.rateLimited[agentType] = time.Now().Add(retryAfter)
+ p.mu.Unlock()
+ }
+
var blockedErr *BlockedError
if errors.As(err, &blockedErr) {
exec.Status = "BLOCKED"