diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/cli/run.go | 3 | ||||
| -rw-r--r-- | internal/cli/serve.go | 3 | ||||
| -rw-r--r-- | internal/executor/classifier.go | 109 | ||||
| -rw-r--r-- | internal/executor/classifier_test.go | 49 | ||||
| -rw-r--r-- | internal/executor/executor.go | 95 |
5 files changed, 247 insertions, 12 deletions
diff --git a/internal/cli/run.go b/internal/cli/run.go index 3624cea..62e1252 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -86,6 +86,9 @@ func runTasks(file string, parallel int, dryRun bool) error { }, } pool := executor.NewPool(parallel, runners, store, logger) + if cfg.GeminiBinaryPath != "" { + pool.Classifier = &executor.Classifier{GeminiBinaryPath: cfg.GeminiBinaryPath} + } // Handle graceful shutdown. ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/cli/serve.go b/internal/cli/serve.go index 2ecb6cd..b679b38 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -71,6 +71,9 @@ func serve(addr string) error { } pool := executor.NewPool(cfg.MaxConcurrent, runners, store, logger) + if cfg.GeminiBinaryPath != "" { + pool.Classifier = &executor.Classifier{GeminiBinaryPath: cfg.GeminiBinaryPath} + } srv := api.NewServer(store, pool, logger, cfg.ClaudeBinaryPath, cfg.GeminiBinaryPath) srv.StartHub() diff --git a/internal/executor/classifier.go b/internal/executor/classifier.go new file mode 100644 index 0000000..79ebc27 --- /dev/null +++ b/internal/executor/classifier.go @@ -0,0 +1,109 @@ +package executor + +import ( + "context" + "encoding/json" + "fmt" + "os/exec" + "strings" +) + +type Classification struct { + AgentType string `json:"agent_type"` + Model string `json:"model"` + Reason string `json:"reason"` +} + +type SystemStatus struct { + ActiveTasks map[string]int + RateLimited map[string]bool +} + +type Classifier struct { + GeminiBinaryPath string +} + +const classificationPrompt = ` +You are a task classifier for Claudomator. +Given a task description and system status, select the best agent (claude or gemini) and model to use. + +Agent Types: +- claude: Best for complex coding, reasoning, and tool use. +- gemini: Best for large context, fast reasoning, and multimodal tasks. + +Available Models: +Claude: +- claude-3-5-sonnet-latest (balanced) +- claude-3-5-sonnet-20241022 (stable) +- claude-3-opus-20240229 (most powerful, expensive) +- claude-3-5-haiku-20241022 (fast, cheap) + +Gemini: +- gemini-2.0-flash-lite (fastest, most efficient, best for simple tasks) +- gemini-2.0-flash (fast, multimodal) +- gemini-1.5-flash (fast, balanced) +- gemini-1.5-pro (more powerful, larger context) + +Selection Criteria: +- Agent: Prefer the one with least running tasks and no active rate limit. +- Model: Select based on task complexity. Use powerful models (opus, pro) for complex reasoning/coding, flash-lite/flash/haiku for simple tasks. + +Task: +Name: %s +Instructions: %s + +System Status: +%s + +Respond with ONLY a JSON object: +{ + "agent_type": "claude" | "gemini", + "model": "model-name", + "reason": "brief reason" +} +` + +func (c *Classifier) Classify(ctx context.Context, taskName, instructions string, status SystemStatus) (*Classification, error) { + statusStr := "" + for agent, active := range status.ActiveTasks { + statusStr += fmt.Sprintf("- Agent %s: %d active tasks, Rate Limited: %t\n", agent, active, status.RateLimited[agent]) + } + + prompt := fmt.Sprintf(classificationPrompt, + taskName, instructions, statusStr, + ) + + binary := c.GeminiBinaryPath + if binary == "" { + binary = "gemini" + } + + // Use a minimal model for classification to be fast and cheap. + args := []string{ + "--prompt", prompt, + "--model", "gemini-2.0-flash-lite", + "--output-format", "json", + } + + cmd := exec.CommandContext(ctx, binary, args...) + out, err := cmd.Output() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return nil, fmt.Errorf("classifier failed (%v): %s", err, string(exitErr.Stderr)) + } + return nil, fmt.Errorf("classifier failed: %w", err) + } + + var cls Classification + // Gemini might wrap the JSON in markdown code blocks. + cleanOut := strings.TrimSpace(string(out)) + cleanOut = strings.TrimPrefix(cleanOut, "```json") + cleanOut = strings.TrimSuffix(cleanOut, "```") + cleanOut = strings.TrimSpace(cleanOut) + + if err := json.Unmarshal([]byte(cleanOut), &cls); err != nil { + return nil, fmt.Errorf("failed to parse classification JSON: %w\nOutput: %s", err, cleanOut) + } + + return &cls, nil +} diff --git a/internal/executor/classifier_test.go b/internal/executor/classifier_test.go new file mode 100644 index 0000000..4de44ca --- /dev/null +++ b/internal/executor/classifier_test.go @@ -0,0 +1,49 @@ +package executor + +import ( + "context" + "os" + "testing" +) + +// TestClassifier_Classify_Mock tests the classifier with a mocked gemini binary. +func TestClassifier_Classify_Mock(t *testing.T) { + // Create a temporary mock binary. + mockBinary := filepathJoin(t.TempDir(), "mock-gemini") + mockContent := `#!/bin/sh +echo '{"agent_type": "gemini", "model": "gemini-2.0-flash", "reason": "test reason"}' +` + if err := os.WriteFile(mockBinary, []byte(mockContent), 0755); err != nil { + t.Fatal(err) + } + + c := &Classifier{GeminiBinaryPath: mockBinary} + status := SystemStatus{ + ActiveTasks: map[string]int{"claude": 5, "gemini": 1}, + RateLimited: map[string]bool{"claude": false, "gemini": false}, + } + + cls, err := c.Classify(context.Background(), "Test Task", "Test Instructions", status) + if err != nil { + t.Fatalf("Classify failed: %v", err) + } + + if cls.AgentType != "gemini" { + t.Errorf("expected gemini, got %s", cls.AgentType) + } + if cls.Model != "gemini-2.0-flash" { + t.Errorf("expected gemini-2.0-flash, got %s", cls.Model) + } +} + +func filepathJoin(elems ...string) string { + var path string + for i, e := range elems { + if i == 0 { + path = e + } else { + path = path + string(os.PathSeparator) + e + } + } + return path +} diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 0786d86..6bd1c68 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -33,11 +33,14 @@ type Pool struct { store *storage.DB logger *slog.Logger - mu sync.Mutex - active int - cancels map[string]context.CancelFunc // taskID → cancel - resultCh chan *Result - Questions *QuestionRegistry + mu sync.Mutex + active int + activePerAgent map[string]int + rateLimited map[string]time.Time // agentType -> until + cancels map[string]context.CancelFunc // taskID → cancel + resultCh chan *Result + Questions *QuestionRegistry + Classifier *Classifier } // Result is emitted when a task execution completes. @@ -52,13 +55,15 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, lo maxConcurrent = 1 } return &Pool{ - maxConcurrent: maxConcurrent, - runners: runners, - store: store, - logger: logger, - cancels: make(map[string]context.CancelFunc), - resultCh: make(chan *Result, maxConcurrent*2), - Questions: NewQuestionRegistry(), + maxConcurrent: maxConcurrent, + runners: runners, + store: store, + logger: logger, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), + cancels: make(map[string]context.CancelFunc), + resultCh: make(chan *Result, maxConcurrent*2), + Questions: NewQuestionRegistry(), } } @@ -126,9 +131,19 @@ func (p *Pool) getRunner(t *task.Task) (Runner, error) { } func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Execution) { + agentType := t.Agent.Type + if agentType == "" { + agentType = "claude" + } + + p.mu.Lock() + p.activePerAgent[agentType]++ + p.mu.Unlock() + defer func() { p.mu.Lock() p.active-- + p.activePerAgent[agentType]-- p.mu.Unlock() }() @@ -169,6 +184,16 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex exec.EndTime = time.Now().UTC() if err != nil { + if isRateLimitError(err) { + p.mu.Lock() + retryAfter := parseRetryAfter(err.Error()) + if retryAfter == 0 { + retryAfter = 1 * time.Minute + } + p.rateLimited[agentType] = time.Now().Add(retryAfter) + p.mu.Unlock() + } + var blockedErr *BlockedError if errors.As(err, &blockedErr) { exec.Status = "BLOCKED" @@ -211,9 +236,45 @@ func (p *Pool) ActiveCount() int { } func (p *Pool) execute(ctx context.Context, t *task.Task) { + // 1. Classification + if p.Classifier != nil { + p.mu.Lock() + activeTasks := make(map[string]int) + rateLimited := make(map[string]bool) + now := time.Now() + for agent := range p.runners { + activeTasks[agent] = p.activePerAgent[agent] + rateLimited[agent] = now.Before(p.rateLimited[agent]) + } + status := SystemStatus{ + ActiveTasks: activeTasks, + RateLimited: rateLimited, + } + p.mu.Unlock() + + cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status) + if err == nil { + p.logger.Info("task classified", "taskID", t.ID, "agent", cls.AgentType, "model", cls.Model, "reason", cls.Reason) + t.Agent.Type = cls.AgentType + t.Agent.Model = cls.Model + } else { + p.logger.Error("classification failed", "error", err, "taskID", t.ID) + } + } + + agentType := t.Agent.Type + if agentType == "" { + agentType = "claude" + } + + p.mu.Lock() + p.activePerAgent[agentType]++ + p.mu.Unlock() + defer func() { p.mu.Lock() p.active-- + p.activePerAgent[agentType]-- p.mu.Unlock() }() @@ -306,6 +367,16 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { exec.EndTime = time.Now().UTC() if err != nil { + if isRateLimitError(err) { + p.mu.Lock() + retryAfter := parseRetryAfter(err.Error()) + if retryAfter == 0 { + retryAfter = 1 * time.Minute + } + p.rateLimited[agentType] = time.Now().Add(retryAfter) + p.mu.Unlock() + } + var blockedErr *BlockedError if errors.As(err, &blockedErr) { exec.Status = "BLOCKED" |
