diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-03-14 00:39:22 +0000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-03-14 00:39:22 +0000 |
| commit | 2ee988ccc04c09ceb6de7cdb75c94114e85d01b9 (patch) | |
| tree | 29100e3e4b33748c544b9a42cb74e964df49b96e | |
| parent | 98ccde12b08ad0b7f53e42de959a72d8382179e3 (diff) | |
feat: add agent selector to UI and support direct agent assignment
- Added an agent selector (Auto, Claude, Gemini) to the Start Next Task button.
- Updated the backend to pass query parameters as environment variables to scripts.
- Modified the executor pool to skip classification when a specific agent is requested.
- Added --agent flag to claudomator start command.
- Updated tests to cover the new functionality.
| -rw-r--r-- | internal/api/scripts.go | 8 | ||||
| -rw-r--r-- | internal/api/server.go | 7 | ||||
| -rw-r--r-- | internal/api/server_test.go | 17 | ||||
| -rw-r--r-- | internal/cli/create.go | 2 | ||||
| -rw-r--r-- | internal/cli/create_test.go | 6 | ||||
| -rw-r--r-- | internal/cli/start.go | 16 | ||||
| -rw-r--r-- | internal/executor/executor.go | 29 | ||||
| -rw-r--r-- | internal/executor/executor_test.go | 38 | ||||
| -rwxr-xr-x | scripts/start-next-task | 7 | ||||
| -rw-r--r-- | web/index.html | 11 | ||||
| -rw-r--r-- | web/test/start-next-task.test.mjs | 21 |
11 files changed, 134 insertions, 28 deletions
diff --git a/internal/api/scripts.go b/internal/api/scripts.go index 822bd32..8db937b 100644 --- a/internal/api/scripts.go +++ b/internal/api/scripts.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "net/http" + "os" "os/exec" + "strings" "time" ) @@ -33,6 +35,12 @@ func (s *Server) handleScript(w http.ResponseWriter, r *http.Request) { defer cancel() cmd := exec.CommandContext(ctx, scriptPath) + cmd.Env = os.Environ() + for k, v := range r.URL.Query() { + if len(v) > 0 { + cmd.Env = append(cmd.Env, "CLAUDOMATOR_"+strings.ToUpper(k)+"="+v[0]) + } + } var stdout, stderr bytes.Buffer cmd.Stdout = &stdout diff --git a/internal/api/server.go b/internal/api/server.go index df35536..163f2b8 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -488,9 +488,10 @@ func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) { } writeJSON(w, http.StatusOK, t) } - func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") + agent := r.URL.Query().Get("agent") + t, err := s.store.ResetTaskForRetry(id) if err != nil { if strings.Contains(err.Error(), "not found") { @@ -505,6 +506,10 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { return } + if agent != "" && agent != "auto" { + t.Agent.Type = agent + } + if err := s.pool.Submit(context.Background(), t); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)}) return diff --git a/internal/api/server_test.go b/internal/api/server_test.go index c90e3b3..2209a69 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -384,6 +384,23 @@ func TestRunTask_TimedOutTask_Returns202(t *testing.T) { } } +func TestRunTask_WithAgentParam(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "run-agent-param", task.StatePending) + + // Request run with agent=gemini. + req := httptest.NewRequest("POST", "/api/tasks/run-agent-param/run?agent=gemini", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Fatalf("status: want 202, got %d; body: %s", w.Code, w.Body.String()) + } + + // Wait for the task to complete via the mock runner. + pollState(t, store, "run-agent-param", task.StateReady, 2*time.Second) +} + func TestRunTask_CompletedTask_Returns409(t *testing.T) { srv, store := testServer(t) createTaskWithState(t, store, "run-completed", task.StateCompleted) diff --git a/internal/cli/create.go b/internal/cli/create.go index e5435d3..396cd77 100644 --- a/internal/cli/create.go +++ b/internal/cli/create.go @@ -88,7 +88,7 @@ func createTask(serverURL, name, instructions, workingDir, model, agentType, par fmt.Printf("Created task %s\n", id) if autoStart { - return startTask(serverURL, id) + return startTask(serverURL, id, agentType) } return nil } diff --git a/internal/cli/create_test.go b/internal/cli/create_test.go index 4ce1071..71b403e 100644 --- a/internal/cli/create_test.go +++ b/internal/cli/create_test.go @@ -42,7 +42,7 @@ func TestStartTask_EscapesTaskID(t *testing.T) { })) defer srv.Close() - err := startTask(srv.URL, "task/with/slashes") + err := startTask(srv.URL, "task/with/slashes", "") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -93,7 +93,7 @@ func TestStartTask_NonJSONResponse_ReturnsError(t *testing.T) { })) defer srv.Close() - err := startTask(srv.URL, "task-abc") + err := startTask(srv.URL, "task-abc", "") if err == nil { t.Fatal("expected error for non-JSON response, got nil") } @@ -115,7 +115,7 @@ func TestStartTask_TimesOut(t *testing.T) { httpClient = &http.Client{Timeout: 50 * time.Millisecond} defer func() { httpClient = orig }() - err := startTask(srv.URL, "task-abc") + err := startTask(srv.URL, "task-abc", "") if err == nil { t.Fatal("expected timeout error, got nil") } diff --git a/internal/cli/start.go b/internal/cli/start.go index 9e66e00..99af9a5 100644 --- a/internal/cli/start.go +++ b/internal/cli/start.go @@ -8,28 +8,32 @@ import ( "github.com/spf13/cobra" ) - func newStartCmd() *cobra.Command { var serverURL string + var agent string cmd := &cobra.Command{ Use: "start <task-id>", Short: "Queue a task for execution via the running server", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return startTask(serverURL, args[0]) + return startTask(serverURL, args[0], agent) }, } cmd.Flags().StringVar(&serverURL, "server", "http://localhost:8484", "claudomator server URL") + cmd.Flags().StringVar(&agent, "agent", "", "agent to use (claude, gemini, or auto)") return cmd } -func startTask(serverURL, id string) error { - url := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, url.PathEscape(id)) - resp, err := httpClient.Post(url, "application/json", nil) //nolint:noctx +func startTask(serverURL, id, agent string) error { + u := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, url.PathEscape(id)) + if agent != "" { + u += "?agent=" + url.QueryEscape(agent) + } + resp, err := httpClient.Post(u, "application/json", nil) //nolint:noctx if err != nil { - return fmt.Errorf("POST %s: %w", url, err) + return fmt.Errorf("POST %s: %w", u, err) } defer resp.Body.Close() diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 7ae4e2d..bf209b7 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -414,19 +414,24 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { } p.mu.Unlock() - // Deterministically pick the agent with fewest active tasks. - selectedAgent := pickAgent(status) - if selectedAgent != "" { - t.Agent.Type = selectedAgent - } + // If a specific agent is already requested, skip selection and classification. + skipClassification := t.Agent.Type == "claude" || t.Agent.Type == "gemini" + + if !skipClassification { + // Deterministically pick the agent with fewest active tasks. + selectedAgent := pickAgent(status) + if selectedAgent != "" { + t.Agent.Type = selectedAgent + } - if p.Classifier != nil { - cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type) - if err == nil { - p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason) - t.Agent.Model = cls.Model - } else { - p.logger.Error("classification failed", "error", err, "taskID", t.ID) + if p.Classifier != nil { + cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type) + if err == nil { + p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason) + t.Agent.Model = cls.Model + } else { + p.logger.Error("classification failed", "error", err, "taskID", t.ID) + } } } diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index 7e676eb..17982f8 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -1121,3 +1121,41 @@ func TestPool_LoadBalancing_OverridesAgentType(t *testing.T) { t.Errorf("expected claude runner to be called once, got %d", runner.callCount()) } } + +// TestPool_SpecificAgent_SkipsLoadBalancing verifies that if a specific +// registered agent is requested (claude or gemini), it is used directly +// and load balancing (pickAgent) is skipped. +func TestPool_SpecificAgent_SkipsLoadBalancing(t *testing.T) { + store := testStore(t) + claudeRunner := &mockRunner{} + geminiRunner := &mockRunner{} + runners := map[string]Runner{ + "claude": claudeRunner, + "gemini": geminiRunner, + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(4, runners, store, logger) + + // Inject 2 active tasks for gemini, 0 for claude. + // pickAgent would normally pick "claude". + pool.mu.Lock() + pool.activePerAgent["gemini"] = 2 + pool.mu.Unlock() + + tk := makeTask("specific-gemini") + tk.Agent.Type = "gemini" + store.CreateTask(tk) + + if err := pool.Submit(context.Background(), tk); err != nil { + t.Fatalf("submit: %v", err) + } + + <-pool.Results() + + if geminiRunner.callCount() != 1 { + t.Errorf("expected gemini runner to be called once, got %d", geminiRunner.callCount()) + } + if claudeRunner.callCount() != 0 { + t.Errorf("expected claude runner to NOT be called, got %d", claudeRunner.callCount()) + } +} diff --git a/scripts/start-next-task b/scripts/start-next-task index da019b6..d53f008 100755 --- a/scripts/start-next-task +++ b/scripts/start-next-task @@ -12,4 +12,9 @@ if [[ -z "$task_id" ]]; then exit 0 fi -claudomator start "$task_id" +ARGS=() +if [[ -n "${CLAUDOMATOR_AGENT:-}" ]]; then + ARGS+=(--agent "$CLAUDOMATOR_AGENT") +fi + +claudomator start "${ARGS[@]}" "$task_id" diff --git a/web/index.html b/web/index.html index d56dcb3..19cba2c 100644 --- a/web/index.html +++ b/web/index.html @@ -11,8 +11,15 @@ <body> <header> <h1>Claudomator</h1> - <button id="btn-start-next" class="btn-secondary">Start Next</button> - <button id="btn-new-task" class="btn-primary">New Task</button> + <div class="header-actions"> + <select id="select-agent" class="agent-selector"> + <option value="auto">Auto</option> + <option value="claude">Claude</option> + <option value="gemini">Gemini</option> + </select> + <button id="btn-start-next" class="btn-secondary">Start Next</button> + <button id="btn-new-task" class="btn-primary">New Task</button> + </div> </header> <nav class="tab-bar"> <button class="tab active" data-tab="queue" title="Queue">⏳</button> diff --git a/web/test/start-next-task.test.mjs b/web/test/start-next-task.test.mjs index 6863f7e..eaf3087 100644 --- a/web/test/start-next-task.test.mjs +++ b/web/test/start-next-task.test.mjs @@ -9,8 +9,11 @@ import assert from 'node:assert/strict'; // Returns {output, exit_code} on HTTP 2xx // Throws on HTTP error -async function startNextTask(basePath, fetchFn) { - const res = await fetchFn(`${basePath}/api/scripts/start-next-task`, { method: 'POST' }); +async function startNextTask(basePath, fetchFn, agent) { + const url = agent && agent !== 'auto' + ? `${basePath}/api/scripts/start-next-task?agent=${agent}` + : `${basePath}/api/scripts/start-next-task`; + const res = await fetchFn(url, { method: 'POST' }); if (!res.ok) { let msg = `HTTP ${res.status}`; try { const body = await res.json(); msg = body.error || msg; } catch {} @@ -35,6 +38,20 @@ describe('startNextTask', () => { assert.equal(captured.opts.method, 'POST'); }); + it('appends agent as query parameter when provided', async () => { + let captured = null; + const mockFetch = (url, opts) => { + captured = { url, opts }; + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({ output: 'claudomator start --agent claude abc-123\n', exit_code: 0 }), + }); + }; + + await startNextTask('http://localhost:8484', mockFetch, 'claude'); + assert.equal(captured.url, 'http://localhost:8484/api/scripts/start-next-task?agent=claude'); + }); + it('returns output and exit_code on success', async () => { const mockFetch = () => Promise.resolve({ ok: true, |
