diff options
Diffstat (limited to 'internal/executor')
| -rw-r--r-- | internal/executor/executor.go | 29 | ||||
| -rw-r--r-- | internal/executor/executor_test.go | 38 |
2 files changed, 55 insertions, 12 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 7ae4e2d..bf209b7 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -414,19 +414,24 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { } p.mu.Unlock() - // Deterministically pick the agent with fewest active tasks. - selectedAgent := pickAgent(status) - if selectedAgent != "" { - t.Agent.Type = selectedAgent - } + // If a specific agent is already requested, skip selection and classification. + skipClassification := t.Agent.Type == "claude" || t.Agent.Type == "gemini" + + if !skipClassification { + // Deterministically pick the agent with fewest active tasks. + selectedAgent := pickAgent(status) + if selectedAgent != "" { + t.Agent.Type = selectedAgent + } - if p.Classifier != nil { - cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type) - if err == nil { - p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason) - t.Agent.Model = cls.Model - } else { - p.logger.Error("classification failed", "error", err, "taskID", t.ID) + if p.Classifier != nil { + cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type) + if err == nil { + p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason) + t.Agent.Model = cls.Model + } else { + p.logger.Error("classification failed", "error", err, "taskID", t.ID) + } } } diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index 7e676eb..17982f8 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -1121,3 +1121,41 @@ func TestPool_LoadBalancing_OverridesAgentType(t *testing.T) { t.Errorf("expected claude runner to be called once, got %d", runner.callCount()) } } + +// TestPool_SpecificAgent_SkipsLoadBalancing verifies that if a specific +// registered agent is requested (claude or gemini), it is used directly +// and load balancing (pickAgent) is skipped. +func TestPool_SpecificAgent_SkipsLoadBalancing(t *testing.T) { + store := testStore(t) + claudeRunner := &mockRunner{} + geminiRunner := &mockRunner{} + runners := map[string]Runner{ + "claude": claudeRunner, + "gemini": geminiRunner, + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(4, runners, store, logger) + + // Inject 2 active tasks for gemini, 0 for claude. + // pickAgent would normally pick "claude". + pool.mu.Lock() + pool.activePerAgent["gemini"] = 2 + pool.mu.Unlock() + + tk := makeTask("specific-gemini") + tk.Agent.Type = "gemini" + store.CreateTask(tk) + + if err := pool.Submit(context.Background(), tk); err != nil { + t.Fatalf("submit: %v", err) + } + + <-pool.Results() + + if geminiRunner.callCount() != 1 { + t.Errorf("expected gemini runner to be called once, got %d", geminiRunner.callCount()) + } + if claudeRunner.callCount() != 0 { + t.Errorf("expected claude runner to NOT be called, got %d", claudeRunner.callCount()) + } +} |
