summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/api/scripts.go8
-rw-r--r--internal/api/server.go7
-rw-r--r--internal/api/server_test.go17
-rw-r--r--internal/cli/create.go2
-rw-r--r--internal/cli/create_test.go6
-rw-r--r--internal/cli/start.go16
-rw-r--r--internal/executor/executor.go29
-rw-r--r--internal/executor/executor_test.go38
8 files changed, 100 insertions, 23 deletions
diff --git a/internal/api/scripts.go b/internal/api/scripts.go
index 822bd32..8db937b 100644
--- a/internal/api/scripts.go
+++ b/internal/api/scripts.go
@@ -4,7 +4,9 @@ import (
"bytes"
"context"
"net/http"
+ "os"
"os/exec"
+ "strings"
"time"
)
@@ -33,6 +35,12 @@ func (s *Server) handleScript(w http.ResponseWriter, r *http.Request) {
defer cancel()
cmd := exec.CommandContext(ctx, scriptPath)
+ cmd.Env = os.Environ()
+ for k, v := range r.URL.Query() {
+ if len(v) > 0 {
+ cmd.Env = append(cmd.Env, "CLAUDOMATOR_"+strings.ToUpper(k)+"="+v[0])
+ }
+ }
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
diff --git a/internal/api/server.go b/internal/api/server.go
index df35536..163f2b8 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -488,9 +488,10 @@ func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, t)
}
-
func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
+ agent := r.URL.Query().Get("agent")
+
t, err := s.store.ResetTaskForRetry(id)
if err != nil {
if strings.Contains(err.Error(), "not found") {
@@ -505,6 +506,10 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
return
}
+ if agent != "" && agent != "auto" {
+ t.Agent.Type = agent
+ }
+
if err := s.pool.Submit(context.Background(), t); err != nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)})
return
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index c90e3b3..2209a69 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -384,6 +384,23 @@ func TestRunTask_TimedOutTask_Returns202(t *testing.T) {
}
}
+func TestRunTask_WithAgentParam(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "run-agent-param", task.StatePending)
+
+ // Request run with agent=gemini.
+ req := httptest.NewRequest("POST", "/api/tasks/run-agent-param/run?agent=gemini", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ // Wait for the task to complete via the mock runner.
+ pollState(t, store, "run-agent-param", task.StateReady, 2*time.Second)
+}
+
func TestRunTask_CompletedTask_Returns409(t *testing.T) {
srv, store := testServer(t)
createTaskWithState(t, store, "run-completed", task.StateCompleted)
diff --git a/internal/cli/create.go b/internal/cli/create.go
index e5435d3..396cd77 100644
--- a/internal/cli/create.go
+++ b/internal/cli/create.go
@@ -88,7 +88,7 @@ func createTask(serverURL, name, instructions, workingDir, model, agentType, par
fmt.Printf("Created task %s\n", id)
if autoStart {
- return startTask(serverURL, id)
+ return startTask(serverURL, id, agentType)
}
return nil
}
diff --git a/internal/cli/create_test.go b/internal/cli/create_test.go
index 4ce1071..71b403e 100644
--- a/internal/cli/create_test.go
+++ b/internal/cli/create_test.go
@@ -42,7 +42,7 @@ func TestStartTask_EscapesTaskID(t *testing.T) {
}))
defer srv.Close()
- err := startTask(srv.URL, "task/with/slashes")
+ err := startTask(srv.URL, "task/with/slashes", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -93,7 +93,7 @@ func TestStartTask_NonJSONResponse_ReturnsError(t *testing.T) {
}))
defer srv.Close()
- err := startTask(srv.URL, "task-abc")
+ err := startTask(srv.URL, "task-abc", "")
if err == nil {
t.Fatal("expected error for non-JSON response, got nil")
}
@@ -115,7 +115,7 @@ func TestStartTask_TimesOut(t *testing.T) {
httpClient = &http.Client{Timeout: 50 * time.Millisecond}
defer func() { httpClient = orig }()
- err := startTask(srv.URL, "task-abc")
+ err := startTask(srv.URL, "task-abc", "")
if err == nil {
t.Fatal("expected timeout error, got nil")
}
diff --git a/internal/cli/start.go b/internal/cli/start.go
index 9e66e00..99af9a5 100644
--- a/internal/cli/start.go
+++ b/internal/cli/start.go
@@ -8,28 +8,32 @@ import (
"github.com/spf13/cobra"
)
-
func newStartCmd() *cobra.Command {
var serverURL string
+ var agent string
cmd := &cobra.Command{
Use: "start <task-id>",
Short: "Queue a task for execution via the running server",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- return startTask(serverURL, args[0])
+ return startTask(serverURL, args[0], agent)
},
}
cmd.Flags().StringVar(&serverURL, "server", "http://localhost:8484", "claudomator server URL")
+ cmd.Flags().StringVar(&agent, "agent", "", "agent to use (claude, gemini, or auto)")
return cmd
}
-func startTask(serverURL, id string) error {
- url := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, url.PathEscape(id))
- resp, err := httpClient.Post(url, "application/json", nil) //nolint:noctx
+func startTask(serverURL, id, agent string) error {
+ u := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, url.PathEscape(id))
+ if agent != "" {
+ u += "?agent=" + url.QueryEscape(agent)
+ }
+ resp, err := httpClient.Post(u, "application/json", nil) //nolint:noctx
if err != nil {
- return fmt.Errorf("POST %s: %w", url, err)
+ return fmt.Errorf("POST %s: %w", u, err)
}
defer resp.Body.Close()
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 7ae4e2d..bf209b7 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -414,19 +414,24 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
}
p.mu.Unlock()
- // Deterministically pick the agent with fewest active tasks.
- selectedAgent := pickAgent(status)
- if selectedAgent != "" {
- t.Agent.Type = selectedAgent
- }
+ // If a specific agent is already requested, skip selection and classification.
+ skipClassification := t.Agent.Type == "claude" || t.Agent.Type == "gemini"
+
+ if !skipClassification {
+ // Deterministically pick the agent with fewest active tasks.
+ selectedAgent := pickAgent(status)
+ if selectedAgent != "" {
+ t.Agent.Type = selectedAgent
+ }
- if p.Classifier != nil {
- cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type)
- if err == nil {
- p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason)
- t.Agent.Model = cls.Model
- } else {
- p.logger.Error("classification failed", "error", err, "taskID", t.ID)
+ if p.Classifier != nil {
+ cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status, t.Agent.Type)
+ if err == nil {
+ p.logger.Info("task classified", "taskID", t.ID, "agent", t.Agent.Type, "model", cls.Model, "reason", cls.Reason)
+ t.Agent.Model = cls.Model
+ } else {
+ p.logger.Error("classification failed", "error", err, "taskID", t.ID)
+ }
}
}
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
index 7e676eb..17982f8 100644
--- a/internal/executor/executor_test.go
+++ b/internal/executor/executor_test.go
@@ -1121,3 +1121,41 @@ func TestPool_LoadBalancing_OverridesAgentType(t *testing.T) {
t.Errorf("expected claude runner to be called once, got %d", runner.callCount())
}
}
+
+// TestPool_SpecificAgent_SkipsLoadBalancing verifies that if a specific
+// registered agent is requested (claude or gemini), it is used directly
+// and load balancing (pickAgent) is skipped.
+func TestPool_SpecificAgent_SkipsLoadBalancing(t *testing.T) {
+ store := testStore(t)
+ claudeRunner := &mockRunner{}
+ geminiRunner := &mockRunner{}
+ runners := map[string]Runner{
+ "claude": claudeRunner,
+ "gemini": geminiRunner,
+ }
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(4, runners, store, logger)
+
+ // Inject 2 active tasks for gemini, 0 for claude.
+ // pickAgent would normally pick "claude".
+ pool.mu.Lock()
+ pool.activePerAgent["gemini"] = 2
+ pool.mu.Unlock()
+
+ tk := makeTask("specific-gemini")
+ tk.Agent.Type = "gemini"
+ store.CreateTask(tk)
+
+ if err := pool.Submit(context.Background(), tk); err != nil {
+ t.Fatalf("submit: %v", err)
+ }
+
+ <-pool.Results()
+
+ if geminiRunner.callCount() != 1 {
+ t.Errorf("expected gemini runner to be called once, got %d", geminiRunner.callCount())
+ }
+ if claudeRunner.callCount() != 0 {
+ t.Errorf("expected claude runner to NOT be called, got %d", claudeRunner.callCount())
+ }
+}