summaryrefslogtreecommitdiff
path: root/internal/executor/executor.go
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-08 20:50:21 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-08 20:50:21 +0000
commit406247b14985ab57902e8e42898dc8cb8960290d (patch)
tree4a93be793f541038dd5d3fc154563051ba151b50 /internal/executor/executor.go
parent0ff0bf75544bbf565288e61bb8e10c3f903830f8 (diff)
feat(executor): implement Gemini-based task classification and load balancing
- Add Classifier using gemini-2.0-flash-lite to automatically select agent/model. - Update Pool to track per-agent active tasks and rate limit status. - Enable classification for all tasks (top-level and subtasks). - Refine SystemStatus to be dynamic across all supported agents. - Add unit tests for the classifier and updated pool logic. - Minor UI improvements for project selection and 'Start Next' action.
Diffstat (limited to 'internal/executor/executor.go')
-rw-r--r--internal/executor/executor.go95
1 files changed, 83 insertions, 12 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 0786d86..6bd1c68 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -33,11 +33,14 @@ type Pool struct {
store *storage.DB
logger *slog.Logger
- mu sync.Mutex
- active int
- cancels map[string]context.CancelFunc // taskID → cancel
- resultCh chan *Result
- Questions *QuestionRegistry
+ mu sync.Mutex
+ active int
+ activePerAgent map[string]int
+ rateLimited map[string]time.Time // agentType -> until
+ cancels map[string]context.CancelFunc // taskID → cancel
+ resultCh chan *Result
+ Questions *QuestionRegistry
+ Classifier *Classifier
}
// Result is emitted when a task execution completes.
@@ -52,13 +55,15 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, lo
maxConcurrent = 1
}
return &Pool{
- maxConcurrent: maxConcurrent,
- runners: runners,
- store: store,
- logger: logger,
- cancels: make(map[string]context.CancelFunc),
- resultCh: make(chan *Result, maxConcurrent*2),
- Questions: NewQuestionRegistry(),
+ maxConcurrent: maxConcurrent,
+ runners: runners,
+ store: store,
+ logger: logger,
+ activePerAgent: make(map[string]int),
+ rateLimited: make(map[string]time.Time),
+ cancels: make(map[string]context.CancelFunc),
+ resultCh: make(chan *Result, maxConcurrent*2),
+ Questions: NewQuestionRegistry(),
}
}
@@ -126,9 +131,19 @@ func (p *Pool) getRunner(t *task.Task) (Runner, error) {
}
func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Execution) {
+ agentType := t.Agent.Type
+ if agentType == "" {
+ agentType = "claude"
+ }
+
+ p.mu.Lock()
+ p.activePerAgent[agentType]++
+ p.mu.Unlock()
+
defer func() {
p.mu.Lock()
p.active--
+ p.activePerAgent[agentType]--
p.mu.Unlock()
}()
@@ -169,6 +184,16 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex
exec.EndTime = time.Now().UTC()
if err != nil {
+ if isRateLimitError(err) {
+ p.mu.Lock()
+ retryAfter := parseRetryAfter(err.Error())
+ if retryAfter == 0 {
+ retryAfter = 1 * time.Minute
+ }
+ p.rateLimited[agentType] = time.Now().Add(retryAfter)
+ p.mu.Unlock()
+ }
+
var blockedErr *BlockedError
if errors.As(err, &blockedErr) {
exec.Status = "BLOCKED"
@@ -211,9 +236,45 @@ func (p *Pool) ActiveCount() int {
}
func (p *Pool) execute(ctx context.Context, t *task.Task) {
+ // 1. Classification
+ if p.Classifier != nil {
+ p.mu.Lock()
+ activeTasks := make(map[string]int)
+ rateLimited := make(map[string]bool)
+ now := time.Now()
+ for agent := range p.runners {
+ activeTasks[agent] = p.activePerAgent[agent]
+ rateLimited[agent] = now.Before(p.rateLimited[agent])
+ }
+ status := SystemStatus{
+ ActiveTasks: activeTasks,
+ RateLimited: rateLimited,
+ }
+ p.mu.Unlock()
+
+ cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status)
+ if err == nil {
+ p.logger.Info("task classified", "taskID", t.ID, "agent", cls.AgentType, "model", cls.Model, "reason", cls.Reason)
+ t.Agent.Type = cls.AgentType
+ t.Agent.Model = cls.Model
+ } else {
+ p.logger.Error("classification failed", "error", err, "taskID", t.ID)
+ }
+ }
+
+ agentType := t.Agent.Type
+ if agentType == "" {
+ agentType = "claude"
+ }
+
+ p.mu.Lock()
+ p.activePerAgent[agentType]++
+ p.mu.Unlock()
+
defer func() {
p.mu.Lock()
p.active--
+ p.activePerAgent[agentType]--
p.mu.Unlock()
}()
@@ -306,6 +367,16 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
exec.EndTime = time.Now().UTC()
if err != nil {
+ if isRateLimitError(err) {
+ p.mu.Lock()
+ retryAfter := parseRetryAfter(err.Error())
+ if retryAfter == 0 {
+ retryAfter = 1 * time.Minute
+ }
+ p.rateLimited[agentType] = time.Now().Add(retryAfter)
+ p.mu.Unlock()
+ }
+
var blockedErr *BlockedError
if errors.As(err, &blockedErr) {
exec.Status = "BLOCKED"