summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/api/scripts.go8
-rw-r--r--internal/api/server.go7
-rw-r--r--internal/api/server_test.go17
-rw-r--r--internal/cli/create.go2
-rw-r--r--internal/cli/create_test.go6
-rw-r--r--internal/cli/start.go16
-rw-r--r--internal/executor/executor.go29
-rw-r--r--internal/executor/executor_test.go38
-rwxr-xr-xscripts/start-next-task7
-rw-r--r--web/index.html11
-rw-r--r--web/test/start-next-task.test.mjs21
11 files changed, 134 insertions, 28 deletions
diff --git a/internal/api/scripts.go b/internal/api/scripts.go
index 822bd32..8db937b 100644
--- a/internal/api/scripts.go
+++ b/internal/api/scripts.go
@@ -4,7 +4,9 @@ import (
"bytes"
"context"
"net/http"
+ "os"
"os/exec"
+ "strings"
"time"
)
@@ -33,6 +35,12 @@ func (s *Server) handleScript(w http.ResponseWriter, r *http.Request) {
defer cancel()
cmd := exec.CommandContext(ctx, scriptPath)
+ cmd.Env = os.Environ()
+ for k, v := range r.URL.Query() {
+ if len(v) > 0 {
+ cmd.Env = append(cmd.Env, "CLAUDOMATOR_"+strings.ToUpper(k)+"="+v[0])
+ }
+ }
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
diff --git a/internal/api/server.go b/internal/api/server.go
index df35536..163f2b8 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -488,9 +488,10 @@ func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, t)
}
-
func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
+ agent := r.URL.Query().Get("agent")
+
t, err := s.store.ResetTaskForRetry(id)
if err != nil {
if strings.Contains(err.Error(), "not found") {
@@ -505,6 +506,10 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
return
}
+ if agent != "" && agent != "auto" {
+ t.Agent.Type = agent
+ }
+
if err := s.pool.Submit(context.Background(), t); err != nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)})
return
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index c90e3b3..2209a69 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -384,6 +384,23 @@ func TestRunTask_TimedOutTask_Returns202(t *testing.T) {
}
}
+func TestRunTask_WithAgentParam(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "run-agent-param", task.StatePending)
+
+ // Request run with agent=gemini.
+ req := httptest.NewRequest("POST", "/api/tasks/run-agent-param/run?agent=gemini", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ // Wait for the task to complete via the mock runner.
+ pollState(t, store, "run-agent-param", task.StateReady, 2*time.Second)
+}
+
func TestRunTask_CompletedTask_Returns409(t *testing.T) {
srv, store := testServer(t)
createTaskWithState(t, store, "run-completed", task.StateCompleted)
diff --git a/internal/cli/create.go b/internal/cli/create.go
index e5435d3..396cd77 100644
--- a/internal/cli/create.go
+++ b/internal/cli/create.go
@@ -88,7 +88,7 @@ func createTask(serverURL, name, instructions, workingDir, model, agentType, par
fmt.Printf("Created task %s\n", id)
if autoStart {
- return startTask(serverURL, id)
+ return startTask(serverURL, id, agentType)
}
return nil
}
diff --git a/internal/cli/create_test.go b/internal/cli/create_test.go
index 4ce1071..71b403e 100644
--- a/internal/cli/create_test.go
+++ b/internal/cli/create_test.go
@@ -42,7 +42,7 @@ func TestStartTask_EscapesTaskID(t *testing.T) {
}))
defer srv.Close()
- err := startTask(srv.URL, "task/with/slashes")
+ err := startTask(srv.URL, "task/with/slashes", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -93,7 +93,7 @@ func TestStartTask_NonJSONResponse_ReturnsError(t *testing.T) {
}))
defer srv.Close()
- err := startTask(srv.URL, "task-abc")
+ err := startTask(srv.URL, "task-abc", "")
if err == nil {
t.Fatal("expected error for non-JSON response, got nil")
}
@@ -115,7 +115,7 @@ func TestStartTask_TimesOut(t *testing.T) {
httpClient = &http.Client{Timeout: 50 * time.Millisecond}
defer func() { httpClient = orig }()
- err := startTask(srv.URL, "task-abc")
+ err := startTask(srv.URL, "task-abc", "")
if err == nil {
t.Fatal("expected timeout error, got nil")
}
diff --git a/internal/cli/start.go b/internal/cli/start.go
index 9e66e00..99af9a5 100644
--- a/internal/cli/start.go
+++ b/internal/cli/start.go
@@ -8,28 +8,32 @@ import (
"github.com/spf13/cobra"
)
-
func newStartCmd() *cobra.Command {
var serverURL string
+ var agent string
cmd := &cobra.Command{
Use: "start <task-id>",
Short: "Queue a task for execution via the running server",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- return startTask(serverURL, args[0])
+ return startTask(serverURL, args[0], agent)
},
}
cmd.Flags().StringVar(&serverURL, "server", "http://localhost:8484", "claudomator server URL")
+ cmd.Flags().StringVar(&agent, "agent", "", "agent to use (claude, gemini, or auto)")
return cmd
}
-func startTask(serverURL, id string) error {
- url := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, url.PathEscape(id))
- resp, err := httpClient.Post(url, "application/json", nil) //nolint:noctx
+func startTask(serverURL, id, agent string) error {
+ u := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, url.PathEscape(id))
+ if agent != "" {
+ u += "?agent=" + url.QueryEscape(agent)
+ }
+ resp, err := httpClient.Post(u, "application/json", nil) //nolint:noctx
if err != nil {
- return fmt.Errorf("POST %s: %w", url, err)
+ return fmt.Errorf("POST %s: %w", u, err)
}
defer resp.Body.Close()
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 7ae4e2d..bf209b7 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -414,19 +414,24 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
}
p.mu.Unlock()
- // Deterministically pick the agent with fewest active tasks.
- selectedAgent := pickAgent(status)
- if selectedAgent != "" {
- t.Agent.Type = selectedAgent
- }
+ // If a specific agent is already requested, skip selection and classification.
+ skipClassification := t.Agent.Type == "claude" || t.Agent.Type == "gemini"
+
+ if !skipClassification {
+ // Deterministically pick the agent with fewest active tasks.
+ selectedAgent := pickAgent(status)
+ if selectedAgent != "" {
+ t.Agent.Type = selectedAgent
+ }
- if p.Classifier != nil {
- cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type)
- if err == nil {
- p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason)
- t.Agent.Model = cls.Model
- } else {
- p.logger.Error("classification failed", "error", err, "taskID", t.ID)
+ if p.Classifier != nil {
+ cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type)
+ if err == nil {
+ p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason)
+ t.Agent.Model = cls.Model
+ } else {
+ p.logger.Error("classification failed", "error", err, "taskID", t.ID)
+ }
}
}
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
index 7e676eb..17982f8 100644
--- a/internal/executor/executor_test.go
+++ b/internal/executor/executor_test.go
@@ -1121,3 +1121,41 @@ func TestPool_LoadBalancing_OverridesAgentType(t *testing.T) {
t.Errorf("expected claude runner to be called once, got %d", runner.callCount())
}
}
+
+// TestPool_SpecificAgent_SkipsLoadBalancing verifies that if a specific
+// registered agent is requested (claude or gemini), it is used directly
+// and load balancing (pickAgent) is skipped.
+func TestPool_SpecificAgent_SkipsLoadBalancing(t *testing.T) {
+ store := testStore(t)
+ claudeRunner := &mockRunner{}
+ geminiRunner := &mockRunner{}
+ runners := map[string]Runner{
+ "claude": claudeRunner,
+ "gemini": geminiRunner,
+ }
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(4, runners, store, logger)
+
+ // Inject 2 active tasks for gemini, 0 for claude.
+ // pickAgent would normally pick "claude".
+ pool.mu.Lock()
+ pool.activePerAgent["gemini"] = 2
+ pool.mu.Unlock()
+
+ tk := makeTask("specific-gemini")
+ tk.Agent.Type = "gemini"
+ store.CreateTask(tk)
+
+ if err := pool.Submit(context.Background(), tk); err != nil {
+ t.Fatalf("submit: %v", err)
+ }
+
+ <-pool.Results()
+
+ if geminiRunner.callCount() != 1 {
+ t.Errorf("expected gemini runner to be called once, got %d", geminiRunner.callCount())
+ }
+ if claudeRunner.callCount() != 0 {
+ t.Errorf("expected claude runner to NOT be called, got %d", claudeRunner.callCount())
+ }
+}
diff --git a/scripts/start-next-task b/scripts/start-next-task
index da019b6..d53f008 100755
--- a/scripts/start-next-task
+++ b/scripts/start-next-task
@@ -12,4 +12,9 @@ if [[ -z "$task_id" ]]; then
exit 0
fi
-claudomator start "$task_id"
+ARGS=()
+if [[ -n "${CLAUDOMATOR_AGENT:-}" ]]; then
+ ARGS+=(--agent "$CLAUDOMATOR_AGENT")
+fi
+
+claudomator start "${ARGS[@]}" "$task_id"
diff --git a/web/index.html b/web/index.html
index d56dcb3..19cba2c 100644
--- a/web/index.html
+++ b/web/index.html
@@ -11,8 +11,15 @@
<body>
<header>
<h1>Claudomator</h1>
- <button id="btn-start-next" class="btn-secondary">Start Next</button>
- <button id="btn-new-task" class="btn-primary">New Task</button>
+ <div class="header-actions">
+ <select id="select-agent" class="agent-selector">
+ <option value="auto">Auto</option>
+ <option value="claude">Claude</option>
+ <option value="gemini">Gemini</option>
+ </select>
+ <button id="btn-start-next" class="btn-secondary">Start Next</button>
+ <button id="btn-new-task" class="btn-primary">New Task</button>
+ </div>
</header>
<nav class="tab-bar">
<button class="tab active" data-tab="queue" title="Queue">⏳</button>
diff --git a/web/test/start-next-task.test.mjs b/web/test/start-next-task.test.mjs
index 6863f7e..eaf3087 100644
--- a/web/test/start-next-task.test.mjs
+++ b/web/test/start-next-task.test.mjs
@@ -9,8 +9,11 @@ import assert from 'node:assert/strict';
// Returns {output, exit_code} on HTTP 2xx
// Throws on HTTP error
-async function startNextTask(basePath, fetchFn) {
- const res = await fetchFn(`${basePath}/api/scripts/start-next-task`, { method: 'POST' });
+async function startNextTask(basePath, fetchFn, agent) {
+ const url = agent && agent !== 'auto'
+ ? `${basePath}/api/scripts/start-next-task?agent=${agent}`
+ : `${basePath}/api/scripts/start-next-task`;
+ const res = await fetchFn(url, { method: 'POST' });
if (!res.ok) {
let msg = `HTTP ${res.status}`;
try { const body = await res.json(); msg = body.error || msg; } catch {}
@@ -35,6 +38,20 @@ describe('startNextTask', () => {
assert.equal(captured.opts.method, 'POST');
});
+ it('appends agent as query parameter when provided', async () => {
+ let captured = null;
+ const mockFetch = (url, opts) => {
+ captured = { url, opts };
+ return Promise.resolve({
+ ok: true,
+ json: () => Promise.resolve({ output: 'claudomator start --agent claude abc-123\n', exit_code: 0 }),
+ });
+ };
+
+ await startNextTask('http://localhost:8484', mockFetch, 'claude');
+ assert.equal(captured.url, 'http://localhost:8484/api/scripts/start-next-task?agent=claude');
+ });
+
it('returns output and exit_code on success', async () => {
const mockFetch = () => Promise.resolve({
ok: true,