summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudomator Agent <agent@claudomator>2026-03-09 07:30:18 +0000
committerClaudomator Agent <agent@claudomator>2026-03-09 07:30:24 +0000
commit933af819ae3ee7ea0cf6b750815ab185043e19fc (patch)
treed3de98a3168e578cbe06222e7141c66e678eb82a
parentb33566b534185a444a392c36e3f307a5ecad8d4b (diff)
api: validate ?state= param in handleListTasks; standardize operation response shapes
- handleListTasks: validate ?state= against known states, return 400 with clear error for unrecognized values (e.g. ?state=BOGUS) - handleCancelTask: replace {"status":"cancelling"|"cancelled"} with {"message":"...","task_id":"..."} to match run/resume shape - handleAnswerQuestion: replace {"status":"queued"} with {"message":"task queued for resume","task_id":"..."} - Tests: add TestListTasks_InvalidState_Returns400, TestListTasks_ValidState_Returns200, TestCancelTask_ResponseShape, TestAnswerQuestion_ResponseShape, TestRunTask_ResponseShape, TestResumeTimedOut_ResponseShape Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
-rw-r--r--internal/api/server.go27
-rw-r--r--internal/api/server_test.go140
2 files changed, 163 insertions, 4 deletions
diff --git a/internal/api/server.go b/internal/api/server.go
index 0868295..a6e708a 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -203,7 +203,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) {
}
// If the task is actively running in the pool, cancel it there.
if s.pool.Cancel(taskID) {
- writeJSON(w, http.StatusOK, map[string]string{"status": "cancelling"})
+ writeJSON(w, http.StatusOK, map[string]string{"message": "task cancellation requested", "task_id": taskID})
return
}
// For non-running tasks (PENDING, QUEUED), transition directly to CANCELLED.
@@ -215,7 +215,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to cancel task"})
return
}
- writeJSON(w, http.StatusOK, map[string]string{"status": "cancelled"})
+ writeJSON(w, http.StatusOK, map[string]string{"message": "task cancelled", "task_id": taskID})
}
func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
@@ -272,7 +272,7 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
return
}
- writeJSON(w, http.StatusOK, map[string]string{"status": "queued"})
+ writeJSON(w, http.StatusOK, map[string]string{"message": "task queued for resume", "task_id": taskID})
}
func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request) {
@@ -412,10 +412,29 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, t)
}
+// validTaskStates is the set of all known task states for query param validation.
+var validTaskStates = map[task.State]bool{
+ task.StatePending: true,
+ task.StateQueued: true,
+ task.StateRunning: true,
+ task.StateReady: true,
+ task.StateCompleted: true,
+ task.StateFailed: true,
+ task.StateTimedOut: true,
+ task.StateCancelled: true,
+ task.StateBudgetExceeded: true,
+ task.StateBlocked: true,
+}
+
func (s *Server) handleListTasks(w http.ResponseWriter, r *http.Request) {
filter := storage.TaskFilter{}
if state := r.URL.Query().Get("state"); state != "" {
- filter.State = task.State(state)
+ ts := task.State(state)
+ if !validTaskStates[ts] {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid state: " + state})
+ return
+ }
+ filter.State = ts
}
tasks, err := s.store.ListTasks(filter)
if err != nil {
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index 5047af6..9b5c7ae 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -1039,6 +1039,146 @@ func TestListWorkspaces_SuppressesRawError(t *testing.T) {
}
}
+func TestListTasks_InvalidState_Returns400(t *testing.T) {
+ srv, _ := testServer(t)
+
+ req := httptest.NewRequest("GET", "/api/tasks?state=BOGUS", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Fatalf("status: want 400, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["error"] == "" {
+ t.Error("expected non-empty error message")
+ }
+}
+
+func TestListTasks_ValidState_Returns200(t *testing.T) {
+ srv, _ := testServer(t)
+
+ req := httptest.NewRequest("GET", "/api/tasks?state=PENDING", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestCancelTask_ResponseShape(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "cancel-shape-1", task.StatePending)
+
+ req := httptest.NewRequest("POST", "/api/tasks/cancel-shape-1/cancel", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["task_id"] != "cancel-shape-1" {
+ t.Errorf("task_id: want 'cancel-shape-1', got %q", body["task_id"])
+ }
+ if body["message"] == "" {
+ t.Error("expected non-empty message field")
+ }
+ if _, hasStatus := body["status"]; hasStatus {
+ t.Error("response must not contain legacy 'status' field")
+ }
+}
+
+func TestAnswerQuestion_ResponseShape(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "answer-shape-1", task.StateBlocked)
+
+ exec := &storage.Execution{
+ ID: "exec-shape-1",
+ TaskID: "answer-shape-1",
+ SessionID: "550e8400-e29b-41d4-a716-446655440010",
+ Status: "BLOCKED",
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatalf("create execution: %v", err)
+ }
+
+ req := httptest.NewRequest("POST", "/api/tasks/answer-shape-1/answer", bytes.NewBufferString(`{"answer":"yes"}`))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["task_id"] != "answer-shape-1" {
+ t.Errorf("task_id: want 'answer-shape-1', got %q", body["task_id"])
+ }
+ if body["message"] == "" {
+ t.Error("expected non-empty message field")
+ }
+ if _, hasStatus := body["status"]; hasStatus {
+ t.Error("response must not contain legacy 'status' field")
+ }
+}
+
+func TestRunTask_ResponseShape(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "run-shape-1", task.StatePending)
+
+ req := httptest.NewRequest("POST", "/api/tasks/run-shape-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["task_id"] != "run-shape-1" {
+ t.Errorf("task_id: want 'run-shape-1', got %q", body["task_id"])
+ }
+ if body["message"] == "" {
+ t.Error("expected non-empty message field")
+ }
+}
+
+func TestResumeTimedOut_ResponseShape(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "resume-shape-1", task.StateTimedOut)
+
+ exec := &storage.Execution{
+ ID: "exec-resume-shape-1",
+ TaskID: "resume-shape-1",
+ SessionID: "550e8400-e29b-41d4-a716-446655440020",
+ Status: "TIMED_OUT",
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatalf("create execution: %v", err)
+ }
+
+ req := httptest.NewRequest("POST", "/api/tasks/resume-shape-1/resume", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["task_id"] != "resume-shape-1" {
+ t.Errorf("task_id: want 'resume-shape-1', got %q", body["task_id"])
+ }
+ if body["message"] == "" {
+ t.Error("expected non-empty message field")
+ }
+}
+
func TestRateLimit_ValidateRejectsExcess(t *testing.T) {
srv, _ := testServer(t)
srv.elaborateLimiter = newIPRateLimiter(0, 1)