diff options
30 files changed, 1005 insertions, 206 deletions
diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go index 8a18dee..907cb98 100644 --- a/internal/api/elaborate.go +++ b/internal/api/elaborate.go @@ -18,7 +18,7 @@ func buildElaboratePrompt(workDir string) string { if workDir != "" { workDirLine = fmt.Sprintf(` "project_dir": string — use %q for tasks that operate on this codebase, empty string otherwise,`, workDir) } - return `You are a task configuration assistant for Claudomator, an AI task runner that executes tasks by running Claude as a subprocess. + return `You are a task configuration assistant for Claudomator, an AI task runner that executes tasks by running Claude or Gemini as a subprocess. Your ONLY job is to convert any user request into a Claudomator task JSON object. You MUST always output valid JSON. Never ask clarifying questions. Never explain. Never refuse. Make reasonable assumptions and produce the JSON. @@ -27,9 +27,10 @@ Output ONLY a valid JSON object matching this schema (no markdown fences, no pro { "name": string — short imperative title (≤60 chars), "description": string — 1-2 sentence summary, - "claude": { - "model": string — "sonnet" unless the task obviously needs opus, - "instructions": string — detailed, step-by-step instructions for Claude, + "agent": { + "type": "claude" | "gemini", + "model": string — "sonnet" for claude, "gemini-2.0-flash" for gemini, + "instructions": string — detailed, step-by-step instructions for the agent, ` + workDirLine + ` "max_budget_usd": number — conservative estimate (0.25–5.00), "allowed_tools": array — only tools the task genuinely needs @@ -44,13 +45,14 @@ Output ONLY a valid JSON object matching this schema (no markdown fences, no pro type elaboratedTask struct { Name string `json:"name"` Description string `json:"description"` - Claude elaboratedClaude `json:"claude"` + Agent elaboratedAgent `json:"agent"` Timeout string `json:"timeout"` Priority string `json:"priority"` Tags []string `json:"tags"` } -type elaboratedClaude struct { +type elaboratedAgent struct { + Type string `json:"type"` Model string `json:"model"` Instructions string `json:"instructions"` ProjectDir string `json:"project_dir"` @@ -149,12 +151,16 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { return } - if result.Name == "" || result.Claude.Instructions == "" { + if result.Name == "" || result.Agent.Instructions == "" { writeJSON(w, http.StatusBadGateway, map[string]string{ "error": "elaboration failed: missing required fields in response", }) return } + if result.Agent.Type == "" { + result.Agent.Type = "claude" + } + writeJSON(w, http.StatusOK, result) } diff --git a/internal/api/elaborate_test.go b/internal/api/elaborate_test.go index 09f7fbe..b33ca11 100644 --- a/internal/api/elaborate_test.go +++ b/internal/api/elaborate_test.go @@ -53,7 +53,8 @@ func TestElaborateTask_Success(t *testing.T) { task := elaboratedTask{ Name: "Run Go tests with race detector", Description: "Runs the Go test suite with -race flag and checks coverage.", - Claude: elaboratedClaude{ + Agent: elaboratedAgent{ + Type: "claude", Model: "sonnet", Instructions: "Run go test -race ./... and report results.", ProjectDir: "", @@ -94,7 +95,7 @@ func TestElaborateTask_Success(t *testing.T) { if result.Name == "" { t.Error("expected non-empty name") } - if result.Claude.Instructions == "" { + if result.Agent.Instructions == "" { t.Error("expected non-empty instructions") } } @@ -127,7 +128,8 @@ func TestElaborateTask_MarkdownFencedJSON(t *testing.T) { task := elaboratedTask{ Name: "Test task", Description: "Does something.", - Claude: elaboratedClaude{ + Agent: elaboratedAgent{ + Type: "claude", Model: "sonnet", Instructions: "Do the thing.", MaxBudgetUSD: 0.5, diff --git a/internal/api/server.go b/internal/api/server.go index 833be8b..3d7cb1e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -37,6 +37,7 @@ type Server struct { logger *slog.Logger mux *http.ServeMux claudeBinPath string // path to claude binary; defaults to "claude" + geminiBinPath string // path to gemini binary; defaults to "gemini" elaborateCmdPath string // overrides claudeBinPath; used in tests validateCmdPath string // overrides claudeBinPath for validate; used in tests scripts ScriptRegistry // optional; maps endpoint name → script path @@ -56,7 +57,7 @@ func (s *Server) SetNotifier(n notify.Notifier) { s.notifier = n } -func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath string) *Server { +func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server { wd, _ := os.Getwd() s := &Server{ store: store, @@ -68,6 +69,7 @@ func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, clau logger: logger, mux: http.NewServeMux(), claudeBinPath: claudeBinPath, + geminiBinPath: geminiBinPath, workDir: wd, } s.routes() @@ -346,7 +348,7 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { var input struct { Name string `json:"name"` Description string `json:"description"` - Claude task.ClaudeConfig `json:"claude"` + Agent task.AgentConfig `json:"agent"` Timeout string `json:"timeout"` Priority string `json:"priority"` Tags []string `json:"tags"` @@ -362,7 +364,7 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { ID: uuid.New().String(), Name: input.Name, Description: input.Description, - Claude: input.Claude, + Agent: input.Agent, Priority: task.Priority(input.Priority), Tags: input.Tags, DependsOn: []string{}, @@ -372,6 +374,9 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { UpdatedAt: now, ParentTaskID: input.ParentTaskID, } + if t.Agent.Type == "" { + t.Agent.Type = "claude" + } if t.Priority == "" { t.Priority = task.PriorityNormal } diff --git a/internal/api/server_test.go b/internal/api/server_test.go index c3b12ce..cd415ae 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -84,8 +84,12 @@ func testServer(t *testing.T) (*Server, *storage.DB) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) runner := &mockRunner{} - pool := executor.NewPool(2, runner, store, logger) - srv := NewServer(store, pool, logger, "claude") + runners := map[string]executor.Runner{ + "claude": runner, + "gemini": runner, + } + pool := executor.NewPool(2, runners, store, logger) + srv := NewServer(store, pool, logger, "claude", "gemini") return srv, store } @@ -118,7 +122,8 @@ func TestCreateTask_Success(t *testing.T) { payload := `{ "name": "API Task", "description": "Created via API", - "claude": { + "agent": { + "type": "claude", "instructions": "do the thing", "model": "sonnet" }, @@ -160,7 +165,7 @@ func TestCreateTask_InvalidJSON(t *testing.T) { func TestCreateTask_ValidationFailure(t *testing.T) { srv, _ := testServer(t) - payload := `{"name": "", "claude": {"instructions": ""}}` + payload := `{"name": "", "agent": {"type": "claude", "instructions": ""}}` req := httptest.NewRequest("POST", "/api/tasks", bytes.NewBufferString(payload)) w := httptest.NewRecorder() srv.Handler().ServeHTTP(w, req) @@ -207,7 +212,7 @@ func TestListTasks_WithTasks(t *testing.T) { for i := 0; i < 3; i++ { tk := &task.Task{ ID: fmt.Sprintf("lt-%d", i), Name: fmt.Sprintf("T%d", i), - Claude: task.ClaudeConfig{Instructions: "x"}, Priority: task.PriorityNormal, + Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: task.StatePending, } @@ -244,7 +249,7 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task. tk := &task.Task{ ID: id, Name: "test-task-" + id, - Claude: task.ClaudeConfig{Instructions: "do something"}, + Agent: task.AgentConfig{Type: "claude", Instructions: "do something"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: task.StatePending, @@ -606,7 +611,7 @@ func TestRunTask_RetryLimitReached_Returns409(t *testing.T) { tk := &task.Task{ ID: "retry-limit-1", Name: "Retry Limit Task", - Claude: task.ClaudeConfig{Instructions: "do something"}, + Agent: task.AgentConfig{Instructions: "do something"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, @@ -647,7 +652,7 @@ func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) { tk := &task.Task{ ID: "retry-within-1", Name: "Retry Within Task", - Claude: task.ClaudeConfig{Instructions: "do something"}, + Agent: task.AgentConfig{Instructions: "do something"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"}, Tags: []string{}, @@ -694,7 +699,7 @@ func TestDeleteTask_Success(t *testing.T) { srv, store := testServer(t) // Create a task to delete. - created := createTestTask(t, srv, `{"name":"Delete Me","claude":{"instructions":"x","model":"sonnet"}}`) + created := createTestTask(t, srv, `{"name":"Delete Me","agent":{"type":"claude","instructions":"x","model":"sonnet"}}`) req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil) w := httptest.NewRecorder() @@ -729,7 +734,7 @@ func TestDeleteTask_RunningTaskRejected(t *testing.T) { tk := &task.Task{ ID: "running-task-del", Name: "Running Task", - Claude: task.ClaudeConfig{Instructions: "x", Model: "sonnet"}, + Agent: task.AgentConfig{Instructions: "x", Model: "sonnet"}, Priority: task.PriorityNormal, Tags: []string{}, DependsOn: []string{}, diff --git a/internal/api/templates.go b/internal/api/templates.go index 0139895..024a6df 100644 --- a/internal/api/templates.go +++ b/internal/api/templates.go @@ -27,7 +27,7 @@ func (s *Server) handleCreateTemplate(w http.ResponseWriter, r *http.Request) { var input struct { Name string `json:"name"` Description string `json:"description"` - Claude task.ClaudeConfig `json:"claude"` + Agent task.AgentConfig `json:"agent"` Timeout string `json:"timeout"` Priority string `json:"priority"` Tags []string `json:"tags"` @@ -46,13 +46,16 @@ func (s *Server) handleCreateTemplate(w http.ResponseWriter, r *http.Request) { ID: uuid.New().String(), Name: input.Name, Description: input.Description, - Claude: input.Claude, + Agent: input.Agent, Timeout: input.Timeout, Priority: input.Priority, Tags: input.Tags, CreatedAt: now, UpdatedAt: now, } + if tmpl.Agent.Type == "" { + tmpl.Agent.Type = "claude" + } if tmpl.Priority == "" { tmpl.Priority = "normal" } @@ -98,7 +101,7 @@ func (s *Server) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) { var input struct { Name string `json:"name"` Description string `json:"description"` - Claude task.ClaudeConfig `json:"claude"` + Agent task.AgentConfig `json:"agent"` Timeout string `json:"timeout"` Priority string `json:"priority"` Tags []string `json:"tags"` @@ -114,7 +117,10 @@ func (s *Server) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) { existing.Name = input.Name existing.Description = input.Description - existing.Claude = input.Claude + existing.Agent = input.Agent + if existing.Agent.Type == "" { + existing.Agent.Type = "claude" + } existing.Timeout = input.Timeout existing.Priority = input.Priority if input.Tags != nil { diff --git a/internal/api/templates_test.go b/internal/api/templates_test.go index bbcfc87..474c5d4 100644 --- a/internal/api/templates_test.go +++ b/internal/api/templates_test.go @@ -34,7 +34,8 @@ func TestCreateTemplate_Success(t *testing.T) { payload := `{ "name": "Go: Run Tests", "description": "Run the full test suite with race detector", - "claude": { + "agent": { + "type": "claude", "model": "sonnet", "instructions": "Run go test -race ./...", "max_budget_usd": 0.50, @@ -65,7 +66,7 @@ func TestCreateTemplate_Success(t *testing.T) { func TestGetTemplate_AfterCreate(t *testing.T) { srv, _ := testServer(t) - payload := `{"name": "Fetch Me", "claude": {"instructions": "do thing", "model": "haiku"}}` + payload := `{"name": "Fetch Me", "agent": {"type": "claude", "instructions": "do thing", "model": "haiku"}}` req := httptest.NewRequest("POST", "/api/templates", bytes.NewBufferString(payload)) w := httptest.NewRecorder() srv.Handler().ServeHTTP(w, req) @@ -107,14 +108,14 @@ func TestGetTemplate_NotFound(t *testing.T) { func TestUpdateTemplate(t *testing.T) { srv, _ := testServer(t) - payload := `{"name": "Original Name", "claude": {"instructions": "original"}}` + payload := `{"name": "Original Name", "agent": {"type": "claude", "instructions": "original"}}` req := httptest.NewRequest("POST", "/api/templates", bytes.NewBufferString(payload)) w := httptest.NewRecorder() srv.Handler().ServeHTTP(w, req) var created storage.Template json.NewDecoder(w.Body).Decode(&created) - update := `{"name": "Updated Name", "claude": {"instructions": "updated"}}` + update := `{"name": "Updated Name", "agent": {"type": "claude", "instructions": "updated"}}` req2 := httptest.NewRequest("PUT", fmt.Sprintf("/api/templates/%s", created.ID), bytes.NewBufferString(update)) w2 := httptest.NewRecorder() srv.Handler().ServeHTTP(w2, req2) @@ -132,7 +133,7 @@ func TestUpdateTemplate(t *testing.T) { func TestUpdateTemplate_NotFound(t *testing.T) { srv, _ := testServer(t) - update := `{"name": "Ghost", "claude": {"instructions": "x"}}` + update := `{"name": "Ghost", "agent": {"type": "claude", "instructions": "x"}}` req := httptest.NewRequest("PUT", "/api/templates/nonexistent", bytes.NewBufferString(update)) w := httptest.NewRecorder() srv.Handler().ServeHTTP(w, req) @@ -145,7 +146,7 @@ func TestUpdateTemplate_NotFound(t *testing.T) { func TestDeleteTemplate(t *testing.T) { srv, _ := testServer(t) - payload := `{"name": "To Delete", "claude": {"instructions": "bye"}}` + payload := `{"name": "To Delete", "agent": {"type": "claude", "instructions": "bye"}}` req := httptest.NewRequest("POST", "/api/templates", bytes.NewBufferString(payload)) w := httptest.NewRecorder() srv.Handler().ServeHTTP(w, req) diff --git a/internal/api/validate.go b/internal/api/validate.go index 0fcdb47..07d293c 100644 --- a/internal/api/validate.go +++ b/internal/api/validate.go @@ -12,7 +12,7 @@ import ( const validateTimeout = 20 * time.Second -const validateSystemPrompt = `You are a task instruction reviewer for Claudomator, an AI task runner that executes tasks by running Claude as a subprocess. +const validateSystemPrompt = `You are a task instruction reviewer for Claudomator, an AI task runner that executes tasks by running Claude or Gemini as a subprocess. Analyze the given task name and instructions for clarity and completeness. @@ -48,7 +48,7 @@ func (s *Server) validateBinaryPath() string { if s.validateCmdPath != "" { return s.validateCmdPath } - return s.claudeBinaryPath() + return s.claudeBinPath } func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) { @@ -59,11 +59,13 @@ func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) { var input struct { Name string `json:"name"` - Claude struct { + Agent struct { + Type string `json:"type"` Instructions string `json:"instructions"` ProjectDir string `json:"project_dir"` + WorkingDir string `json:"working_dir"` // legacy AllowedTools []string `json:"allowed_tools"` - } `json:"claude"` + } `json:"agent"` } if err := json.NewDecoder(r.Body).Decode(&input); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) @@ -73,17 +75,27 @@ func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) return } - if input.Claude.Instructions == "" { + if input.Agent.Instructions == "" { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "instructions are required"}) return } - userMsg := fmt.Sprintf("Task name: %s\n\nInstructions:\n%s", input.Name, input.Claude.Instructions) - if input.Claude.ProjectDir != "" { - userMsg += fmt.Sprintf("\n\nWorking directory: %s", input.Claude.ProjectDir) + agentType := input.Agent.Type + if agentType == "" { + agentType = "claude" } - if len(input.Claude.AllowedTools) > 0 { - userMsg += fmt.Sprintf("\n\nAllowed tools: %v", input.Claude.AllowedTools) + + projectDir := input.Agent.ProjectDir + if projectDir == "" { + projectDir = input.Agent.WorkingDir + } + + userMsg := fmt.Sprintf("Task name: %s\nAgent: %s\n\nInstructions:\n%s", input.Name, agentType, input.Agent.Instructions) + if projectDir != "" { + userMsg += fmt.Sprintf("\n\nWorking directory: %s", projectDir) + } + if len(input.Agent.AllowedTools) > 0 { + userMsg += fmt.Sprintf("\n\nAllowed tools: %v", input.Agent.AllowedTools) } ctx, cancel := context.WithTimeout(r.Context(), validateTimeout) diff --git a/internal/api/validate_test.go b/internal/api/validate_test.go index 5a1246b..c3d7b1f 100644 --- a/internal/api/validate_test.go +++ b/internal/api/validate_test.go @@ -23,7 +23,7 @@ func TestValidateTask_Success(t *testing.T) { wrapperJSON, _ := json.Marshal(wrapper) srv.validateCmdPath = createFakeClaude(t, string(wrapperJSON), 0) - body := `{"name":"Test Task","claude":{"instructions":"Run go test ./... and report results."}}` + body := `{"name":"Test Task","agent":{"instructions":"Run go test ./... and report results."}}` req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() @@ -46,7 +46,7 @@ func TestValidateTask_Success(t *testing.T) { func TestValidateTask_MissingInstructions(t *testing.T) { srv, _ := testServer(t) - body := `{"name":"Test Task","claude":{"instructions":""}}` + body := `{"name":"Test Task","agent":{"instructions":""}}` req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() @@ -61,7 +61,7 @@ func TestValidateTask_MissingInstructions(t *testing.T) { func TestValidateTask_MissingName(t *testing.T) { srv, _ := testServer(t) - body := `{"name":"","claude":{"instructions":"Do something useful."}}` + body := `{"name":"","agent":{"instructions":"Do something useful."}}` req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() @@ -77,7 +77,7 @@ func TestValidateTask_BadJSONFromClaude(t *testing.T) { srv, _ := testServer(t) srv.validateCmdPath = createFakeClaude(t, "not valid json at all", 0) - body := `{"name":"Test Task","claude":{"instructions":"Do something useful."}}` + body := `{"name":"Test Task","agent":{"instructions":"Do something useful."}}` req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() diff --git a/internal/cli/run.go b/internal/cli/run.go index ebf371c..49aa28e 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -54,7 +54,7 @@ func runTasks(file string, parallel int, dryRun bool) error { if dryRun { fmt.Printf("Validated %d task(s) successfully.\n", len(tasks)) for _, t := range tasks { - fmt.Printf(" - %s (model: %s, timeout: %v)\n", t.Name, t.Claude.Model, t.Timeout.Duration) + fmt.Printf(" - %s (model: %s, timeout: %v)\n", t.Name, t.Agent.Model, t.Timeout.Duration) } return nil } @@ -72,12 +72,22 @@ func runTasks(file string, parallel int, dryRun bool) error { logger := newLogger(verbose) - runner := &executor.ClaudeRunner{ - BinaryPath: cfg.ClaudeBinaryPath, - Logger: logger, - LogDir: cfg.LogDir, + runners := map[string]executor.Runner{ + "claude": &executor.ClaudeRunner{ + BinaryPath: cfg.ClaudeBinaryPath, + Logger: logger, + LogDir: cfg.LogDir, + }, + "gemini": &executor.GeminiRunner{ + BinaryPath: cfg.GeminiBinaryPath, + Logger: logger, + LogDir: cfg.LogDir, + }, + } + pool := executor.NewPool(parallel, runners, store, logger) + if cfg.GeminiBinaryPath != "" { + pool.Classifier = &executor.Classifier{GeminiBinaryPath: cfg.GeminiBinaryPath} } - pool := executor.NewPool(parallel, runner, store, logger) // Handle graceful shutdown. ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/cli/serve.go b/internal/cli/serve.go index cd5bfce..36a53b5 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -50,15 +50,28 @@ func serve(addr string) error { if len(addr) > 0 && addr[0] != ':' { apiURL = "http://" + addr } - runner := &executor.ClaudeRunner{ - BinaryPath: cfg.ClaudeBinaryPath, - Logger: logger, - LogDir: cfg.LogDir, - APIURL: apiURL, + + runners := map[string]executor.Runner{ + "claude": &executor.ClaudeRunner{ + BinaryPath: cfg.ClaudeBinaryPath, + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + }, + "gemini": &executor.GeminiRunner{ + BinaryPath: cfg.GeminiBinaryPath, + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + }, + } + + pool := executor.NewPool(cfg.MaxConcurrent, runners, store, logger) + if cfg.GeminiBinaryPath != "" { + pool.Classifier = &executor.Classifier{GeminiBinaryPath: cfg.GeminiBinaryPath} } - pool := executor.NewPool(cfg.MaxConcurrent, runner, store, logger) - srv := api.NewServer(store, pool, logger, cfg.ClaudeBinaryPath) + srv := api.NewServer(store, pool, logger, cfg.ClaudeBinaryPath, cfg.GeminiBinaryPath) if cfg.WebhookURL != "" { srv.SetNotifier(notify.NewWebhookNotifier(cfg.WebhookURL, logger)) } diff --git a/internal/cli/status.go b/internal/cli/status.go index 4e0461e..16b88b0 100644 --- a/internal/cli/status.go +++ b/internal/cli/status.go @@ -38,7 +38,7 @@ func showStatus(id string) error { fmt.Printf("ID: %s\n", t.ID) fmt.Printf("State: %s\n", t.State) fmt.Printf("Priority: %s\n", t.Priority) - fmt.Printf("Model: %s\n", t.Claude.Model) + fmt.Printf("Model: %s\n", t.Agent.Model) if t.Description != "" { fmt.Printf("Description: %s\n", t.Description) } diff --git a/internal/config/config.go b/internal/config/config.go index 12adf68..d3d9d68 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,6 +12,7 @@ type Config struct { DBPath string `toml:"-"` LogDir string `toml:"-"` ClaudeBinaryPath string `toml:"claude_binary_path"` + GeminiBinaryPath string `toml:"gemini_binary_path"` MaxConcurrent int `toml:"max_concurrent"` DefaultTimeout string `toml:"default_timeout"` ServerAddr string `toml:"server_addr"` @@ -32,6 +33,7 @@ func Default() (*Config, error) { DBPath: filepath.Join(dataDir, "claudomator.db"), LogDir: filepath.Join(dataDir, "executions"), ClaudeBinaryPath: "claude", + GeminiBinaryPath: "gemini", MaxConcurrent: 3, DefaultTimeout: "15m", ServerAddr: ":8484", diff --git a/internal/executor/classifier.go b/internal/executor/classifier.go new file mode 100644 index 0000000..79ebc27 --- /dev/null +++ b/internal/executor/classifier.go @@ -0,0 +1,109 @@ +package executor + +import ( + "context" + "encoding/json" + "fmt" + "os/exec" + "strings" +) + +type Classification struct { + AgentType string `json:"agent_type"` + Model string `json:"model"` + Reason string `json:"reason"` +} + +type SystemStatus struct { + ActiveTasks map[string]int + RateLimited map[string]bool +} + +type Classifier struct { + GeminiBinaryPath string +} + +const classificationPrompt = ` +You are a task classifier for Claudomator. +Given a task description and system status, select the best agent (claude or gemini) and model to use. + +Agent Types: +- claude: Best for complex coding, reasoning, and tool use. +- gemini: Best for large context, fast reasoning, and multimodal tasks. + +Available Models: +Claude: +- claude-3-5-sonnet-latest (balanced) +- claude-3-5-sonnet-20241022 (stable) +- claude-3-opus-20240229 (most powerful, expensive) +- claude-3-5-haiku-20241022 (fast, cheap) + +Gemini: +- gemini-2.0-flash-lite (fastest, most efficient, best for simple tasks) +- gemini-2.0-flash (fast, multimodal) +- gemini-1.5-flash (fast, balanced) +- gemini-1.5-pro (more powerful, larger context) + +Selection Criteria: +- Agent: Prefer the one with least running tasks and no active rate limit. +- Model: Select based on task complexity. Use powerful models (opus, pro) for complex reasoning/coding, flash-lite/flash/haiku for simple tasks. + +Task: +Name: %s +Instructions: %s + +System Status: +%s + +Respond with ONLY a JSON object: +{ + "agent_type": "claude" | "gemini", + "model": "model-name", + "reason": "brief reason" +} +` + +func (c *Classifier) Classify(ctx context.Context, taskName, instructions string, status SystemStatus) (*Classification, error) { + statusStr := "" + for agent, active := range status.ActiveTasks { + statusStr += fmt.Sprintf("- Agent %s: %d active tasks, Rate Limited: %t\n", agent, active, status.RateLimited[agent]) + } + + prompt := fmt.Sprintf(classificationPrompt, + taskName, instructions, statusStr, + ) + + binary := c.GeminiBinaryPath + if binary == "" { + binary = "gemini" + } + + // Use a minimal model for classification to be fast and cheap. + args := []string{ + "--prompt", prompt, + "--model", "gemini-2.0-flash-lite", + "--output-format", "json", + } + + cmd := exec.CommandContext(ctx, binary, args...) + out, err := cmd.Output() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return nil, fmt.Errorf("classifier failed (%v): %s", err, string(exitErr.Stderr)) + } + return nil, fmt.Errorf("classifier failed: %w", err) + } + + var cls Classification + // Gemini might wrap the JSON in markdown code blocks. + cleanOut := strings.TrimSpace(string(out)) + cleanOut = strings.TrimPrefix(cleanOut, "```json") + cleanOut = strings.TrimSuffix(cleanOut, "```") + cleanOut = strings.TrimSpace(cleanOut) + + if err := json.Unmarshal([]byte(cleanOut), &cls); err != nil { + return nil, fmt.Errorf("failed to parse classification JSON: %w\nOutput: %s", err, cleanOut) + } + + return &cls, nil +} diff --git a/internal/executor/classifier_test.go b/internal/executor/classifier_test.go new file mode 100644 index 0000000..4de44ca --- /dev/null +++ b/internal/executor/classifier_test.go @@ -0,0 +1,49 @@ +package executor + +import ( + "context" + "os" + "testing" +) + +// TestClassifier_Classify_Mock tests the classifier with a mocked gemini binary. +func TestClassifier_Classify_Mock(t *testing.T) { + // Create a temporary mock binary. + mockBinary := filepathJoin(t.TempDir(), "mock-gemini") + mockContent := `#!/bin/sh +echo '{"agent_type": "gemini", "model": "gemini-2.0-flash", "reason": "test reason"}' +` + if err := os.WriteFile(mockBinary, []byte(mockContent), 0755); err != nil { + t.Fatal(err) + } + + c := &Classifier{GeminiBinaryPath: mockBinary} + status := SystemStatus{ + ActiveTasks: map[string]int{"claude": 5, "gemini": 1}, + RateLimited: map[string]bool{"claude": false, "gemini": false}, + } + + cls, err := c.Classify(context.Background(), "Test Task", "Test Instructions", status) + if err != nil { + t.Fatalf("Classify failed: %v", err) + } + + if cls.AgentType != "gemini" { + t.Errorf("expected gemini, got %s", cls.AgentType) + } + if cls.Model != "gemini-2.0-flash" { + t.Errorf("expected gemini-2.0-flash, got %s", cls.Model) + } +} + +func filepathJoin(elems ...string) string { + var path string + for i, e := range elems { + if i == 0 { + path = e + } else { + path = path + string(os.PathSeparator) + e + } + } + return path +} diff --git a/internal/executor/claude.go b/internal/executor/claude.go index aa715da..e504369 100644 --- a/internal/executor/claude.go +++ b/internal/executor/claude.go @@ -61,7 +61,7 @@ func (r *ClaudeRunner) binaryPath() string { // changes back to project_dir. On failure the sandbox is preserved and its // path is included in the error. func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { - projectDir := t.Claude.ProjectDir + projectDir := t.Agent.ProjectDir // Validate project_dir exists when set. if projectDir != "" { @@ -319,21 +319,21 @@ func (r *ClaudeRunner) buildArgs(t *task.Task, e *storage.Execution, questionFil "--output-format", "stream-json", "--verbose", } - permMode := t.Claude.PermissionMode + permMode := t.Agent.PermissionMode if permMode == "" { permMode = "bypassPermissions" } args = append(args, "--permission-mode", permMode) - if t.Claude.Model != "" { - args = append(args, "--model", t.Claude.Model) + if t.Agent.Model != "" { + args = append(args, "--model", t.Agent.Model) } return args } - instructions := t.Claude.Instructions - allowedTools := t.Claude.AllowedTools + instructions := t.Agent.Instructions + allowedTools := t.Agent.AllowedTools - if !t.Claude.SkipPlanning { + if !t.Agent.SkipPlanning { instructions = withPlanningPreamble(instructions) // Ensure Bash is available so the agent can POST subtasks and ask questions. hasBash := false @@ -355,33 +355,33 @@ func (r *ClaudeRunner) buildArgs(t *task.Task, e *storage.Execution, questionFil "--verbose", } - if t.Claude.Model != "" { - args = append(args, "--model", t.Claude.Model) + if t.Agent.Model != "" { + args = append(args, "--model", t.Agent.Model) } - if t.Claude.MaxBudgetUSD > 0 { - args = append(args, "--max-budget-usd", fmt.Sprintf("%.2f", t.Claude.MaxBudgetUSD)) + if t.Agent.MaxBudgetUSD > 0 { + args = append(args, "--max-budget-usd", fmt.Sprintf("%.2f", t.Agent.MaxBudgetUSD)) } // Default to bypassPermissions — claudomator runs tasks unattended, so // prompting for write access would always stall execution. Tasks that need // a more restrictive mode can set permission_mode explicitly. - permMode := t.Claude.PermissionMode + permMode := t.Agent.PermissionMode if permMode == "" { permMode = "bypassPermissions" } args = append(args, "--permission-mode", permMode) - if t.Claude.SystemPromptAppend != "" { - args = append(args, "--append-system-prompt", t.Claude.SystemPromptAppend) + if t.Agent.SystemPromptAppend != "" { + args = append(args, "--append-system-prompt", t.Agent.SystemPromptAppend) } for _, tool := range allowedTools { args = append(args, "--allowedTools", tool) } - for _, tool := range t.Claude.DisallowedTools { + for _, tool := range t.Agent.DisallowedTools { args = append(args, "--disallowedTools", tool) } - for _, f := range t.Claude.ContextFiles { + for _, f := range t.Agent.ContextFiles { args = append(args, "--add-dir", f) } - args = append(args, t.Claude.AdditionalArgs...) + args = append(args, t.Agent.AdditionalArgs...) return args } diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go index 31dcf52..b5380f4 100644 --- a/internal/executor/claude_test.go +++ b/internal/executor/claude_test.go @@ -14,7 +14,8 @@ import ( func TestClaudeRunner_BuildArgs_BasicTask(t *testing.T) { r := &ClaudeRunner{} tk := &task.Task{ - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", Instructions: "fix the bug", Model: "sonnet", SkipPlanning: true, @@ -37,7 +38,8 @@ func TestClaudeRunner_BuildArgs_BasicTask(t *testing.T) { func TestClaudeRunner_BuildArgs_FullConfig(t *testing.T) { r := &ClaudeRunner{} tk := &task.Task{ - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", Instructions: "implement feature", Model: "opus", MaxBudgetUSD: 5.0, @@ -79,7 +81,8 @@ func TestClaudeRunner_BuildArgs_FullConfig(t *testing.T) { func TestClaudeRunner_BuildArgs_DefaultsToBypassPermissions(t *testing.T) { r := &ClaudeRunner{} tk := &task.Task{ - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", Instructions: "do work", SkipPlanning: true, // PermissionMode intentionally not set @@ -102,7 +105,8 @@ func TestClaudeRunner_BuildArgs_DefaultsToBypassPermissions(t *testing.T) { func TestClaudeRunner_BuildArgs_RespectsExplicitPermissionMode(t *testing.T) { r := &ClaudeRunner{} tk := &task.Task{ - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", Instructions: "do work", PermissionMode: "default", SkipPlanning: true, @@ -125,7 +129,8 @@ func TestClaudeRunner_BuildArgs_RespectsExplicitPermissionMode(t *testing.T) { func TestClaudeRunner_BuildArgs_AlwaysIncludesVerbose(t *testing.T) { r := &ClaudeRunner{} tk := &task.Task{ - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", Instructions: "do something", SkipPlanning: true, }, @@ -148,7 +153,8 @@ func TestClaudeRunner_BuildArgs_AlwaysIncludesVerbose(t *testing.T) { func TestClaudeRunner_BuildArgs_PreamblePrepended(t *testing.T) { r := &ClaudeRunner{} tk := &task.Task{ - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", Instructions: "fix the bug", SkipPlanning: false, }, @@ -171,7 +177,8 @@ func TestClaudeRunner_BuildArgs_PreamblePrepended(t *testing.T) { func TestClaudeRunner_BuildArgs_PreambleAddsBash(t *testing.T) { r := &ClaudeRunner{} tk := &task.Task{ - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", Instructions: "do work", AllowedTools: []string{"Read"}, SkipPlanning: false, @@ -195,7 +202,8 @@ func TestClaudeRunner_BuildArgs_PreambleAddsBash(t *testing.T) { func TestClaudeRunner_BuildArgs_PreambleBashNotDuplicated(t *testing.T) { r := &ClaudeRunner{} tk := &task.Task{ - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", Instructions: "do work", AllowedTools: []string{"Bash", "Read"}, SkipPlanning: false, @@ -223,7 +231,8 @@ func TestClaudeRunner_Run_InaccessibleWorkingDir_ReturnsError(t *testing.T) { LogDir: t.TempDir(), } tk := &task.Task{ - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", ProjectDir: "/nonexistent/path/does/not/exist", SkipPlanning: true, }, diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 0245899..d1c8e72 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -36,18 +36,21 @@ type workItem struct { // Pool manages a bounded set of concurrent task workers. type Pool struct { maxConcurrent int - runner Runner + runners map[string]Runner store *storage.DB logger *slog.Logger depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s - mu sync.Mutex - active int - cancels map[string]context.CancelFunc // taskID → cancel - resultCh chan *Result - workCh chan workItem // internal bounded queue; Submit enqueues here - doneCh chan struct{} // signals when a worker slot is freed - Questions *QuestionRegistry + mu sync.Mutex + active int + activePerAgent map[string]int + rateLimited map[string]time.Time // agentType -> until + cancels map[string]context.CancelFunc // taskID → cancel + resultCh chan *Result + workCh chan workItem // internal bounded queue; Submit enqueues here + doneCh chan struct{} // signals when a worker slot is freed + Questions *QuestionRegistry + Classifier *Classifier } // Result is emitted when a task execution completes. @@ -57,16 +60,18 @@ type Result struct { Err error } -func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.Logger) *Pool { +func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, logger *slog.Logger) *Pool { if maxConcurrent < 1 { maxConcurrent = 1 } p := &Pool{ maxConcurrent: maxConcurrent, - runner: runner, + runners: runners, store: store, logger: logger, depPollInterval: 5 * time.Second, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), cancels: make(map[string]context.CancelFunc), resultCh: make(chan *Result, maxConcurrent*2), workCh: make(chan workItem, maxConcurrent*10+100), @@ -147,10 +152,32 @@ func (p *Pool) SubmitResume(ctx context.Context, t *task.Task, exec *storage.Exe } } +func (p *Pool) getRunner(t *task.Task) (Runner, error) { + agentType := t.Agent.Type + if agentType == "" { + agentType = "claude" // Default for backward compatibility + } + runner, ok := p.runners[agentType] + if !ok { + return nil, fmt.Errorf("unsupported agent type: %q", agentType) + } + return runner, nil +} + func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Execution) { + agentType := t.Agent.Type + if agentType == "" { + agentType = "claude" + } + + p.mu.Lock() + p.activePerAgent[agentType]++ + p.mu.Unlock() + defer func() { p.mu.Lock() p.active-- + p.activePerAgent[agentType]-- p.mu.Unlock() select { case p.doneCh <- struct{}{}: @@ -158,8 +185,15 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex } }() + runner, err := p.getRunner(t) + if err != nil { + p.logger.Error("failed to get runner for resume", "error", err, "taskID", t.ID) + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} + return + } + // Pre-populate log paths. - if lp, ok := p.runner.(LogPather); ok { + if lp, ok := runner.(LogPather); ok { if logDir := lp.ExecLogDir(exec.ID); logDir != "" { exec.StdoutPath = filepath.Join(logDir, "stdout.log") exec.StderrPath = filepath.Join(logDir, "stderr.log") @@ -182,12 +216,30 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex } else { ctx, cancel = context.WithCancel(ctx) } - defer cancel() + p.mu.Lock() + p.cancels[t.ID] = cancel + p.mu.Unlock() + defer func() { + cancel() + p.mu.Lock() + delete(p.cancels, t.ID) + p.mu.Unlock() + }() - err := p.runner.Run(ctx, t, exec) + err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() if err != nil { + if isRateLimitError(err) { + p.mu.Lock() + retryAfter := parseRetryAfter(err.Error()) + if retryAfter == 0 { + retryAfter = 1 * time.Minute + } + p.rateLimited[agentType] = time.Now().Add(retryAfter) + p.mu.Unlock() + } + var blockedErr *BlockedError if errors.As(err, &blockedErr) { exec.Status = "BLOCKED" @@ -234,9 +286,45 @@ func (p *Pool) ActiveCount() int { } func (p *Pool) execute(ctx context.Context, t *task.Task) { + // 1. Classification + if p.Classifier != nil { + p.mu.Lock() + activeTasks := make(map[string]int) + rateLimited := make(map[string]bool) + now := time.Now() + for agent := range p.runners { + activeTasks[agent] = p.activePerAgent[agent] + rateLimited[agent] = now.Before(p.rateLimited[agent]) + } + status := SystemStatus{ + ActiveTasks: activeTasks, + RateLimited: rateLimited, + } + p.mu.Unlock() + + cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status) + if err == nil { + p.logger.Info("task classified", "taskID", t.ID, "agent", cls.AgentType, "model", cls.Model, "reason", cls.Reason) + t.Agent.Type = cls.AgentType + t.Agent.Model = cls.Model + } else { + p.logger.Error("classification failed", "error", err, "taskID", t.ID) + } + } + + agentType := t.Agent.Type + if agentType == "" { + agentType = "claude" + } + + p.mu.Lock() + p.activePerAgent[agentType]++ + p.mu.Unlock() + defer func() { p.mu.Lock() p.active-- + p.activePerAgent[agentType]-- p.mu.Unlock() select { case p.doneCh <- struct{}{}: @@ -244,6 +332,26 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { } }() + runner, err := p.getRunner(t) + if err != nil { + p.logger.Error("failed to get runner", "error", err, "taskID", t.ID) + now := time.Now().UTC() + exec := &storage.Execution{ + ID: uuid.New().String(), + TaskID: t.ID, + StartTime: now, + EndTime: now, + Status: "FAILED", + ErrorMsg: err.Error(), + } + if createErr := p.store.CreateExecution(exec); createErr != nil { + p.logger.Error("failed to create execution record", "error", createErr) + } + p.store.UpdateTaskState(t.ID, task.StateFailed) + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} + return + } + // Wait for all dependencies to complete before starting execution. if len(t.DependsOn) > 0 { if err := p.waitForDependencies(ctx, t); err != nil { @@ -275,7 +383,7 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { // Pre-populate log paths so they're available in the DB immediately — // before the subprocess starts — enabling live tailing and debugging. - if lp, ok := p.runner.(LogPather); ok { + if lp, ok := runner.(LogPather); ok { if logDir := lp.ExecLogDir(execID); logDir != "" { exec.StdoutPath = filepath.Join(logDir, "stdout.log") exec.StderrPath = filepath.Join(logDir, "stderr.log") @@ -309,10 +417,20 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { }() // Run the task. - err := p.runner.Run(ctx, t, exec) + err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() if err != nil { + if isRateLimitError(err) { + p.mu.Lock() + retryAfter := parseRetryAfter(err.Error()) + if retryAfter == 0 { + retryAfter = 1 * time.Minute + } + p.rateLimited[agentType] = time.Now().Add(retryAfter) + p.mu.Unlock() + } + var blockedErr *BlockedError if errors.As(err, &blockedErr) { exec.Status = "BLOCKED" diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index 414f852..028e5cf 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -64,7 +64,7 @@ func makeTask(id string) *task.Task { now := time.Now().UTC() return &task.Task{ ID: id, Name: "Test " + id, - Claude: task.ClaudeConfig{Instructions: "test"}, + Agent: task.AgentConfig{Type: "claude", Instructions: "test"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, @@ -77,8 +77,9 @@ func makeTask(id string) *task.Task { func TestPool_Submit_TopLevel_GoesToReady(t *testing.T) { store := testStore(t) runner := &mockRunner{} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(2, runner, store, logger) + pool := NewPool(2, runners, store, logger) tk := makeTask("ps-1") // no ParentTaskID → top-level store.CreateTask(tk) @@ -104,8 +105,9 @@ func TestPool_Submit_TopLevel_GoesToReady(t *testing.T) { func TestPool_Submit_Subtask_GoesToCompleted(t *testing.T) { store := testStore(t) runner := &mockRunner{} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(2, runner, store, logger) + pool := NewPool(2, runners, store, logger) tk := makeTask("sub-1") tk.ParentTaskID = "parent-99" // subtask @@ -132,8 +134,9 @@ func TestPool_Submit_Subtask_GoesToCompleted(t *testing.T) { func TestPool_Submit_Failure(t *testing.T) { store := testStore(t) runner := &mockRunner{err: fmt.Errorf("boom"), exitCode: 1} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(2, runner, store, logger) + pool := NewPool(2, runners, store, logger) tk := makeTask("pf-1") store.CreateTask(tk) @@ -151,8 +154,9 @@ func TestPool_Submit_Failure(t *testing.T) { func TestPool_Submit_Timeout(t *testing.T) { store := testStore(t) runner := &mockRunner{delay: 5 * time.Second} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(2, runner, store, logger) + pool := NewPool(2, runners, store, logger) tk := makeTask("pt-1") tk.Timeout.Duration = 50 * time.Millisecond @@ -168,8 +172,9 @@ func TestPool_Submit_Timeout(t *testing.T) { func TestPool_Submit_Cancellation(t *testing.T) { store := testStore(t) runner := &mockRunner{delay: 5 * time.Second} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(2, runner, store, logger) + pool := NewPool(2, runners, store, logger) ctx, cancel := context.WithCancel(context.Background()) tk := makeTask("pc-1") @@ -188,8 +193,9 @@ func TestPool_Submit_Cancellation(t *testing.T) { func TestPool_Cancel_StopsRunningTask(t *testing.T) { store := testStore(t) runner := &mockRunner{delay: 5 * time.Second} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(2, runner, store, logger) + pool := NewPool(2, runners, store, logger) tk := makeTask("cancel-1") store.CreateTask(tk) @@ -209,8 +215,9 @@ func TestPool_Cancel_StopsRunningTask(t *testing.T) { func TestPool_Cancel_UnknownTask_ReturnsFalse(t *testing.T) { store := testStore(t) runner := &mockRunner{} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(2, runner, store, logger) + pool := NewPool(2, runners, store, logger) if ok := pool.Cancel("nonexistent"); ok { t.Error("Cancel returned true for unknown task") @@ -222,8 +229,9 @@ func TestPool_Cancel_UnknownTask_ReturnsFalse(t *testing.T) { func TestPool_QueuedWhenAtCapacity(t *testing.T) { store := testStore(t) runner := &mockRunner{delay: 100 * time.Millisecond} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(1, runner, store, logger) + pool := NewPool(1, runners, store, logger) tk1 := makeTask("queue-1") store.CreateTask(tk1) @@ -273,8 +281,9 @@ func (m *logPatherMockRunner) Run(ctx context.Context, t *task.Task, e *storage. func TestPool_Execute_LogPathsPreSetBeforeRun(t *testing.T) { store := testStore(t) runner := &logPatherMockRunner{logDir: t.TempDir()} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(2, runner, store, logger) + pool := NewPool(2, runners, store, logger) tk := makeTask("lp-1") store.CreateTask(tk) @@ -304,8 +313,9 @@ func TestPool_Execute_LogPathsPreSetBeforeRun(t *testing.T) { func TestPool_Execute_NoLogPather_PathsEmptyBeforeRun(t *testing.T) { store := testStore(t) runner := &mockRunner{} // does NOT implement LogPather + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(2, runner, store, logger) + pool := NewPool(2, runners, store, logger) tk := makeTask("nolp-1") store.CreateTask(tk) @@ -321,8 +331,9 @@ func TestPool_Execute_NoLogPather_PathsEmptyBeforeRun(t *testing.T) { func TestPool_ConcurrentExecution(t *testing.T) { store := testStore(t) runner := &mockRunner{delay: 50 * time.Millisecond} + runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - pool := NewPool(3, runner, store, logger) + pool := NewPool(3, runners, store, logger) for i := 0; i < 3; i++ { tk := makeTask(fmt.Sprintf("cc-%d", i)) @@ -343,3 +354,29 @@ func TestPool_ConcurrentExecution(t *testing.T) { t.Errorf("calls: want 3, got %d", runner.callCount()) } } + +func TestPool_UnsupportedAgent(t *testing.T) { + store := testStore(t) + runners := map[string]Runner{"claude": &mockRunner{}} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runners, store, logger) + + tk := makeTask("bad-agent") + tk.Agent.Type = "super-ai" + store.CreateTask(tk) + + if err := pool.Submit(context.Background(), tk); err != nil { + t.Fatalf("submit: %v", err) + } + + result := <-pool.Results() + if result.Err == nil { + t.Fatal("expected error for unsupported agent") + } + if !strings.Contains(result.Err.Error(), "unsupported agent type") { + t.Errorf("expected 'unsupported agent type' in error, got: %v", result.Err) + } + if result.Execution.Status != "FAILED" { + t.Errorf("status: want FAILED, got %q", result.Execution.Status) + } +} diff --git a/internal/executor/gemini.go b/internal/executor/gemini.go new file mode 100644 index 0000000..956d8b5 --- /dev/null +++ b/internal/executor/gemini.go @@ -0,0 +1,192 @@ +package executor + +import ( + "context" + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "syscall" + + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +// GeminiRunner spawns the `gemini` CLI in non-interactive mode. +type GeminiRunner struct { + BinaryPath string // defaults to "gemini" + Logger *slog.Logger + LogDir string // base directory for execution logs + APIURL string // base URL of the Claudomator API, passed to subprocesses +} + +// ExecLogDir returns the log directory for the given execution ID. +func (r *GeminiRunner) ExecLogDir(execID string) string { + if r.LogDir == "" { + return "" + } + return filepath.Join(r.LogDir, execID) +} + +func (r *GeminiRunner) binaryPath() string { + if r.BinaryPath != "" { + return r.BinaryPath + } + return "gemini" +} + +// Run executes a gemini <instructions> invocation, streaming output to log files. +func (r *GeminiRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + if t.Agent.ProjectDir != "" { + if _, err := os.Stat(t.Agent.ProjectDir); err != nil { + return fmt.Errorf("project_dir %q: %w", t.Agent.ProjectDir, err) + } + } + + logDir := r.ExecLogDir(e.ID) + if logDir == "" { + logDir = e.ID + } + if err := os.MkdirAll(logDir, 0700); err != nil { + return fmt.Errorf("creating log dir: %w", err) + } + if e.StdoutPath == "" { + e.StdoutPath = filepath.Join(logDir, "stdout.log") + e.StderrPath = filepath.Join(logDir, "stderr.log") + e.ArtifactDir = logDir + } + + if e.SessionID == "" { + e.SessionID = e.ID + } + + questionFile := filepath.Join(logDir, "question.json") + args := r.buildArgs(t, e, questionFile) + + // Gemini CLI doesn't necessarily have the same rate limiting behavior as Claude, + // but we'll use a similar execution pattern. + err := r.execOnce(ctx, args, t.Agent.ProjectDir, e) + if err != nil { + return err + } + + // Check whether the agent left a question before exiting. + data, readErr := os.ReadFile(questionFile) + if readErr == nil { + os.Remove(questionFile) // consumed + return &BlockedError{QuestionJSON: strings.TrimSpace(string(data)), SessionID: e.SessionID} + } + return nil +} + +func (r *GeminiRunner) execOnce(ctx context.Context, args []string, workingDir string, e *storage.Execution) error { + cmd := exec.CommandContext(ctx, r.binaryPath(), args...) + cmd.Env = append(os.Environ(), + "CLAUDOMATOR_API_URL="+r.APIURL, + "CLAUDOMATOR_TASK_ID="+e.TaskID, + "CLAUDOMATOR_QUESTION_FILE="+filepath.Join(e.ArtifactDir, "question.json"), + ) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + if workingDir != "" { + cmd.Dir = workingDir + } + + stdoutFile, err := os.Create(e.StdoutPath) + if err != nil { + return fmt.Errorf("creating stdout log: %w", err) + } + defer stdoutFile.Close() + + stderrFile, err := os.Create(e.StderrPath) + if err != nil { + return fmt.Errorf("creating stderr log: %w", err) + } + defer stderrFile.Close() + + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + return fmt.Errorf("creating stdout pipe: %w", err) + } + cmd.Stdout = stdoutW + cmd.Stderr = stderrFile + + if err := cmd.Start(); err != nil { + stdoutW.Close() + stdoutR.Close() + return fmt.Errorf("starting gemini: %w", err) + } + stdoutW.Close() + + killDone := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + case <-killDone: + } + }() + + var costUSD float64 + var streamErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + // Reusing parseStream as the JSONL format should be compatible + costUSD, streamErr = parseStream(stdoutR, stdoutFile, r.Logger) + stdoutR.Close() + }() + + waitErr := cmd.Wait() + close(killDone) + wg.Wait() + + e.CostUSD = costUSD + + if waitErr != nil { + if exitErr, ok := waitErr.(*exec.ExitError); ok { + e.ExitCode = exitErr.ExitCode() + } + return fmt.Errorf("gemini exited with error: %w", waitErr) + } + + e.ExitCode = 0 + if streamErr != nil { + return streamErr + } + return nil +} + +func (r *GeminiRunner) buildArgs(t *task.Task, e *storage.Execution, questionFile string) []string { + // Gemini CLI uses a different command structure: gemini "instructions" [flags] + + instructions := t.Agent.Instructions + if !t.Agent.SkipPlanning { + instructions = withPlanningPreamble(instructions) + } + + args := []string{ + instructions, + "--output-format", "stream-json", + } + + // Note: Gemini CLI flags might differ from Claude CLI. + // Assuming common flags for now, but these may need adjustment. + if t.Agent.Model != "" { + args = append(args, "--model", t.Agent.Model) + } + + // Gemini CLI doesn't use --session-id for the first run in the same way, + // or it might use it differently. For now we assume compatibility. + if e.SessionID != "" { + // If it's a resume, it might use different flags. + if e.ResumeSessionID != "" { + // This is a placeholder for Gemini's resume logic + } + } + + return args +} diff --git a/internal/executor/gemini_test.go b/internal/executor/gemini_test.go new file mode 100644 index 0000000..42253da --- /dev/null +++ b/internal/executor/gemini_test.go @@ -0,0 +1,103 @@ +package executor + +import ( + "context" + "io" + "log/slog" + "strings" + "testing" + + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +func TestGeminiRunner_BuildArgs_BasicTask(t *testing.T) { + r := &GeminiRunner{} + tk := &task.Task{ + Agent: task.AgentConfig{ + Type: "gemini", + Instructions: "fix the bug", + Model: "gemini-2.0-flash", + SkipPlanning: true, + }, + } + + args := r.buildArgs(tk, &storage.Execution{ID: "test-exec"}, "/tmp/q.json") + + // Gemini CLI: instructions is the first positional arg + if len(args) < 1 || args[0] != "fix the bug" { + t.Errorf("expected instructions as first arg, got: %v", args) + } + + argMap := make(map[string]bool) + for _, a := range args { + argMap[a] = true + } + for _, want := range []string{"--output-format", "stream-json", "--model", "gemini-2.0-flash"} { + if !argMap[want] { + t.Errorf("missing arg %q in %v", want, args) + } + } +} + +func TestGeminiRunner_BuildArgs_PreamblePrepended(t *testing.T) { + r := &GeminiRunner{} + tk := &task.Task{ + Agent: task.AgentConfig{ + Type: "gemini", + Instructions: "fix the bug", + SkipPlanning: false, + }, + } + + args := r.buildArgs(tk, &storage.Execution{ID: "test-exec"}, "/tmp/q.json") + + if len(args) < 1 { + t.Fatalf("expected at least 1 arg, got: %v", args) + } + if !strings.HasPrefix(args[0], planningPreamble) { + t.Errorf("instructions should start with planning preamble") + } + if !strings.HasSuffix(args[0], "fix the bug") { + t.Errorf("instructions should end with original instructions") + } +} + +func TestGeminiRunner_Run_InaccessibleProjectDir_ReturnsError(t *testing.T) { + r := &GeminiRunner{ + BinaryPath: "true", // would succeed if it ran + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + LogDir: t.TempDir(), + } + tk := &task.Task{ + Agent: task.AgentConfig{ + Type: "gemini", + ProjectDir: "/nonexistent/path/does/not/exist", + SkipPlanning: true, + }, + } + exec := &storage.Execution{ID: "test-exec"} + + err := r.Run(context.Background(), tk, exec) + + if err == nil { + t.Fatal("expected error for inaccessible project_dir, got nil") + } + if !strings.Contains(err.Error(), "project_dir") { + t.Errorf("expected 'project_dir' in error, got: %v", err) + } +} + +func TestGeminiRunner_BinaryPath_Default(t *testing.T) { + r := &GeminiRunner{} + if r.binaryPath() != "gemini" { + t.Errorf("want 'gemini', got %q", r.binaryPath()) + } +} + +func TestGeminiRunner_BinaryPath_Custom(t *testing.T) { + r := &GeminiRunner{BinaryPath: "/usr/local/bin/gemini"} + if r.binaryPath() != "/usr/local/bin/gemini" { + t.Errorf("want custom path, got %q", r.binaryPath()) + } +} diff --git a/internal/executor/preamble.go b/internal/executor/preamble.go index 71f8233..b20f7ea 100644 --- a/internal/executor/preamble.go +++ b/internal/executor/preamble.go @@ -23,11 +23,12 @@ Before doing any implementation work: 1. Estimate: will this task take more than 3 minutes of implementation effort? -2. If YES — break it down into subtasks: - - Create 3–7 discrete subtasks using the claudomator CLI, for example: - claudomator create "Subtask name" --instructions "..." --working-dir "/path" --parent-id "$CLAUDOMATOR_TASK_ID" --server "$CLAUDOMATOR_API_URL" - - Do NOT pass --start. Tasks will be queued and started in order by the operator. +2. If YES — break it down: + - Create 3–7 discrete subtasks by POSTing to $CLAUDOMATOR_API_URL/api/tasks + - Each subtask POST body should be JSON with: name, agent.instructions, agent.working_dir (copy from current task), agent.model, agent.allowed_tools, and agent.skip_planning set to true + - Set parent_task_id to $CLAUDOMATOR_TASK_ID in each POST body - After creating all subtasks, output a brief summary and STOP. Do not implement anything. + - You can also specify agent.type (either "claude" or "gemini") to choose the agent for subtasks. 3. If NO — proceed with the task instructions below. diff --git a/internal/storage/db.go b/internal/storage/db.go index cbbd97c..0a4f7a5 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -109,7 +109,7 @@ func isColumnExistsError(err error) bool { // CreateTask inserts a task into the database. func (s *DB) CreateTask(t *task.Task) error { - configJSON, err := json.Marshal(t.Claude) + configJSON, err := json.Marshal(t.Agent) if err != nil { return fmt.Errorf("marshaling config: %w", err) } @@ -242,7 +242,7 @@ func (s *DB) RejectTask(id, comment string) error { type TaskUpdate struct { Name string Description string - Config task.ClaudeConfig + Config task.AgentConfig Priority task.Priority TimeoutNS int64 Retry task.RetryConfig @@ -522,8 +522,17 @@ func scanTask(row scanner) (*task.Task, error) { t.State = task.State(state) t.Priority = task.Priority(priority) t.Timeout.Duration = time.Duration(timeoutNS) - if err := json.Unmarshal([]byte(configJSON), &t.Claude); err != nil { - return nil, fmt.Errorf("unmarshaling config: %w", err) + if err := json.Unmarshal([]byte(configJSON), &t.Agent); err != nil { + return nil, fmt.Errorf("unmarshaling agent config: %w", err) + } + // Fallback for legacy 'claude' field + if t.Agent.Instructions == "" { + var legacy struct { + Claude task.AgentConfig `json:"claude"` + } + if err := json.Unmarshal([]byte(configJSON), &legacy); err == nil && legacy.Claude.Instructions != "" { + t.Agent = legacy.Claude + } } if err := json.Unmarshal([]byte(retryJSON), &t.Retry); err != nil { return nil, fmt.Errorf("unmarshaling retry: %w", err) diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index 2738a41..f737096 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -37,7 +37,8 @@ func TestCreateTask_AndGetTask(t *testing.T) { ID: "task-1", Name: "Test Task", Description: "A test", - Claude: task.ClaudeConfig{ + Agent: task.AgentConfig{ + Type: "claude", Model: "sonnet", Instructions: "do it", ProjectDir: "/tmp", @@ -64,11 +65,11 @@ func TestCreateTask_AndGetTask(t *testing.T) { if got.Name != "Test Task" { t.Errorf("name: want 'Test Task', got %q", got.Name) } - if got.Claude.Model != "sonnet" { - t.Errorf("model: want 'sonnet', got %q", got.Claude.Model) + if got.Agent.Model != "sonnet" { + t.Errorf("model: want 'sonnet', got %q", got.Agent.Model) } - if got.Claude.MaxBudgetUSD != 2.5 { - t.Errorf("budget: want 2.5, got %f", got.Claude.MaxBudgetUSD) + if got.Agent.MaxBudgetUSD != 2.5 { + t.Errorf("budget: want 2.5, got %f", got.Agent.MaxBudgetUSD) } if got.Priority != task.PriorityHigh { t.Errorf("priority: want 'high', got %q", got.Priority) @@ -93,7 +94,7 @@ func TestUpdateTaskState(t *testing.T) { tk := &task.Task{ ID: "task-2", Name: "Stateful", - Claude: task.ClaudeConfig{Instructions: "test"}, + Agent: task.AgentConfig{Type: "claude", Instructions: "test"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, @@ -162,7 +163,7 @@ func TestListTasks_FilterByState(t *testing.T) { for i, state := range []task.State{task.StatePending, task.StatePending, task.StateRunning} { tk := &task.Task{ ID: fmt.Sprintf("t-%d", i), Name: fmt.Sprintf("Task %d", i), - Claude: task.ClaudeConfig{Instructions: "x"}, Priority: task.PriorityNormal, + Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: state, CreatedAt: now, UpdatedAt: now, @@ -195,7 +196,7 @@ func TestListTasks_WithLimit(t *testing.T) { for i := 0; i < 5; i++ { tk := &task.Task{ ID: fmt.Sprintf("lt-%d", i), Name: fmt.Sprintf("T%d", i), - Claude: task.ClaudeConfig{Instructions: "x"}, Priority: task.PriorityNormal, + Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: task.StatePending, CreatedAt: now.Add(time.Duration(i) * time.Second), UpdatedAt: now, @@ -218,7 +219,7 @@ func TestCreateExecution_AndGet(t *testing.T) { // Need a task first. tk := &task.Task{ - ID: "etask", Name: "E", Claude: task.ClaudeConfig{Instructions: "x"}, + ID: "etask", Name: "E", Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: task.StatePending, CreatedAt: now, UpdatedAt: now, @@ -259,7 +260,7 @@ func TestListExecutions(t *testing.T) { db := testDB(t) now := time.Now().UTC() tk := &task.Task{ - ID: "ltask", Name: "L", Claude: task.ClaudeConfig{Instructions: "x"}, + ID: "ltask", Name: "L", Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: task.StatePending, CreatedAt: now, UpdatedAt: now, @@ -292,7 +293,7 @@ func TestDB_UpdateTask(t *testing.T) { ID: "upd-1", Name: "Original Name", Description: "original desc", - Claude: task.ClaudeConfig{Model: "sonnet", Instructions: "original"}, + Agent: task.AgentConfig{Type: "claude", Model: "sonnet", Instructions: "original"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{"old"}, @@ -309,7 +310,7 @@ func TestDB_UpdateTask(t *testing.T) { u := TaskUpdate{ Name: "Updated Name", Description: "updated desc", - Config: task.ClaudeConfig{Model: "opus", Instructions: "updated"}, + Config: task.AgentConfig{Type: "claude", Model: "opus", Instructions: "updated"}, Priority: task.PriorityHigh, TimeoutNS: int64(15 * time.Minute), Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "exponential"}, @@ -330,8 +331,8 @@ func TestDB_UpdateTask(t *testing.T) { if got.Description != "updated desc" { t.Errorf("description: want 'updated desc', got %q", got.Description) } - if got.Claude.Model != "opus" { - t.Errorf("model: want 'opus', got %q", got.Claude.Model) + if got.Agent.Model != "opus" { + t.Errorf("model: want 'opus', got %q", got.Agent.Model) } if got.Priority != task.PriorityHigh { t.Errorf("priority: want 'high', got %q", got.Priority) @@ -376,7 +377,7 @@ func TestRejectTask(t *testing.T) { db := testDB(t) now := time.Now().UTC() tk := &task.Task{ - ID: "reject-1", Name: "R", Claude: task.ClaudeConfig{Instructions: "x"}, + ID: "reject-1", Name: "R", Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: task.StateReady, CreatedAt: now, UpdatedAt: now, @@ -413,7 +414,7 @@ func TestUpdateExecution(t *testing.T) { db := testDB(t) now := time.Now().UTC() tk := &task.Task{ - ID: "utask", Name: "U", Claude: task.ClaudeConfig{Instructions: "x"}, + ID: "utask", Name: "U", Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: task.StatePending, CreatedAt: now, UpdatedAt: now, @@ -456,7 +457,7 @@ func TestUpdateExecution(t *testing.T) { func makeTestTask(id string, now time.Time) *task.Task { return &task.Task{ - ID: id, Name: "T-" + id, Claude: task.ClaudeConfig{Instructions: "x"}, + ID: id, Name: "T-" + id, Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: task.StatePending, CreatedAt: now, UpdatedAt: now, @@ -579,3 +580,36 @@ func TestStorage_GetLatestExecution(t *testing.T) { t.Errorf("want le-2, got %q", got.ID) } } + +func TestGetTask_BackwardCompatibility(t *testing.T) { + db := testDB(t) + now := time.Now().UTC().Truncate(time.Second) + + // Legacy config JSON using "claude" field instead of "agent" + legacyConfig := `{"claude":{"model":"haiku","instructions":"legacy instructions","max_budget_usd":0.5}}` + + _, err := db.db.Exec(` + INSERT INTO tasks (id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "legacy-id", "Legacy Task", "A legacy test", legacyConfig, "normal", + 0, "{}", "[]", "[]", "PENDING", now, now, + ) + if err != nil { + t.Fatalf("inserting legacy task: %v", err) + } + + got, err := db.GetTask("legacy-id") + if err != nil { + t.Fatalf("getting legacy task: %v", err) + } + + if got.Agent.Instructions != "legacy instructions" { + t.Errorf("instructions: want 'legacy instructions', got %q", got.Agent.Instructions) + } + if got.Agent.Model != "haiku" { + t.Errorf("model: want 'haiku', got %q", got.Agent.Model) + } + if got.Agent.MaxBudgetUSD != 0.5 { + t.Errorf("budget: want 0.5, got %f", got.Agent.MaxBudgetUSD) + } +} diff --git a/internal/storage/templates.go b/internal/storage/templates.go index 350b4f8..57abaa4 100644 --- a/internal/storage/templates.go +++ b/internal/storage/templates.go @@ -18,7 +18,7 @@ type Template struct { ID string `json:"id"` Name string `json:"name"` Description string `json:"description"` - Claude task.ClaudeConfig `json:"claude"` + Agent task.AgentConfig `json:"agent"` Timeout string `json:"timeout"` Priority string `json:"priority"` Tags []string `json:"tags"` @@ -28,7 +28,7 @@ type Template struct { // CreateTemplate inserts a new template. func (s *DB) CreateTemplate(tmpl *Template) error { - configJSON, err := json.Marshal(tmpl.Claude) + configJSON, err := json.Marshal(tmpl.Agent) if err != nil { return fmt.Errorf("marshaling config: %w", err) } @@ -73,7 +73,7 @@ func (s *DB) ListTemplates() ([]*Template, error) { // UpdateTemplate fully replaces a template's fields. Returns ErrTemplateNotFound if the ID is missing. func (s *DB) UpdateTemplate(tmpl *Template) error { - configJSON, err := json.Marshal(tmpl.Claude) + configJSON, err := json.Marshal(tmpl.Agent) if err != nil { return fmt.Errorf("marshaling config: %w", err) } @@ -130,7 +130,7 @@ func scanTemplate(row scanner) (*Template, error) { } return nil, err } - if err := json.Unmarshal([]byte(configJSON), &tmpl.Claude); err != nil { + if err := json.Unmarshal([]byte(configJSON), &tmpl.Agent); err != nil { return nil, fmt.Errorf("unmarshaling config: %w", err) } if err := json.Unmarshal([]byte(tagsJSON), &tmpl.Tags); err != nil { diff --git a/internal/task/parser_test.go b/internal/task/parser_test.go index cb68e86..7c3aadc 100644 --- a/internal/task/parser_test.go +++ b/internal/task/parser_test.go @@ -11,7 +11,8 @@ func TestParse_SingleTask(t *testing.T) { yaml := ` name: "Test Task" description: "A simple test" -claude: +agent: + type: "claude" model: "sonnet" instructions: "Do something" working_dir: "/tmp" @@ -30,8 +31,8 @@ tags: if task.Name != "Test Task" { t.Errorf("expected name 'Test Task', got %q", task.Name) } - if task.Claude.Model != "sonnet" { - t.Errorf("expected model 'sonnet', got %q", task.Claude.Model) + if task.Agent.Model != "sonnet" { + t.Errorf("expected model 'sonnet', got %q", task.Agent.Model) } if task.Timeout.Duration != 10*time.Minute { t.Errorf("expected timeout 10m, got %v", task.Timeout.Duration) @@ -51,12 +52,14 @@ func TestParse_BatchTasks(t *testing.T) { yaml := ` tasks: - name: "Task A" - claude: + agent: + type: "claude" instructions: "Do A" working_dir: "/tmp" tags: ["alpha"] - name: "Task B" - claude: + agent: + type: "claude" instructions: "Do B" working_dir: "/tmp" tags: ["beta"] @@ -79,7 +82,8 @@ tasks: func TestParse_MissingName_ReturnsError(t *testing.T) { yaml := ` description: "no name" -claude: +agent: + type: "claude" instructions: "something" ` _, err := Parse([]byte(yaml)) @@ -91,7 +95,8 @@ claude: func TestParse_DefaultRetryConfig(t *testing.T) { yaml := ` name: "Defaults" -claude: +agent: + type: "claude" instructions: "test" ` tasks, err := Parse([]byte(yaml)) @@ -110,7 +115,8 @@ func TestParse_WithPriority(t *testing.T) { yaml := ` name: "High Priority" priority: "high" -claude: +agent: + type: "claude" instructions: "urgent" ` tasks, err := Parse([]byte(yaml)) @@ -127,7 +133,8 @@ func TestParseFile(t *testing.T) { path := filepath.Join(dir, "task.yaml") content := ` name: "File Task" -claude: +agent: + type: "claude" instructions: "from file" working_dir: "/tmp" ` diff --git a/internal/task/task.go b/internal/task/task.go index 498c364..6b240dd 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -28,7 +28,8 @@ const ( PriorityLow Priority = "low" ) -type ClaudeConfig struct { +type AgentConfig struct { + Type string `yaml:"type" json:"type"` Model string `yaml:"model" json:"model"` ContextFiles []string `yaml:"context_files" json:"context_files"` Instructions string `yaml:"instructions" json:"instructions"` @@ -43,8 +44,8 @@ type ClaudeConfig struct { } // UnmarshalJSON reads project_dir with fallback to legacy working_dir. -func (c *ClaudeConfig) UnmarshalJSON(data []byte) error { - type Alias ClaudeConfig +func (c *AgentConfig) UnmarshalJSON(data []byte) error { + type Alias AgentConfig aux := &struct { ProjectDir string `json:"project_dir"` WorkingDir string `json:"working_dir"` // legacy @@ -71,7 +72,8 @@ type Task struct { ParentTaskID string `yaml:"parent_task_id" json:"parent_task_id"` Name string `yaml:"name" json:"name"` Description string `yaml:"description" json:"description"` - Claude ClaudeConfig `yaml:"claude" json:"claude"` + Agent AgentConfig `yaml:"agent" json:"agent"` + Claude AgentConfig `yaml:"claude" json:"claude"` // alias for backward compatibility Timeout Duration `yaml:"timeout" json:"timeout"` Retry RetryConfig `yaml:"retry" json:"retry"` Priority Priority `yaml:"priority" json:"priority"` diff --git a/internal/task/validator.go b/internal/task/validator.go index ea0b1c2..003fab9 100644 --- a/internal/task/validator.go +++ b/internal/task/validator.go @@ -29,11 +29,11 @@ func Validate(t *Task) error { if t.Name == "" { ve.Add("name is required") } - if t.Claude.Instructions == "" { - ve.Add("claude.instructions is required") + if t.Agent.Instructions == "" { + ve.Add("agent.instructions is required") } - if t.Claude.MaxBudgetUSD < 0 { - ve.Add("claude.max_budget_usd must be non-negative") + if t.Agent.MaxBudgetUSD < 0 { + ve.Add("agent.max_budget_usd must be non-negative") } if t.Timeout.Duration < 0 { ve.Add("timeout must be non-negative") @@ -48,13 +48,13 @@ func Validate(t *Task) error { if t.Priority != "" && !validPriorities[t.Priority] { ve.Add(fmt.Sprintf("invalid priority %q; must be high, normal, or low", t.Priority)) } - if t.Claude.PermissionMode != "" { + if t.Agent.PermissionMode != "" { validModes := map[string]bool{ "default": true, "acceptEdits": true, "bypassPermissions": true, "plan": true, "dontAsk": true, "delegate": true, } - if !validModes[t.Claude.PermissionMode] { - ve.Add(fmt.Sprintf("invalid permission_mode %q", t.Claude.PermissionMode)) + if !validModes[t.Agent.PermissionMode] { + ve.Add(fmt.Sprintf("invalid permission_mode %q", t.Agent.PermissionMode)) } } diff --git a/internal/task/validator_test.go b/internal/task/validator_test.go index 02bde45..657d93f 100644 --- a/internal/task/validator_test.go +++ b/internal/task/validator_test.go @@ -9,7 +9,8 @@ func validTask() *Task { return &Task{ ID: "test-id", Name: "Valid Task", - Claude: ClaudeConfig{ + Agent: AgentConfig{ + Type: "claude", Instructions: "do something", ProjectDir: "/tmp", }, @@ -39,7 +40,7 @@ func TestValidate_MissingName_ReturnsError(t *testing.T) { func TestValidate_MissingInstructions_ReturnsError(t *testing.T) { task := validTask() - task.Claude.Instructions = "" + task.Agent.Instructions = "" err := Validate(task) if err == nil { t.Fatal("expected error") @@ -51,7 +52,7 @@ func TestValidate_MissingInstructions_ReturnsError(t *testing.T) { func TestValidate_NegativeBudget_ReturnsError(t *testing.T) { task := validTask() - task.Claude.MaxBudgetUSD = -1.0 + task.Agent.MaxBudgetUSD = -1.0 err := Validate(task) if err == nil { t.Fatal("expected error") @@ -87,7 +88,7 @@ func TestValidate_InvalidPriority_ReturnsError(t *testing.T) { func TestValidate_InvalidPermissionMode_ReturnsError(t *testing.T) { task := validTask() - task.Claude.PermissionMode = "yolo" + task.Agent.PermissionMode = "yolo" err := Validate(task) if err == nil { t.Fatal("expected error") @@ -496,7 +496,7 @@ async function updateTask(taskId, body) { } function createEditForm(task) { - const c = task.claude || {}; + const a = task.agent || {}; const form = document.createElement('div'); form.className = 'task-inline-edit'; @@ -517,10 +517,25 @@ function createEditForm(task) { form.appendChild(makeField('Name', 'input', { type: 'text', name: 'name', value: task.name || '' })); form.appendChild(makeField('Description', 'textarea', { name: 'description', rows: '2', value: task.description || '' })); - form.appendChild(makeField('Instructions', 'textarea', { name: 'instructions', rows: '4', value: c.instructions || '' })); - form.appendChild(makeField('Model', 'input', { type: 'text', name: 'model', value: c.model || 'sonnet' })); - form.appendChild(makeField('Working Directory', 'input', { type: 'text', name: 'project_dir', value: c.project_dir || '', placeholder: '/path/to/repo' })); - form.appendChild(makeField('Max Budget (USD)', 'input', { type: 'number', name: 'max_budget_usd', step: '0.01', value: c.max_budget_usd != null ? String(c.max_budget_usd) : '1.00' })); + form.appendChild(makeField('Instructions', 'textarea', { name: 'instructions', rows: '4', value: a.instructions || '' })); + + const typeLabel = document.createElement('label'); + typeLabel.textContent = 'Agent Type'; + const typeSel = document.createElement('select'); + typeSel.name = 'type'; + for (const val of ['claude', 'gemini']) { + const opt = document.createElement('option'); + opt.value = val; + opt.textContent = val.charAt(0).toUpperCase() + val.slice(1); + if (val === (a.type || 'claude')) opt.selected = true; + typeSel.appendChild(opt); + } + typeLabel.appendChild(typeSel); + form.appendChild(typeLabel); + + form.appendChild(makeField('Model', 'input', { type: 'text', name: 'model', value: a.model || 'sonnet' })); + form.appendChild(makeField('Project Directory', 'input', { type: 'text', name: 'project_dir', value: a.project_dir || a.working_dir || '', placeholder: '/path/to/repo' })); + form.appendChild(makeField('Max Budget (USD)', 'input', { type: 'number', name: 'max_budget_usd', step: '0.01', value: a.max_budget_usd != null ? String(a.max_budget_usd) : '1.00' })); form.appendChild(makeField('Timeout', 'input', { type: 'text', name: 'timeout', value: formatDurationForInput(task.timeout) || '15m', placeholder: '15m' })); const prioLabel = document.createElement('label'); @@ -568,7 +583,8 @@ async function handleEditSave(taskId, form, saveBtn) { const body = { name: get('name'), description: get('description'), - claude: { + agent: { + type: get('type'), model: get('model'), instructions: get('instructions'), project_dir: get('project_dir'), @@ -1052,11 +1068,12 @@ function buildValidatePayload() { const instructions = f.querySelector('[name="instructions"]').value; const project_dir = f.querySelector('[name="project_dir"]').value; const model = f.querySelector('[name="model"]').value; + const type = f.querySelector('[name="type"]').value; const allowedToolsEl = f.querySelector('[name="allowed_tools"]'); const allowed_tools = allowedToolsEl ? allowedToolsEl.value.split(',').map(s => s.trim()).filter(Boolean) : []; - return { name, claude: { instructions, project_dir, model, allowed_tools } }; + return { name, agent: { type, instructions, project_dir, model, allowed_tools } }; } function renderValidationResult(result) { @@ -1145,6 +1162,7 @@ function initProjectSelect() { const select = document.getElementById('project-select'); const newRow = document.getElementById('new-project-row'); const newInput = document.getElementById('new-project-input'); + if (!select) return; select.addEventListener('change', () => { if (select.value === '__new__') { newRow.hidden = false; @@ -1176,7 +1194,8 @@ async function createTask(formData) { const body = { name: formData.get('name'), description: '', - claude: { + agent: { + type: formData.get('type'), model: formData.get('model'), instructions: formData.get('instructions'), project_dir: workingDir, @@ -1220,7 +1239,8 @@ async function saveTemplate(formData) { const body = { name: formData.get('name'), description: formData.get('description'), - claude: { + agent: { + type: formData.get('type'), model: formData.get('model'), instructions: formData.get('instructions'), project_dir: formData.get('project_dir'), @@ -1395,31 +1415,32 @@ function renderTaskPanel(task, executions) { overview.appendChild(overviewGrid); content.appendChild(overview); - // ── Claude Config ── - const c = task.claude || {}; - const claudeSection = makeSection('Claude Config'); - const claudeGrid = document.createElement('div'); - claudeGrid.className = 'meta-grid'; - claudeGrid.append( - makeMetaItem('Model', c.model), - makeMetaItem('Max Budget', c.max_budget_usd != null ? `$${c.max_budget_usd.toFixed(2)}` : '—'), - makeMetaItem('Project Dir', c.project_dir), - makeMetaItem('Permission Mode', c.permission_mode || 'default'), + // ── Agent Config ── + const a = task.agent || {}; + const agentSection = makeSection('Agent Config'); + const agentGrid = document.createElement('div'); + agentGrid.className = 'meta-grid'; + agentGrid.append( + makeMetaItem('Type', a.type || 'claude'), + makeMetaItem('Model', a.model), + makeMetaItem('Max Budget', a.max_budget_usd != null ? `$${a.max_budget_usd.toFixed(2)}` : '—'), + makeMetaItem('Project Dir', a.project_dir || a.working_dir), + makeMetaItem('Permission Mode', a.permission_mode || 'default'), ); - if (c.allowed_tools && c.allowed_tools.length > 0) { - claudeGrid.append(makeMetaItem('Allowed Tools', c.allowed_tools.join(', '), { fullWidth: true })); + if (a.allowed_tools && a.allowed_tools.length > 0) { + agentGrid.append(makeMetaItem('Allowed Tools', a.allowed_tools.join(', '), { fullWidth: true })); } - if (c.disallowed_tools && c.disallowed_tools.length > 0) { - claudeGrid.append(makeMetaItem('Disallowed Tools', c.disallowed_tools.join(', '), { fullWidth: true })); + if (a.disallowed_tools && a.disallowed_tools.length > 0) { + agentGrid.append(makeMetaItem('Disallowed Tools', a.disallowed_tools.join(', '), { fullWidth: true })); } - if (c.instructions) { - claudeGrid.append(makeMetaItem('Instructions', c.instructions, { fullWidth: true, code: true })); + if (a.instructions) { + agentGrid.append(makeMetaItem('Instructions', a.instructions, { fullWidth: true, code: true })); } - if (c.system_prompt_append) { - claudeGrid.append(makeMetaItem('System Prompt Append', c.system_prompt_append, { fullWidth: true, code: true })); + if (a.system_prompt_append) { + agentGrid.append(makeMetaItem('System Prompt Append', a.system_prompt_append, { fullWidth: true, code: true })); } - claudeSection.appendChild(claudeGrid); - content.appendChild(claudeSection); + agentSection.appendChild(agentGrid); + content.appendChild(agentSection); // ── Execution Settings ── const settingsSection = makeSection('Execution Settings'); @@ -2071,23 +2092,51 @@ if (typeof document !== 'undefined') document.addEventListener('DOMContentLoaded const f = document.getElementById('task-form'); if (result.name) f.querySelector('[name="name"]').value = result.name; +<<<<<<< HEAD + if (result.agent && result.agent.instructions) + f.querySelector('[name="instructions"]').value = result.agent.instructions; + if (result.agent && result.agent.working_dir) { + const pSel = document.getElementById('project-select'); + const exists = [...pSel.options].some(o => o.value === result.agent.working_dir); +||||||| cad057f + if (result.claude && result.claude.instructions) + f.querySelector('[name="instructions"]').value = result.claude.instructions; + if (result.claude && result.claude.working_dir) { + const sel = document.getElementById('project-select'); + const exists = [...sel.options].some(o => o.value === result.claude.working_dir); +======= if (result.claude && result.claude.instructions) f.querySelector('[name="instructions"]').value = result.claude.instructions; if (result.claude && result.claude.project_dir) { const sel = document.getElementById('project-select'); const exists = [...sel.options].some(o => o.value === result.claude.project_dir); +>>>>>>> master if (exists) { +<<<<<<< HEAD + pSel.value = result.agent.working_dir; +||||||| cad057f + sel.value = result.claude.working_dir; +======= sel.value = result.claude.project_dir; +>>>>>>> master } else { - sel.value = '__new__'; + pSel.value = '__new__'; document.getElementById('new-project-row').hidden = false; +<<<<<<< HEAD + document.getElementById('new-project-input').value = result.agent.working_dir; +||||||| cad057f + document.getElementById('new-project-input').value = result.claude.working_dir; +======= document.getElementById('new-project-input').value = result.claude.project_dir; +>>>>>>> master } } - if (result.claude && result.claude.model) - f.querySelector('[name="model"]').value = result.claude.model; - if (result.claude && result.claude.max_budget_usd != null) - f.querySelector('[name="max_budget_usd"]').value = result.claude.max_budget_usd; + if (result.agent && result.agent.model) + f.querySelector('[name="model"]').value = result.agent.model; + if (result.agent && result.agent.type) + f.querySelector('[name="type"]').value = result.agent.type; + if (result.agent && result.agent.max_budget_usd != null) + f.querySelector('[name="max_budget_usd"]').value = result.agent.max_budget_usd; if (result.timeout) f.querySelector('[name="timeout"]').value = result.timeout; if (result.priority) { diff --git a/web/index.html b/web/index.html index 842c272..a2800b0 100644 --- a/web/index.html +++ b/web/index.html @@ -52,23 +52,24 @@ <form id="task-form" method="dialog"> <h2>New Task</h2> <div class="elaborate-section"> - <label>Describe what you want Claude to do + <label>Describe what you want the agent to do <textarea id="elaborate-prompt" rows="3" placeholder="e.g. run tests with race detector and check coverage"></textarea> </label> <button type="button" id="btn-elaborate" class="btn-secondary"> Draft with AI ✦ </button> - <p class="elaborate-hint">Claude will fill in the form fields below. You can edit before submitting.</p> + <p class="elaborate-hint">AI will fill in the form fields below. You can edit before submitting.</p> </div> <hr class="form-divider"> <label>Project <select name="project_dir" id="project-select"> <option value="/workspace/claudomator" selected>/workspace/claudomator</option> + <option value="__new__">Create new project…</option> </select> </label> <div id="new-project-row" hidden> - <label>New project path <input id="new-project-input" placeholder="/workspace/my-project"></label> + <label>New Project Path <input id="new-project-input" placeholder="/workspace/my-new-app"></label> </div> <label>Name <input name="name" required></label> <label>Instructions <textarea name="instructions" rows="6" required></textarea></label> @@ -78,7 +79,15 @@ </button> <div id="validate-result" hidden></div> </div> - <label>Model <input name="model" value="sonnet"></label> + <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px;"> + <label>Agent Type + <select name="type"> + <option value="claude" selected>Claude</option> + <option value="gemini">Gemini</option> + </select> + </label> + <label>Model <input name="model" value="sonnet" placeholder="e.g. sonnet, gemini-2.0-flash"></label> + </div> <label>Max Budget (USD) <input name="max_budget_usd" type="number" step="0.01" value="1.00"></label> <label>Timeout <input name="timeout" value="15m"></label> <label>Priority @@ -100,7 +109,15 @@ <h2>New Template</h2> <label>Name <input name="name" required></label> <label>Description <textarea name="description" rows="2"></textarea></label> - <label>Model <input name="model" value="sonnet"></label> + <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px;"> + <label>Agent Type + <select name="type"> + <option value="claude" selected>Claude</option> + <option value="gemini">Gemini</option> + </select> + </label> + <label>Model <input name="model" value="sonnet" placeholder="e.g. sonnet, gemini-2.0-flash"></label> + </div> <label>Instructions <textarea name="instructions" rows="6" required></textarea></label> <label>Project Directory <input name="project_dir" placeholder="/path/to/repo"></label> <label>Max Budget (USD) <input name="max_budget_usd" type="number" step="0.01" value="1.00"></label> |
