diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/server.go | 14 | ||||
| -rw-r--r-- | internal/executor/executor.go | 27 | ||||
| -rw-r--r-- | internal/executor/executor_test.go | 32 |
3 files changed, 71 insertions, 2 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index 5758347..18c58e9 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -61,6 +61,7 @@ func (s *Server) routes() { s.mux.HandleFunc("GET /api/tasks", s.handleListTasks) s.mux.HandleFunc("GET /api/tasks/{id}", s.handleGetTask) s.mux.HandleFunc("POST /api/tasks/{id}/run", s.handleRunTask) + s.mux.HandleFunc("POST /api/tasks/{id}/cancel", s.handleCancelTask) s.mux.HandleFunc("POST /api/tasks/{id}/accept", s.handleAcceptTask) s.mux.HandleFunc("POST /api/tasks/{id}/reject", s.handleRejectTask) s.mux.HandleFunc("GET /api/tasks/{id}/subtasks", s.handleListSubtasks) @@ -109,6 +110,19 @@ func (s *Server) BroadcastQuestion(taskID, toolUseID string, questionData json.R s.hub.Broadcast(data) } +func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) { + taskID := r.PathValue("id") + if _, err := s.store.GetTask(taskID); err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + if !s.pool.Cancel(taskID) { + writeJSON(w, http.StatusConflict, map[string]string{"error": "task is not running"}) + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "cancelling"}) +} + func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") diff --git a/internal/executor/executor.go b/internal/executor/executor.go index f6932f4..62fed2e 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -35,6 +35,7 @@ type Pool struct { mu sync.Mutex active int + cancels map[string]context.CancelFunc // taskID → cancel resultCh chan *Result Questions *QuestionRegistry } @@ -55,6 +56,7 @@ func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.L runner: runner, store: store, logger: logger, + cancels: make(map[string]context.CancelFunc), resultCh: make(chan *Result, maxConcurrent*2), Questions: NewQuestionRegistry(), } @@ -81,6 +83,19 @@ func (p *Pool) Results() <-chan *Result { return p.resultCh } +// Cancel requests cancellation of a running task. Returns false if the task +// is not currently running in this pool. +func (p *Pool) Cancel(taskID string) bool { + p.mu.Lock() + cancel, ok := p.cancels[taskID] + p.mu.Unlock() + if !ok { + return false + } + cancel() + return true +} + // SubmitResume re-queues a blocked task using the provided resume execution. // The execution must have ResumeSessionID and ResumeAnswer set. func (p *Pool) SubmitResume(ctx context.Context, t *task.Task, exec *storage.Execution) error { @@ -230,14 +245,22 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { p.logger.Error("failed to update task state", "error", err) } - // Apply task timeout. + // Apply task timeout and register cancel so callers can stop this task. var cancel context.CancelFunc if t.Timeout.Duration > 0 { ctx, cancel = context.WithTimeout(ctx, t.Timeout.Duration) } else { ctx, cancel = context.WithCancel(ctx) } - defer cancel() + p.mu.Lock() + p.cancels[t.ID] = cancel + p.mu.Unlock() + defer func() { + cancel() + p.mu.Lock() + delete(p.cancels, t.ID) + p.mu.Unlock() + }() // Run the task. err := p.runner.Run(ctx, t, exec) diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index b3e6dae..6d13873 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -185,6 +185,38 @@ func TestPool_Submit_Cancellation(t *testing.T) { } } +func TestPool_Cancel_StopsRunningTask(t *testing.T) { + store := testStore(t) + runner := &mockRunner{delay: 5 * time.Second} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + tk := makeTask("cancel-1") + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + time.Sleep(20 * time.Millisecond) // let goroutine start + + if ok := pool.Cancel("cancel-1"); !ok { + t.Fatal("Cancel returned false for a running task") + } + + result := <-pool.Results() + if result.Execution.Status != "CANCELLED" { + t.Errorf("status: want CANCELLED, got %q", result.Execution.Status) + } +} + +func TestPool_Cancel_UnknownTask_ReturnsFalse(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + if ok := pool.Cancel("nonexistent"); ok { + t.Error("Cancel returned true for unknown task") + } +} + func TestPool_AtCapacity(t *testing.T) { store := testStore(t) runner := &mockRunner{delay: time.Second} |
