summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/api/server.go14
-rw-r--r--internal/executor/executor.go27
-rw-r--r--internal/executor/executor_test.go32
3 files changed, 71 insertions, 2 deletions
diff --git a/internal/api/server.go b/internal/api/server.go
index 5758347..18c58e9 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -61,6 +61,7 @@ func (s *Server) routes() {
s.mux.HandleFunc("GET /api/tasks", s.handleListTasks)
s.mux.HandleFunc("GET /api/tasks/{id}", s.handleGetTask)
s.mux.HandleFunc("POST /api/tasks/{id}/run", s.handleRunTask)
+ s.mux.HandleFunc("POST /api/tasks/{id}/cancel", s.handleCancelTask)
s.mux.HandleFunc("POST /api/tasks/{id}/accept", s.handleAcceptTask)
s.mux.HandleFunc("POST /api/tasks/{id}/reject", s.handleRejectTask)
s.mux.HandleFunc("GET /api/tasks/{id}/subtasks", s.handleListSubtasks)
@@ -109,6 +110,19 @@ func (s *Server) BroadcastQuestion(taskID, toolUseID string, questionData json.R
s.hub.Broadcast(data)
}
+func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) {
+ taskID := r.PathValue("id")
+ if _, err := s.store.GetTask(taskID); err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
+ return
+ }
+ if !s.pool.Cancel(taskID) {
+ writeJSON(w, http.StatusConflict, map[string]string{"error": "task is not running"})
+ return
+ }
+ writeJSON(w, http.StatusOK, map[string]string{"status": "cancelling"})
+}
+
func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
taskID := r.PathValue("id")
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index f6932f4..62fed2e 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -35,6 +35,7 @@ type Pool struct {
mu sync.Mutex
active int
+ cancels map[string]context.CancelFunc // taskID → cancel
resultCh chan *Result
Questions *QuestionRegistry
}
@@ -55,6 +56,7 @@ func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.L
runner: runner,
store: store,
logger: logger,
+ cancels: make(map[string]context.CancelFunc),
resultCh: make(chan *Result, maxConcurrent*2),
Questions: NewQuestionRegistry(),
}
@@ -81,6 +83,19 @@ func (p *Pool) Results() <-chan *Result {
return p.resultCh
}
+// Cancel requests cancellation of a running task. Returns false if the task
+// is not currently running in this pool.
+func (p *Pool) Cancel(taskID string) bool {
+ p.mu.Lock()
+ cancel, ok := p.cancels[taskID]
+ p.mu.Unlock()
+ if !ok {
+ return false
+ }
+ cancel()
+ return true
+}
+
// SubmitResume re-queues a blocked task using the provided resume execution.
// The execution must have ResumeSessionID and ResumeAnswer set.
func (p *Pool) SubmitResume(ctx context.Context, t *task.Task, exec *storage.Execution) error {
@@ -230,14 +245,22 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
p.logger.Error("failed to update task state", "error", err)
}
- // Apply task timeout.
+ // Apply task timeout and register cancel so callers can stop this task.
var cancel context.CancelFunc
if t.Timeout.Duration > 0 {
ctx, cancel = context.WithTimeout(ctx, t.Timeout.Duration)
} else {
ctx, cancel = context.WithCancel(ctx)
}
- defer cancel()
+ p.mu.Lock()
+ p.cancels[t.ID] = cancel
+ p.mu.Unlock()
+ defer func() {
+ cancel()
+ p.mu.Lock()
+ delete(p.cancels, t.ID)
+ p.mu.Unlock()
+ }()
// Run the task.
err := p.runner.Run(ctx, t, exec)
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
index b3e6dae..6d13873 100644
--- a/internal/executor/executor_test.go
+++ b/internal/executor/executor_test.go
@@ -185,6 +185,38 @@ func TestPool_Submit_Cancellation(t *testing.T) {
}
}
+func TestPool_Cancel_StopsRunningTask(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{delay: 5 * time.Second}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runner, store, logger)
+
+ tk := makeTask("cancel-1")
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+ time.Sleep(20 * time.Millisecond) // let goroutine start
+
+ if ok := pool.Cancel("cancel-1"); !ok {
+ t.Fatal("Cancel returned false for a running task")
+ }
+
+ result := <-pool.Results()
+ if result.Execution.Status != "CANCELLED" {
+ t.Errorf("status: want CANCELLED, got %q", result.Execution.Status)
+ }
+}
+
+func TestPool_Cancel_UnknownTask_ReturnsFalse(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runner, store, logger)
+
+ if ok := pool.Cancel("nonexistent"); ok {
+ t.Error("Cancel returned true for unknown task")
+ }
+}
+
func TestPool_AtCapacity(t *testing.T) {
store := testStore(t)
runner := &mockRunner{delay: time.Second}