summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-08 20:50:21 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-08 20:50:21 +0000
commit406247b14985ab57902e8e42898dc8cb8960290d (patch)
tree4a93be793f541038dd5d3fc154563051ba151b50 /internal
parent0ff0bf75544bbf565288e61bb8e10c3f903830f8 (diff)
feat(executor): implement Gemini-based task classification and load balancing
- Add Classifier using gemini-2.0-flash-lite to automatically select agent/model. - Update Pool to track per-agent active tasks and rate limit status. - Enable classification for all tasks (top-level and subtasks). - Refine SystemStatus to be dynamic across all supported agents. - Add unit tests for the classifier and updated pool logic. - Minor UI improvements for project selection and 'Start Next' action.
Diffstat (limited to 'internal')
-rw-r--r--internal/cli/run.go3
-rw-r--r--internal/cli/serve.go3
-rw-r--r--internal/executor/classifier.go109
-rw-r--r--internal/executor/classifier_test.go49
-rw-r--r--internal/executor/executor.go95
5 files changed, 247 insertions, 12 deletions
diff --git a/internal/cli/run.go b/internal/cli/run.go
index 3624cea..62e1252 100644
--- a/internal/cli/run.go
+++ b/internal/cli/run.go
@@ -86,6 +86,9 @@ func runTasks(file string, parallel int, dryRun bool) error {
},
}
pool := executor.NewPool(parallel, runners, store, logger)
+ if cfg.GeminiBinaryPath != "" {
+ pool.Classifier = &executor.Classifier{GeminiBinaryPath: cfg.GeminiBinaryPath}
+ }
// Handle graceful shutdown.
ctx, cancel := context.WithCancel(context.Background())
diff --git a/internal/cli/serve.go b/internal/cli/serve.go
index 2ecb6cd..b679b38 100644
--- a/internal/cli/serve.go
+++ b/internal/cli/serve.go
@@ -71,6 +71,9 @@ func serve(addr string) error {
}
pool := executor.NewPool(cfg.MaxConcurrent, runners, store, logger)
+ if cfg.GeminiBinaryPath != "" {
+ pool.Classifier = &executor.Classifier{GeminiBinaryPath: cfg.GeminiBinaryPath}
+ }
srv := api.NewServer(store, pool, logger, cfg.ClaudeBinaryPath, cfg.GeminiBinaryPath)
srv.StartHub()
diff --git a/internal/executor/classifier.go b/internal/executor/classifier.go
new file mode 100644
index 0000000..79ebc27
--- /dev/null
+++ b/internal/executor/classifier.go
@@ -0,0 +1,109 @@
+package executor
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os/exec"
+ "strings"
+)
+
+type Classification struct {
+ AgentType string `json:"agent_type"`
+ Model string `json:"model"`
+ Reason string `json:"reason"`
+}
+
+type SystemStatus struct {
+ ActiveTasks map[string]int
+ RateLimited map[string]bool
+}
+
+type Classifier struct {
+ GeminiBinaryPath string
+}
+
+const classificationPrompt = `
+You are a task classifier for Claudomator.
+Given a task description and system status, select the best agent (claude or gemini) and model to use.
+
+Agent Types:
+- claude: Best for complex coding, reasoning, and tool use.
+- gemini: Best for large context, fast reasoning, and multimodal tasks.
+
+Available Models:
+Claude:
+- claude-3-5-sonnet-latest (balanced)
+- claude-3-5-sonnet-20241022 (stable)
+- claude-3-opus-20240229 (most powerful, expensive)
+- claude-3-5-haiku-20241022 (fast, cheap)
+
+Gemini:
+- gemini-2.0-flash-lite (fastest, most efficient, best for simple tasks)
+- gemini-2.0-flash (fast, multimodal)
+- gemini-1.5-flash (fast, balanced)
+- gemini-1.5-pro (more powerful, larger context)
+
+Selection Criteria:
+- Agent: Prefer the one with least running tasks and no active rate limit.
+- Model: Select based on task complexity. Use powerful models (opus, pro) for complex reasoning/coding, flash-lite/flash/haiku for simple tasks.
+
+Task:
+Name: %s
+Instructions: %s
+
+System Status:
+%s
+
+Respond with ONLY a JSON object:
+{
+ "agent_type": "claude" | "gemini",
+ "model": "model-name",
+ "reason": "brief reason"
+}
+`
+
+func (c *Classifier) Classify(ctx context.Context, taskName, instructions string, status SystemStatus) (*Classification, error) {
+ statusStr := ""
+ for agent, active := range status.ActiveTasks {
+ statusStr += fmt.Sprintf("- Agent %s: %d active tasks, Rate Limited: %t\n", agent, active, status.RateLimited[agent])
+ }
+
+ prompt := fmt.Sprintf(classificationPrompt,
+ taskName, instructions, statusStr,
+ )
+
+ binary := c.GeminiBinaryPath
+ if binary == "" {
+ binary = "gemini"
+ }
+
+ // Use a minimal model for classification to be fast and cheap.
+ args := []string{
+ "--prompt", prompt,
+ "--model", "gemini-2.0-flash-lite",
+ "--output-format", "json",
+ }
+
+ cmd := exec.CommandContext(ctx, binary, args...)
+ out, err := cmd.Output()
+ if err != nil {
+ if exitErr, ok := err.(*exec.ExitError); ok {
+ return nil, fmt.Errorf("classifier failed (%v): %s", err, string(exitErr.Stderr))
+ }
+ return nil, fmt.Errorf("classifier failed: %w", err)
+ }
+
+ var cls Classification
+ // Gemini might wrap the JSON in markdown code blocks.
+ cleanOut := strings.TrimSpace(string(out))
+ cleanOut = strings.TrimPrefix(cleanOut, "```json")
+ cleanOut = strings.TrimSuffix(cleanOut, "```")
+ cleanOut = strings.TrimSpace(cleanOut)
+
+ if err := json.Unmarshal([]byte(cleanOut), &cls); err != nil {
+ return nil, fmt.Errorf("failed to parse classification JSON: %w\nOutput: %s", err, cleanOut)
+ }
+
+ return &cls, nil
+}
diff --git a/internal/executor/classifier_test.go b/internal/executor/classifier_test.go
new file mode 100644
index 0000000..4de44ca
--- /dev/null
+++ b/internal/executor/classifier_test.go
@@ -0,0 +1,49 @@
+package executor
+
+import (
+ "context"
+ "os"
+ "testing"
+)
+
+// TestClassifier_Classify_Mock tests the classifier with a mocked gemini binary.
+func TestClassifier_Classify_Mock(t *testing.T) {
+ // Create a temporary mock binary.
+ mockBinary := filepathJoin(t.TempDir(), "mock-gemini")
+ mockContent := `#!/bin/sh
+echo '{"agent_type": "gemini", "model": "gemini-2.0-flash", "reason": "test reason"}'
+`
+ if err := os.WriteFile(mockBinary, []byte(mockContent), 0755); err != nil {
+ t.Fatal(err)
+ }
+
+ c := &Classifier{GeminiBinaryPath: mockBinary}
+ status := SystemStatus{
+ ActiveTasks: map[string]int{"claude": 5, "gemini": 1},
+ RateLimited: map[string]bool{"claude": false, "gemini": false},
+ }
+
+ cls, err := c.Classify(context.Background(), "Test Task", "Test Instructions", status)
+ if err != nil {
+ t.Fatalf("Classify failed: %v", err)
+ }
+
+ if cls.AgentType != "gemini" {
+ t.Errorf("expected gemini, got %s", cls.AgentType)
+ }
+ if cls.Model != "gemini-2.0-flash" {
+ t.Errorf("expected gemini-2.0-flash, got %s", cls.Model)
+ }
+}
+
+func filepathJoin(elems ...string) string {
+ var path string
+ for i, e := range elems {
+ if i == 0 {
+ path = e
+ } else {
+ path = path + string(os.PathSeparator) + e
+ }
+ }
+ return path
+}
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 0786d86..6bd1c68 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -33,11 +33,14 @@ type Pool struct {
store *storage.DB
logger *slog.Logger
- mu sync.Mutex
- active int
- cancels map[string]context.CancelFunc // taskID → cancel
- resultCh chan *Result
- Questions *QuestionRegistry
+ mu sync.Mutex
+ active int
+ activePerAgent map[string]int
+ rateLimited map[string]time.Time // agentType -> until
+ cancels map[string]context.CancelFunc // taskID → cancel
+ resultCh chan *Result
+ Questions *QuestionRegistry
+ Classifier *Classifier
}
// Result is emitted when a task execution completes.
@@ -52,13 +55,15 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, lo
maxConcurrent = 1
}
return &Pool{
- maxConcurrent: maxConcurrent,
- runners: runners,
- store: store,
- logger: logger,
- cancels: make(map[string]context.CancelFunc),
- resultCh: make(chan *Result, maxConcurrent*2),
- Questions: NewQuestionRegistry(),
+ maxConcurrent: maxConcurrent,
+ runners: runners,
+ store: store,
+ logger: logger,
+ activePerAgent: make(map[string]int),
+ rateLimited: make(map[string]time.Time),
+ cancels: make(map[string]context.CancelFunc),
+ resultCh: make(chan *Result, maxConcurrent*2),
+ Questions: NewQuestionRegistry(),
}
}
@@ -126,9 +131,19 @@ func (p *Pool) getRunner(t *task.Task) (Runner, error) {
}
func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Execution) {
+ agentType := t.Agent.Type
+ if agentType == "" {
+ agentType = "claude"
+ }
+
+ p.mu.Lock()
+ p.activePerAgent[agentType]++
+ p.mu.Unlock()
+
defer func() {
p.mu.Lock()
p.active--
+ p.activePerAgent[agentType]--
p.mu.Unlock()
}()
@@ -169,6 +184,16 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex
exec.EndTime = time.Now().UTC()
if err != nil {
+ if isRateLimitError(err) {
+ p.mu.Lock()
+ retryAfter := parseRetryAfter(err.Error())
+ if retryAfter == 0 {
+ retryAfter = 1 * time.Minute
+ }
+ p.rateLimited[agentType] = time.Now().Add(retryAfter)
+ p.mu.Unlock()
+ }
+
var blockedErr *BlockedError
if errors.As(err, &blockedErr) {
exec.Status = "BLOCKED"
@@ -211,9 +236,45 @@ func (p *Pool) ActiveCount() int {
}
func (p *Pool) execute(ctx context.Context, t *task.Task) {
+ // 1. Classification
+ if p.Classifier != nil {
+ p.mu.Lock()
+ activeTasks := make(map[string]int)
+ rateLimited := make(map[string]bool)
+ now := time.Now()
+ for agent := range p.runners {
+ activeTasks[agent] = p.activePerAgent[agent]
+ rateLimited[agent] = now.Before(p.rateLimited[agent])
+ }
+ status := SystemStatus{
+ ActiveTasks: activeTasks,
+ RateLimited: rateLimited,
+ }
+ p.mu.Unlock()
+
+ cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status)
+ if err == nil {
+ p.logger.Info("task classified", "taskID", t.ID, "agent", cls.AgentType, "model", cls.Model, "reason", cls.Reason)
+ t.Agent.Type = cls.AgentType
+ t.Agent.Model = cls.Model
+ } else {
+ p.logger.Error("classification failed", "error", err, "taskID", t.ID)
+ }
+ }
+
+ agentType := t.Agent.Type
+ if agentType == "" {
+ agentType = "claude"
+ }
+
+ p.mu.Lock()
+ p.activePerAgent[agentType]++
+ p.mu.Unlock()
+
defer func() {
p.mu.Lock()
p.active--
+ p.activePerAgent[agentType]--
p.mu.Unlock()
}()
@@ -306,6 +367,16 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
exec.EndTime = time.Now().UTC()
if err != nil {
+ if isRateLimitError(err) {
+ p.mu.Lock()
+ retryAfter := parseRetryAfter(err.Error())
+ if retryAfter == 0 {
+ retryAfter = 1 * time.Minute
+ }
+ p.rateLimited[agentType] = time.Now().Add(retryAfter)
+ p.mu.Unlock()
+ }
+
var blockedErr *BlockedError
if errors.As(err, &blockedErr) {
exec.Status = "BLOCKED"