diff options
Diffstat (limited to 'internal/api')
| -rw-r--r-- | internal/api/scripts.go | 50 | ||||
| -rw-r--r-- | internal/api/server.go | 52 | ||||
| -rw-r--r-- | internal/api/server_test.go | 140 |
3 files changed, 239 insertions, 3 deletions
diff --git a/internal/api/scripts.go b/internal/api/scripts.go new file mode 100644 index 0000000..492570b --- /dev/null +++ b/internal/api/scripts.go @@ -0,0 +1,50 @@ +package api + +import ( + "bytes" + "context" + "net/http" + "os/exec" + "path/filepath" + "time" +) + +const scriptTimeout = 30 * time.Second + +func (s *Server) startNextTaskScriptPath() string { + if s.startNextTaskScript != "" { + return s.startNextTaskScript + } + return filepath.Join(s.workDir, "scripts", "start-next-task") +} + +func (s *Server) handleStartNextTask(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), scriptTimeout) + defer cancel() + + scriptPath := s.startNextTaskScriptPath() + cmd := exec.CommandContext(ctx, scriptPath) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + exitCode := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + s.logger.Error("start-next-task: script execution failed", "error", err, "path", scriptPath) + writeJSON(w, http.StatusInternalServerError, map[string]string{ + "error": "script execution failed: " + err.Error(), + }) + return + } + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "output": stdout.String(), + "exit_code": exitCode, + }) +} diff --git a/internal/api/server.go b/internal/api/server.go index 608cdd4..bac98b6 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -24,9 +24,10 @@ type Server struct { hub *Hub logger *slog.Logger mux *http.ServeMux - claudeBinPath string // path to claude binary; defaults to "claude" - elaborateCmdPath string // overrides claudeBinPath; used in tests - workDir string // working directory injected into elaborate system prompt + claudeBinPath string // path to claude binary; defaults to "claude" + elaborateCmdPath string // overrides claudeBinPath; used in tests + startNextTaskScript string // path to start-next-task script; overridden in tests + workDir string // working directory injected into elaborate system prompt } func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath string) *Server { @@ -71,6 +72,8 @@ func (s *Server) routes() { s.mux.HandleFunc("GET /api/templates/{id}", s.handleGetTemplate) s.mux.HandleFunc("PUT /api/templates/{id}", s.handleUpdateTemplate) s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate) + s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion) + s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask) s.mux.HandleFunc("GET /api/ws", s.handleWebSocket) s.mux.HandleFunc("GET /api/health", s.handleHealth) s.mux.Handle("GET /", http.FileServerFS(webui.Files)) @@ -93,6 +96,49 @@ func (s *Server) forwardResults() { } } +// BroadcastQuestion sends a task_question event to all WebSocket clients. +func (s *Server) BroadcastQuestion(taskID, toolUseID string, questionData json.RawMessage) { + event := map[string]interface{}{ + "type": "task_question", + "task_id": taskID, + "question_id": toolUseID, + "data": json.RawMessage(questionData), + "timestamp": time.Now().UTC(), + } + data, _ := json.Marshal(event) + s.hub.Broadcast(data) +} + +func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { + taskID := r.PathValue("id") + + if _, err := s.store.GetTask(taskID); err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + + var input struct { + QuestionID string `json:"question_id"` + Answer string `json:"answer"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if input.QuestionID == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "question_id is required"}) + return + } + + ok := s.pool.Questions.Answer(input.QuestionID, input.Answer) + if !ok { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "no pending question with that ID"}) + return + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "delivered"}) +} + func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) } diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 5094961..af93a77 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -335,3 +335,143 @@ func TestCORS_Headers(t *testing.T) { t.Errorf("OPTIONS status: want 200, got %d", w.Code) } } + +func TestAnswerQuestion_NoTask_Returns404(t *testing.T) { + srv, _ := testServer(t) + + payload := `{"question_id": "toolu_abc", "answer": "blue"}` + req := httptest.NewRequest("POST", "/api/tasks/nonexistent/answer", bytes.NewBufferString(payload)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status: want 404, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestAnswerQuestion_NoPendingQuestion_Returns404(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "answer-task-1", task.StatePending) + + payload := `{"question_id": "toolu_nonexistent", "answer": "blue"}` + req := httptest.NewRequest("POST", "/api/tasks/answer-task-1/answer", bytes.NewBufferString(payload)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status: want 404, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if body["error"] != "no pending question with that ID" { + t.Errorf("error: want 'no pending question with that ID', got %q", body["error"]) + } +} + +func TestAnswerQuestion_WithPendingQuestion_Returns200(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "answer-task-2", task.StateRunning) + + ch := srv.pool.Questions.Register("answer-task-2", "toolu_Q1", []byte(`{}`)) + + go func() { + payload := `{"question_id": "toolu_Q1", "answer": "red"}` + req := httptest.NewRequest("POST", "/api/tasks/answer-task-2/answer", bytes.NewBufferString(payload)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } + }() + + answer := <-ch + if answer != "red" { + t.Errorf("answer: want 'red', got %q", answer) + } +} + +func TestAnswerQuestion_MissingQuestionID_Returns400(t *testing.T) { + srv, store := testServer(t) + createTaskWithState(t, store, "answer-task-3", task.StateRunning) + + payload := `{"answer": "blue"}` + req := httptest.NewRequest("POST", "/api/tasks/answer-task-3/answer", bytes.NewBufferString(payload)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status: want 400, got %d", w.Code) + } +} + +func TestHandleStartNextTask_Success(t *testing.T) { + dir := t.TempDir() + script := filepath.Join(dir, "start-next-task") + if err := os.WriteFile(script, []byte("#!/bin/sh\necho 'claudomator start abc-123'\n"), 0755); err != nil { + t.Fatal(err) + } + + srv, _ := testServer(t) + srv.startNextTaskScript = script + + req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]interface{} + json.NewDecoder(w.Body).Decode(&body) + if body["output"] != "claudomator start abc-123\n" { + t.Errorf("unexpected output: %v", body["output"]) + } + if body["exit_code"] != float64(0) { + t.Errorf("unexpected exit_code: %v", body["exit_code"]) + } +} + +func TestHandleStartNextTask_NoTask(t *testing.T) { + dir := t.TempDir() + script := filepath.Join(dir, "start-next-task") + if err := os.WriteFile(script, []byte("#!/bin/sh\necho 'No task to start.'\n"), 0755); err != nil { + t.Fatal(err) + } + + srv, _ := testServer(t) + srv.startNextTaskScript = script + + req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]interface{} + json.NewDecoder(w.Body).Decode(&body) + if body["output"] != "No task to start.\n" { + t.Errorf("unexpected output: %v", body["output"]) + } +} + +func TestHandleStartNextTask_ScriptNotFound(t *testing.T) { + srv, _ := testServer(t) + srv.startNextTaskScript = "/nonexistent/start-next-task" + + req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("want 500, got %d; body: %s", w.Code, w.Body.String()) + } +} |
