summaryrefslogtreecommitdiff
path: root/internal/executor
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-14 00:39:22 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-14 00:39:22 +0000
commit2ee988ccc04c09ceb6de7cdb75c94114e85d01b9 (patch)
tree29100e3e4b33748c544b9a42cb74e964df49b96e /internal/executor
parent98ccde12b08ad0b7f53e42de959a72d8382179e3 (diff)
feat: add agent selector to UI and support direct agent assignment
- Added an agent selector (Auto, Claude, Gemini) to the Start Next Task button. - Updated the backend to pass query parameters as environment variables to scripts. - Modified the executor pool to skip classification when a specific agent is requested. - Added --agent flag to claudomator start command. - Updated tests to cover the new functionality.
Diffstat (limited to 'internal/executor')
-rw-r--r--internal/executor/executor.go29
-rw-r--r--internal/executor/executor_test.go38
2 files changed, 55 insertions, 12 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 7ae4e2d..bf209b7 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -414,19 +414,24 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
}
p.mu.Unlock()
- // Deterministically pick the agent with fewest active tasks.
- selectedAgent := pickAgent(status)
- if selectedAgent != "" {
- t.Agent.Type = selectedAgent
- }
+ // If a specific agent is already requested, skip selection and classification.
+ skipClassification := t.Agent.Type == "claude" || t.Agent.Type == "gemini"
+
+ if !skipClassification {
+ // Deterministically pick the agent with fewest active tasks.
+ selectedAgent := pickAgent(status)
+ if selectedAgent != "" {
+ t.Agent.Type = selectedAgent
+ }
- if p.Classifier != nil {
- cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type)
- if err == nil {
- p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason)
- t.Agent.Model = cls.Model
- } else {
- p.logger.Error("classification failed", "error", err, "taskID", t.ID)
+ if p.Classifier != nil {
+ cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type)
+ if err == nil {
+ p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason)
+ t.Agent.Model = cls.Model
+ } else {
+ p.logger.Error("classification failed", "error", err, "taskID", t.ID)
+ }
}
}
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
index 7e676eb..17982f8 100644
--- a/internal/executor/executor_test.go
+++ b/internal/executor/executor_test.go
@@ -1121,3 +1121,41 @@ func TestPool_LoadBalancing_OverridesAgentType(t *testing.T) {
t.Errorf("expected claude runner to be called once, got %d", runner.callCount())
}
}
+
+// TestPool_SpecificAgent_SkipsLoadBalancing verifies that if a specific
+// registered agent is requested (claude or gemini), it is used directly
+// and load balancing (pickAgent) is skipped.
+func TestPool_SpecificAgent_SkipsLoadBalancing(t *testing.T) {
+ store := testStore(t)
+ claudeRunner := &mockRunner{}
+ geminiRunner := &mockRunner{}
+ runners := map[string]Runner{
+ "claude": claudeRunner,
+ "gemini": geminiRunner,
+ }
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(4, runners, store, logger)
+
+ // Inject 2 active tasks for gemini, 0 for claude.
+ // pickAgent would normally pick "claude".
+ pool.mu.Lock()
+ pool.activePerAgent["gemini"] = 2
+ pool.mu.Unlock()
+
+ tk := makeTask("specific-gemini")
+ tk.Agent.Type = "gemini"
+ store.CreateTask(tk)
+
+ if err := pool.Submit(context.Background(), tk); err != nil {
+ t.Fatalf("submit: %v", err)
+ }
+
+ <-pool.Results()
+
+ if geminiRunner.callCount() != 1 {
+ t.Errorf("expected gemini runner to be called once, got %d", geminiRunner.callCount())
+ }
+ if claudeRunner.callCount() != 0 {
+ t.Errorf("expected claude runner to NOT be called, got %d", claudeRunner.callCount())
+ }
+}