diff options
Diffstat (limited to 'internal/api')
| -rw-r--r-- | internal/api/server.go | 27 | ||||
| -rw-r--r-- | internal/api/server_test.go | 140 |
2 files changed, 163 insertions, 4 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index 0868295..a6e708a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -203,7 +203,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) { } // If the task is actively running in the pool, cancel it there. if s.pool.Cancel(taskID) { - writeJSON(w, http.StatusOK, map[string]string{"status": "cancelling"}) + writeJSON(w, http.StatusOK, map[string]string{"message": "task cancellation requested", "task_id": taskID}) return } // For non-running tasks (PENDING, QUEUED), transition directly to CANCELLED. @@ -215,7 +215,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to cancel task"}) return } - writeJSON(w, http.StatusOK, map[string]string{"status": "cancelled"}) + writeJSON(w, http.StatusOK, map[string]string{"message": "task cancelled", "task_id": taskID}) } func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { @@ -272,7 +272,7 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { return } - writeJSON(w, http.StatusOK, map[string]string{"status": "queued"}) + writeJSON(w, http.StatusOK, map[string]string{"message": "task queued for resume", "task_id": taskID}) } func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request) { @@ -412,10 +412,29 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusCreated, t) } +// validTaskStates is the set of all known task states for query param validation. +var validTaskStates = map[task.State]bool{ + task.StatePending: true, + task.StateQueued: true, + task.StateRunning: true, + task.StateReady: true, + task.StateCompleted: true, + task.StateFailed: true, + task.StateTimedOut: true, + task.StateCancelled: true, + task.StateBudgetExceeded: true, + task.StateBlocked: true, +} + func (s *Server) handleListTasks(w http.ResponseWriter, r *http.Request) { filter := storage.TaskFilter{} if state := r.URL.Query().Get("state"); state != "" { - filter.State = task.State(state) + ts := task.State(state) + if !validTaskStates[ts] { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid state: " + state}) + return + } + filter.State = ts } tasks, err := s.store.ListTasks(filter) if err != nil { diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 5047af6..9b5c7ae 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -1039,6 +1039,146 @@ func TestListWorkspaces_SuppressesRawError(t *testing.T) { } } +func TestListTasks_InvalidState_Returns400(t *testing.T) { + srv, _ := testServer(t) + + req := httptest.NewRequest("GET", "/api/tasks?state=BOGUS", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status: want 400, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if body["error"] == "" { + t.Error("expected non-empty error message") + } +} + +func TestListTasks_ValidState_Returns200(t *testing.T) { + srv, _ := testServer(t) + + req := httptest.NewRequest("GET", "/api/tasks?state=PENDING", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestCancelTask_ResponseShape(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "cancel-shape-1", task.StatePending) + + req := httptest.NewRequest("POST", "/api/tasks/cancel-shape-1/cancel", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if body["task_id"] != "cancel-shape-1" { + t.Errorf("task_id: want 'cancel-shape-1', got %q", body["task_id"]) + } + if body["message"] == "" { + t.Error("expected non-empty message field") + } + if _, hasStatus := body["status"]; hasStatus { + t.Error("response must not contain legacy 'status' field") + } +} + +func TestAnswerQuestion_ResponseShape(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "answer-shape-1", task.StateBlocked) + + exec := &storage.Execution{ + ID: "exec-shape-1", + TaskID: "answer-shape-1", + SessionID: "550e8400-e29b-41d4-a716-446655440010", + Status: "BLOCKED", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatalf("create execution: %v", err) + } + + req := httptest.NewRequest("POST", "/api/tasks/answer-shape-1/answer", bytes.NewBufferString(`{"answer":"yes"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if body["task_id"] != "answer-shape-1" { + t.Errorf("task_id: want 'answer-shape-1', got %q", body["task_id"]) + } + if body["message"] == "" { + t.Error("expected non-empty message field") + } + if _, hasStatus := body["status"]; hasStatus { + t.Error("response must not contain legacy 'status' field") + } +} + +func TestRunTask_ResponseShape(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "run-shape-1", task.StatePending) + + req := httptest.NewRequest("POST", "/api/tasks/run-shape-1/run", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Fatalf("status: want 202, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if body["task_id"] != "run-shape-1" { + t.Errorf("task_id: want 'run-shape-1', got %q", body["task_id"]) + } + if body["message"] == "" { + t.Error("expected non-empty message field") + } +} + +func TestResumeTimedOut_ResponseShape(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "resume-shape-1", task.StateTimedOut) + + exec := &storage.Execution{ + ID: "exec-resume-shape-1", + TaskID: "resume-shape-1", + SessionID: "550e8400-e29b-41d4-a716-446655440020", + Status: "TIMED_OUT", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatalf("create execution: %v", err) + } + + req := httptest.NewRequest("POST", "/api/tasks/resume-shape-1/resume", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Fatalf("status: want 202, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if body["task_id"] != "resume-shape-1" { + t.Errorf("task_id: want 'resume-shape-1', got %q", body["task_id"]) + } + if body["message"] == "" { + t.Error("expected non-empty message field") + } +} + func TestRateLimit_ValidateRejectsExcess(t *testing.T) { srv, _ := testServer(t) srv.elaborateLimiter = newIPRateLimiter(0, 1) |
