diff options
| -rw-r--r-- | internal/api/server.go | 23 | ||||
| -rw-r--r-- | internal/api/server_test.go | 5 | ||||
| -rw-r--r-- | internal/executor/executor.go | 8 | ||||
| -rw-r--r-- | internal/executor/ratelimit.go | 3 | ||||
| -rw-r--r-- | internal/storage/db.go | 38 |
5 files changed, 52 insertions, 25 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index a6e708a..944e450 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -459,24 +459,19 @@ func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) { func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - t, err := s.store.GetTask(id) + t, err := s.store.ResetTaskForRetry(id) if err != nil { - writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) - return - } - - if !task.ValidTransition(t.State, task.StateQueued) { - writeJSON(w, http.StatusConflict, map[string]string{ - "error": fmt.Sprintf("task cannot be queued from state %s", t.State), - }) - return - } - - if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil { + if strings.Contains(err.Error(), "not found") { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + if strings.Contains(err.Error(), "invalid state transition") { + writeJSON(w, http.StatusConflict, map[string]string{"error": err.Error()}) + return + } writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } - t.State = task.StateQueued if err := s.pool.Submit(context.Background(), t); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)}) diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 9b5c7ae..afdc9d2 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -397,9 +397,8 @@ func TestRunTask_CompletedTask_Returns409(t *testing.T) { } var body map[string]string json.NewDecoder(w.Body).Decode(&body) - wantMsg := "task cannot be queued from state COMPLETED" - if body["error"] != wantMsg { - t.Errorf("error body: want %q, got %q", wantMsg, body["error"]) + if !strings.Contains(body["error"], "invalid state transition") { + t.Errorf("error body: want it to contain 'invalid state transition', got %q", body["error"]) } } diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 8924830..c04f68e 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -355,18 +355,12 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { if deadline, ok := p.rateLimited[agent]; ok && now.After(deadline) { delete(p.rateLimited, agent) } - activeUntil := p.rateLimited[agent] - isLimited := now.Before(activeUntil) - rateLimited[agent] = isLimited - if isLimited { - p.logger.Debug("agent rate limited", "agent", agent, "until", activeUntil) - } + rateLimited[agent] = now.Before(p.rateLimited[agent]) } status := SystemStatus{ ActiveTasks: activeTasks, RateLimited: rateLimited, } - p.logger.Debug("classifying task", "taskID", t.ID, "status", status) p.mu.Unlock() cls, err := p.Classifier.Classify(ctx, t.Name, t.Agent.Instructions, status) diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go index aa9df99..1f38a6d 100644 --- a/internal/executor/ratelimit.go +++ b/internal/executor/ratelimit.go @@ -36,7 +36,8 @@ func isQuotaExhausted(err error) bool { return strings.Contains(msg, "hit your limit") || strings.Contains(msg, "you've hit your limit") || strings.Contains(msg, "you have hit your limit") || - strings.Contains(msg, "rate limit reached (rejected)") + strings.Contains(msg, "rate limit reached (rejected)") || + strings.Contains(msg, "status: rejected") } // parseRetryAfter extracts a Retry-After duration from an error message. diff --git a/internal/storage/db.go b/internal/storage/db.go index 31d38ed..01ce902 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -210,6 +210,44 @@ func (s *DB) UpdateTaskState(id string, newState task.State) error { return tx.Commit() } +// ResetTaskForRetry sets a task to QUEUED and clears its agent type/model +// so it can be re-classified. +func (s *DB) ResetTaskForRetry(id string) (*task.Task, error) { + tx, err := s.db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() //nolint:errcheck + + t, err := scanTask(tx.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json FROM tasks WHERE id = ?`, id)) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("task %q not found", id) + } + return nil, err + } + + if !task.ValidTransition(t.State, task.StateQueued) { + return nil, fmt.Errorf("invalid state transition %s → %s for task %q", t.State, task.StateQueued, id) + } + + t.Agent.Type = "" + t.Agent.Model = "" + configJSON, _ := json.Marshal(t.Agent) + + now := time.Now().UTC() + if _, err := tx.Exec(`UPDATE tasks SET state = ?, config_json = ?, updated_at = ? WHERE id = ?`, + string(task.StateQueued), string(configJSON), now, id); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + t.State = task.StateQueued + return t, nil +} + // RejectTask sets a task's state to PENDING and stores the rejection comment. func (s *DB) RejectTask(id, comment string) error { now := time.Now().UTC() |
