summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/api/scripts.go50
-rw-r--r--internal/api/server.go52
-rw-r--r--internal/api/server_test.go140
-rw-r--r--internal/executor/claude.go98
-rw-r--r--internal/executor/executor.go8
-rw-r--r--internal/executor/question.go172
-rw-r--r--internal/executor/question_test.go253
-rw-r--r--internal/executor/ratelimit.go76
-rw-r--r--internal/executor/ratelimit_test.go170
-rw-r--r--web/app.js481
-rw-r--r--web/index.html6
-rw-r--r--web/style.css262
-rw-r--r--web/test/start-next-task.test.mjs84
-rw-r--r--web/test/task-actions.test.mjs53
14 files changed, 1864 insertions, 41 deletions
diff --git a/internal/api/scripts.go b/internal/api/scripts.go
new file mode 100644
index 0000000..492570b
--- /dev/null
+++ b/internal/api/scripts.go
@@ -0,0 +1,50 @@
+package api
+
+import (
+ "bytes"
+ "context"
+ "net/http"
+ "os/exec"
+ "path/filepath"
+ "time"
+)
+
+const scriptTimeout = 30 * time.Second
+
+func (s *Server) startNextTaskScriptPath() string {
+ if s.startNextTaskScript != "" {
+ return s.startNextTaskScript
+ }
+ return filepath.Join(s.workDir, "scripts", "start-next-task")
+}
+
+func (s *Server) handleStartNextTask(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithTimeout(r.Context(), scriptTimeout)
+ defer cancel()
+
+ scriptPath := s.startNextTaskScriptPath()
+ cmd := exec.CommandContext(ctx, scriptPath)
+
+ var stdout, stderr bytes.Buffer
+ cmd.Stdout = &stdout
+ cmd.Stderr = &stderr
+
+ err := cmd.Run()
+ exitCode := 0
+ if err != nil {
+ if exitErr, ok := err.(*exec.ExitError); ok {
+ exitCode = exitErr.ExitCode()
+ } else {
+ s.logger.Error("start-next-task: script execution failed", "error", err, "path", scriptPath)
+ writeJSON(w, http.StatusInternalServerError, map[string]string{
+ "error": "script execution failed: " + err.Error(),
+ })
+ return
+ }
+ }
+
+ writeJSON(w, http.StatusOK, map[string]interface{}{
+ "output": stdout.String(),
+ "exit_code": exitCode,
+ })
+}
diff --git a/internal/api/server.go b/internal/api/server.go
index 608cdd4..bac98b6 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -24,9 +24,10 @@ type Server struct {
hub *Hub
logger *slog.Logger
mux *http.ServeMux
- claudeBinPath string // path to claude binary; defaults to "claude"
- elaborateCmdPath string // overrides claudeBinPath; used in tests
- workDir string // working directory injected into elaborate system prompt
+ claudeBinPath string // path to claude binary; defaults to "claude"
+ elaborateCmdPath string // overrides claudeBinPath; used in tests
+ startNextTaskScript string // path to start-next-task script; overridden in tests
+ workDir string // working directory injected into elaborate system prompt
}
func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath string) *Server {
@@ -71,6 +72,8 @@ func (s *Server) routes() {
s.mux.HandleFunc("GET /api/templates/{id}", s.handleGetTemplate)
s.mux.HandleFunc("PUT /api/templates/{id}", s.handleUpdateTemplate)
s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate)
+ s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion)
+ s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask)
s.mux.HandleFunc("GET /api/ws", s.handleWebSocket)
s.mux.HandleFunc("GET /api/health", s.handleHealth)
s.mux.Handle("GET /", http.FileServerFS(webui.Files))
@@ -93,6 +96,49 @@ func (s *Server) forwardResults() {
}
}
+// BroadcastQuestion sends a task_question event to all WebSocket clients.
+func (s *Server) BroadcastQuestion(taskID, toolUseID string, questionData json.RawMessage) {
+ event := map[string]interface{}{
+ "type": "task_question",
+ "task_id": taskID,
+ "question_id": toolUseID,
+ "data": json.RawMessage(questionData),
+ "timestamp": time.Now().UTC(),
+ }
+ data, _ := json.Marshal(event)
+ s.hub.Broadcast(data)
+}
+
+func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
+ taskID := r.PathValue("id")
+
+ if _, err := s.store.GetTask(taskID); err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
+ return
+ }
+
+ var input struct {
+ QuestionID string `json:"question_id"`
+ Answer string `json:"answer"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if input.QuestionID == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "question_id is required"})
+ return
+ }
+
+ ok := s.pool.Questions.Answer(input.QuestionID, input.Answer)
+ if !ok {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "no pending question with that ID"})
+ return
+ }
+
+ writeJSON(w, http.StatusOK, map[string]string{"status": "delivered"})
+}
+
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index 5094961..af93a77 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -335,3 +335,143 @@ func TestCORS_Headers(t *testing.T) {
t.Errorf("OPTIONS status: want 200, got %d", w.Code)
}
}
+
+func TestAnswerQuestion_NoTask_Returns404(t *testing.T) {
+ srv, _ := testServer(t)
+
+ payload := `{"question_id": "toolu_abc", "answer": "blue"}`
+ req := httptest.NewRequest("POST", "/api/tasks/nonexistent/answer", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusNotFound {
+ t.Errorf("status: want 404, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestAnswerQuestion_NoPendingQuestion_Returns404(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "answer-task-1", task.StatePending)
+
+ payload := `{"question_id": "toolu_nonexistent", "answer": "blue"}`
+ req := httptest.NewRequest("POST", "/api/tasks/answer-task-1/answer", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusNotFound {
+ t.Errorf("status: want 404, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["error"] != "no pending question with that ID" {
+ t.Errorf("error: want 'no pending question with that ID', got %q", body["error"])
+ }
+}
+
+func TestAnswerQuestion_WithPendingQuestion_Returns200(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "answer-task-2", task.StateRunning)
+
+ ch := srv.pool.Questions.Register("answer-task-2", "toolu_Q1", []byte(`{}`))
+
+ go func() {
+ payload := `{"question_id": "toolu_Q1", "answer": "red"}`
+ req := httptest.NewRequest("POST", "/api/tasks/answer-task-2/answer", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ }()
+
+ answer := <-ch
+ if answer != "red" {
+ t.Errorf("answer: want 'red', got %q", answer)
+ }
+}
+
+func TestAnswerQuestion_MissingQuestionID_Returns400(t *testing.T) {
+ srv, store := testServer(t)
+ createTaskWithState(t, store, "answer-task-3", task.StateRunning)
+
+ payload := `{"answer": "blue"}`
+ req := httptest.NewRequest("POST", "/api/tasks/answer-task-3/answer", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status: want 400, got %d", w.Code)
+ }
+}
+
+func TestHandleStartNextTask_Success(t *testing.T) {
+ dir := t.TempDir()
+ script := filepath.Join(dir, "start-next-task")
+ if err := os.WriteFile(script, []byte("#!/bin/sh\necho 'claudomator start abc-123'\n"), 0755); err != nil {
+ t.Fatal(err)
+ }
+
+ srv, _ := testServer(t)
+ srv.startNextTaskScript = script
+
+ req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]interface{}
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["output"] != "claudomator start abc-123\n" {
+ t.Errorf("unexpected output: %v", body["output"])
+ }
+ if body["exit_code"] != float64(0) {
+ t.Errorf("unexpected exit_code: %v", body["exit_code"])
+ }
+}
+
+func TestHandleStartNextTask_NoTask(t *testing.T) {
+ dir := t.TempDir()
+ script := filepath.Join(dir, "start-next-task")
+ if err := os.WriteFile(script, []byte("#!/bin/sh\necho 'No task to start.'\n"), 0755); err != nil {
+ t.Fatal(err)
+ }
+
+ srv, _ := testServer(t)
+ srv.startNextTaskScript = script
+
+ req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]interface{}
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["output"] != "No task to start.\n" {
+ t.Errorf("unexpected output: %v", body["output"])
+ }
+}
+
+func TestHandleStartNextTask_ScriptNotFound(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.startNextTaskScript = "/nonexistent/start-next-task"
+
+ req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("want 500, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
diff --git a/internal/executor/claude.go b/internal/executor/claude.go
index 7b3884c..0029331 100644
--- a/internal/executor/claude.go
+++ b/internal/executor/claude.go
@@ -10,6 +10,9 @@ import (
"os"
"os/exec"
"path/filepath"
+ "sync"
+ "syscall"
+ "time"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
@@ -31,71 +34,116 @@ func (r *ClaudeRunner) binaryPath() string {
}
// Run executes a claude -p invocation, streaming output to log files.
+// It retries up to 3 times on rate-limit errors using exponential backoff.
func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
args := r.buildArgs(t)
- cmd := exec.CommandContext(ctx, r.binaryPath(), args...)
- cmd.Env = append(os.Environ(),
- "CLAUDOMATOR_API_URL="+r.APIURL,
- "CLAUDOMATOR_TASK_ID="+t.ID,
- )
if t.Claude.WorkingDir != "" {
if _, err := os.Stat(t.Claude.WorkingDir); err != nil {
return fmt.Errorf("working_dir %q: %w", t.Claude.WorkingDir, err)
}
- cmd.Dir = t.Claude.WorkingDir
}
- // Setup log directory for this execution.
+ // Setup log directory once; retries overwrite the log files.
logDir := filepath.Join(r.LogDir, e.ID)
if err := os.MkdirAll(logDir, 0700); err != nil {
return fmt.Errorf("creating log dir: %w", err)
}
-
- stdoutPath := filepath.Join(logDir, "stdout.log")
- stderrPath := filepath.Join(logDir, "stderr.log")
- e.StdoutPath = stdoutPath
- e.StderrPath = stderrPath
+ e.StdoutPath = filepath.Join(logDir, "stdout.log")
+ e.StderrPath = filepath.Join(logDir, "stderr.log")
e.ArtifactDir = logDir
- stdoutFile, err := os.Create(stdoutPath)
+ attempt := 0
+ return runWithBackoff(ctx, 3, 5*time.Second, func() error {
+ if attempt > 0 {
+ delay := 5 * time.Second * (1 << (attempt - 1))
+ r.Logger.Warn("rate-limited by Claude API, retrying",
+ "attempt", attempt,
+ "delay", delay,
+ )
+ }
+ attempt++
+ return r.execOnce(ctx, t, args, e)
+ })
+}
+
+// execOnce runs the claude subprocess once, streaming output to e's log paths.
+func (r *ClaudeRunner) execOnce(ctx context.Context, t *task.Task, args []string, e *storage.Execution) error {
+ cmd := exec.CommandContext(ctx, r.binaryPath(), args...)
+ cmd.Env = append(os.Environ(),
+ "CLAUDOMATOR_API_URL="+r.APIURL,
+ "CLAUDOMATOR_TASK_ID="+t.ID,
+ )
+ // Put the subprocess in its own process group so we can SIGKILL the entire
+ // group (MCP servers, bash children, etc.) on cancellation.
+ cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+ if t.Claude.WorkingDir != "" {
+ cmd.Dir = t.Claude.WorkingDir
+ }
+
+ stdoutFile, err := os.Create(e.StdoutPath)
if err != nil {
return fmt.Errorf("creating stdout log: %w", err)
}
defer stdoutFile.Close()
- stderrFile, err := os.Create(stderrPath)
+ stderrFile, err := os.Create(e.StderrPath)
if err != nil {
return fmt.Errorf("creating stderr log: %w", err)
}
defer stderrFile.Close()
- stdoutPipe, err := cmd.StdoutPipe()
+ // Use os.Pipe for stdout so we own the read-end lifetime.
+ // cmd.StdoutPipe() would add the read-end to closeAfterWait, causing
+ // cmd.Wait() to close it before our goroutine finishes reading.
+ stdoutR, stdoutW, err := os.Pipe()
if err != nil {
return fmt.Errorf("creating stdout pipe: %w", err)
}
- stderrPipe, err := cmd.StderrPipe()
- if err != nil {
- return fmt.Errorf("creating stderr pipe: %w", err)
- }
+ cmd.Stdout = stdoutW // *os.File — not added to closeAfterStart/Wait
+ cmd.Stderr = stderrFile
if err := cmd.Start(); err != nil {
+ stdoutW.Close()
+ stdoutR.Close()
return fmt.Errorf("starting claude: %w", err)
}
+ // Close our write-end immediately; the subprocess holds its own copy.
+ // The goroutine below gets EOF when the subprocess exits.
+ stdoutW.Close()
+
+ // killDone is closed when cmd.Wait() returns, stopping the pgid-kill goroutine.
+ killDone := make(chan struct{})
+ go func() {
+ select {
+ case <-ctx.Done():
+ // SIGKILL the entire process group to reap orphan children.
+ syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
+ case <-killDone:
+ }
+ }()
- // Stream output to log files and parse cost info.
+ // Stream stdout to the log file and parse cost.
+ // wg ensures costUSD is fully written before we read it after cmd.Wait().
var costUSD float64
+ var wg sync.WaitGroup
+ wg.Add(1)
go func() {
- costUSD = streamAndParseCost(stdoutPipe, stdoutFile, r.Logger)
+ defer wg.Done()
+ costUSD = streamAndParseCost(stdoutR, stdoutFile, r.Logger)
+ stdoutR.Close()
}()
- go io.Copy(stderrFile, stderrPipe)
- if err := cmd.Wait(); err != nil {
- if exitErr, ok := err.(*exec.ExitError); ok {
+ waitErr := cmd.Wait()
+ close(killDone) // stop the pgid-kill goroutine
+ wg.Wait() // drain remaining stdout before reading costUSD
+
+ if waitErr != nil {
+ if exitErr, ok := waitErr.(*exec.ExitError); ok {
e.ExitCode = exitErr.ExitCode()
}
e.CostUSD = costUSD
- return fmt.Errorf("claude exited with error: %w", err)
+ return fmt.Errorf("claude exited with error: %w", waitErr)
}
e.ExitCode = 0
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index d25d3b4..51f468e 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -24,9 +24,10 @@ type Pool struct {
store *storage.DB
logger *slog.Logger
- mu sync.Mutex
- active int
- resultCh chan *Result
+ mu sync.Mutex
+ active int
+ resultCh chan *Result
+ Questions *QuestionRegistry
}
// Result is emitted when a task execution completes.
@@ -46,6 +47,7 @@ func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.L
store: store,
logger: logger,
resultCh: make(chan *Result, maxConcurrent*2),
+ Questions: NewQuestionRegistry(),
}
}
diff --git a/internal/executor/question.go b/internal/executor/question.go
new file mode 100644
index 0000000..9a2b55d
--- /dev/null
+++ b/internal/executor/question.go
@@ -0,0 +1,172 @@
+package executor
+
+import (
+ "bufio"
+ "encoding/json"
+ "io"
+ "log/slog"
+ "sync"
+)
+
+// QuestionHandler is called when an agent invokes AskUserQuestion.
+// Implementations should broadcast the question and block until an answer arrives.
+type QuestionHandler interface {
+ HandleQuestion(taskID, toolUseID string, input json.RawMessage) (string, error)
+}
+
+// PendingQuestion holds state for a question awaiting a user answer.
+type PendingQuestion struct {
+ TaskID string `json:"task_id"`
+ ToolUseID string `json:"tool_use_id"`
+ Input json.RawMessage `json:"input"`
+ AnswerCh chan string `json:"-"`
+}
+
+// QuestionRegistry tracks pending questions across running tasks.
+type QuestionRegistry struct {
+ mu sync.Mutex
+ questions map[string]*PendingQuestion // keyed by toolUseID
+}
+
+// NewQuestionRegistry creates a new registry.
+func NewQuestionRegistry() *QuestionRegistry {
+ return &QuestionRegistry{
+ questions: make(map[string]*PendingQuestion),
+ }
+}
+
+// Register adds a pending question and returns its answer channel.
+func (qr *QuestionRegistry) Register(taskID, toolUseID string, input json.RawMessage) chan string {
+ ch := make(chan string, 1)
+ qr.mu.Lock()
+ qr.questions[toolUseID] = &PendingQuestion{
+ TaskID: taskID,
+ ToolUseID: toolUseID,
+ Input: input,
+ AnswerCh: ch,
+ }
+ qr.mu.Unlock()
+ return ch
+}
+
+// Answer delivers an answer for a pending question. Returns false if no such question exists.
+func (qr *QuestionRegistry) Answer(toolUseID, answer string) bool {
+ qr.mu.Lock()
+ pq, ok := qr.questions[toolUseID]
+ if ok {
+ delete(qr.questions, toolUseID)
+ }
+ qr.mu.Unlock()
+ if !ok {
+ return false
+ }
+ pq.AnswerCh <- answer
+ return true
+}
+
+// Get returns a pending question by tool_use_id, or nil.
+func (qr *QuestionRegistry) Get(toolUseID string) *PendingQuestion {
+ qr.mu.Lock()
+ defer qr.mu.Unlock()
+ return qr.questions[toolUseID]
+}
+
+// PendingForTask returns all pending questions for a given task.
+func (qr *QuestionRegistry) PendingForTask(taskID string) []*PendingQuestion {
+ qr.mu.Lock()
+ defer qr.mu.Unlock()
+ var result []*PendingQuestion
+ for _, pq := range qr.questions {
+ if pq.TaskID == taskID {
+ result = append(result, pq)
+ }
+ }
+ return result
+}
+
+// Remove removes a question without answering it (e.g., on task cancellation).
+func (qr *QuestionRegistry) Remove(toolUseID string) {
+ qr.mu.Lock()
+ delete(qr.questions, toolUseID)
+ qr.mu.Unlock()
+}
+
+// extractAskUserQuestion parses a stream-json line and returns the tool_use_id and input
+// if the line is an assistant event containing an AskUserQuestion tool_use.
+func extractAskUserQuestion(line []byte) (string, json.RawMessage) {
+ var event struct {
+ Type string `json:"type"`
+ Message struct {
+ Content []struct {
+ Type string `json:"type"`
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Input json.RawMessage `json:"input"`
+ } `json:"content"`
+ } `json:"message"`
+ }
+ if err := json.Unmarshal(line, &event); err != nil {
+ return "", nil
+ }
+ if event.Type != "assistant" {
+ return "", nil
+ }
+ for _, block := range event.Message.Content {
+ if block.Type == "tool_use" && block.Name == "AskUserQuestion" {
+ return block.ID, block.Input
+ }
+ }
+ return "", nil
+}
+
+// streamAndParseWithQuestions reads streaming JSON, writes to w, parses cost,
+// and calls onQuestion for each detected AskUserQuestion tool_use.
+func streamAndParseWithQuestions(r io.Reader, w io.Writer, _ *slog.Logger, onQuestion func(string, json.RawMessage)) float64 {
+ tee := io.TeeReader(r, w)
+ scanner := bufio.NewScanner(tee)
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
+
+ var totalCost float64
+ for scanner.Scan() {
+ line := scanner.Bytes()
+
+ if toolUseID, input := extractAskUserQuestion(line); toolUseID != "" {
+ if onQuestion != nil {
+ onQuestion(toolUseID, input)
+ }
+ }
+
+ var msg map[string]interface{}
+ if err := json.Unmarshal(line, &msg); err != nil {
+ continue
+ }
+ if costData, ok := msg["cost_usd"]; ok {
+ if cost, ok := costData.(float64); ok {
+ totalCost = cost
+ }
+ }
+ }
+ return totalCost
+}
+
+// buildToolResultMessage builds a tool_result message to feed back to Claude
+// as the answer to an AskUserQuestion tool_use.
+func buildToolResultMessage(toolUseID, answer string) []byte {
+ answerJSON, _ := json.Marshal(map[string]interface{}{
+ "answers": map[string]string{"answer": answer},
+ })
+ msg := map[string]interface{}{
+ "message": map[string]interface{}{
+ "role": "user",
+ "content": []map[string]interface{}{
+ {
+ "type": "tool_result",
+ "tool_use_id": toolUseID,
+ "content": string(answerJSON),
+ },
+ },
+ },
+ }
+ result, _ := json.Marshal(msg)
+ return result
+}
diff --git a/internal/executor/question_test.go b/internal/executor/question_test.go
new file mode 100644
index 0000000..d0fbed9
--- /dev/null
+++ b/internal/executor/question_test.go
@@ -0,0 +1,253 @@
+package executor
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "log/slog"
+ "strings"
+ "testing"
+)
+
+func TestQuestionRegistry_RegisterAndAnswer(t *testing.T) {
+ qr := NewQuestionRegistry()
+
+ ch := qr.Register("task-1", "toolu_abc", json.RawMessage(`{"question":"color?"}`))
+
+ // Answer should unblock the channel.
+ go func() {
+ ok := qr.Answer("toolu_abc", "blue")
+ if !ok {
+ t.Error("Answer returned false, expected true")
+ }
+ }()
+
+ answer := <-ch
+ if answer != "blue" {
+ t.Errorf("want 'blue', got %q", answer)
+ }
+
+ // Question should be removed after answering.
+ if qr.Get("toolu_abc") != nil {
+ t.Error("question should be removed after answering")
+ }
+}
+
+func TestQuestionRegistry_AnswerUnknown(t *testing.T) {
+ qr := NewQuestionRegistry()
+ ok := qr.Answer("nonexistent", "anything")
+ if ok {
+ t.Error("expected false for unknown question")
+ }
+}
+
+func TestQuestionRegistry_PendingForTask(t *testing.T) {
+ qr := NewQuestionRegistry()
+ qr.Register("task-1", "toolu_1", json.RawMessage(`{}`))
+ qr.Register("task-1", "toolu_2", json.RawMessage(`{}`))
+ qr.Register("task-2", "toolu_3", json.RawMessage(`{}`))
+
+ pending := qr.PendingForTask("task-1")
+ if len(pending) != 2 {
+ t.Errorf("want 2 pending for task-1, got %d", len(pending))
+ }
+
+ pending2 := qr.PendingForTask("task-2")
+ if len(pending2) != 1 {
+ t.Errorf("want 1 pending for task-2, got %d", len(pending2))
+ }
+}
+
+func TestQuestionRegistry_Remove(t *testing.T) {
+ qr := NewQuestionRegistry()
+ qr.Register("task-1", "toolu_x", json.RawMessage(`{}`))
+ qr.Remove("toolu_x")
+ if qr.Get("toolu_x") != nil {
+ t.Error("question should be removed")
+ }
+}
+
+func TestExtractAskUserQuestion_DetectsQuestion(t *testing.T) {
+ // Simulate a stream-json assistant event containing an AskUserQuestion tool_use.
+ event := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_01ABC",
+ "name": "AskUserQuestion",
+ "input": map[string]interface{}{
+ "questions": []interface{}{
+ map[string]interface{}{
+ "question": "Which color?",
+ "header": "Color",
+ "options": []interface{}{
+ map[string]interface{}{"label": "red", "description": "Red color"},
+ map[string]interface{}{"label": "blue", "description": "Blue color"},
+ },
+ "multiSelect": false,
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "toolu_01ABC" {
+ t.Errorf("toolUseID: want 'toolu_01ABC', got %q", toolUseID)
+ }
+ if input == nil {
+ t.Fatal("input should not be nil")
+ }
+}
+
+func TestExtractAskUserQuestion_IgnoresOtherTools(t *testing.T) {
+ event := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_01XYZ",
+ "name": "Read",
+ "input": map[string]interface{}{"file_path": "/foo"},
+ },
+ },
+ },
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "" || input != nil {
+ t.Error("should not detect non-AskUserQuestion tool_use")
+ }
+}
+
+func TestExtractAskUserQuestion_IgnoresNonAssistant(t *testing.T) {
+ event := map[string]interface{}{
+ "type": "system",
+ "subtype": "init",
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "" || input != nil {
+ t.Error("should not detect from non-assistant events")
+ }
+}
+
+func TestStreamAndParseQuestions_DetectsQuestionAndCost(t *testing.T) {
+ // Build a stream with an assistant event containing AskUserQuestion and a result with cost.
+ assistantEvent := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_Q1",
+ "name": "AskUserQuestion",
+ "input": map[string]interface{}{
+ "questions": []interface{}{
+ map[string]interface{}{
+ "question": "Pick a number",
+ "header": "Num",
+ "options": []interface{}{
+ map[string]interface{}{"label": "1", "description": "One"},
+ map[string]interface{}{"label": "2", "description": "Two"},
+ },
+ "multiSelect": false,
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ resultEvent := map[string]interface{}{
+ "type": "result",
+ "cost_usd": 0.05,
+ }
+
+ var buf bytes.Buffer
+ json.NewEncoder(&buf).Encode(assistantEvent)
+ json.NewEncoder(&buf).Encode(resultEvent)
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ var questions []questionDetected
+ onQuestion := func(toolUseID string, input json.RawMessage) {
+ questions = append(questions, questionDetected{toolUseID, input})
+ }
+
+ cost := streamAndParseWithQuestions(strings.NewReader(buf.String()), io.Discard, logger, onQuestion)
+
+ if cost != 0.05 {
+ t.Errorf("cost: want 0.05, got %f", cost)
+ }
+ if len(questions) != 1 {
+ t.Fatalf("want 1 question detected, got %d", len(questions))
+ }
+ if questions[0].toolUseID != "toolu_Q1" {
+ t.Errorf("toolUseID: want 'toolu_Q1', got %q", questions[0].toolUseID)
+ }
+}
+
+type questionDetected struct {
+ toolUseID string
+ input json.RawMessage
+}
+
+func TestBuildToolResultMessage_Format(t *testing.T) {
+ msg := buildToolResultMessage("toolu_123", "blue")
+
+ var parsed map[string]interface{}
+ if err := json.Unmarshal(msg, &parsed); err != nil {
+ t.Fatalf("invalid JSON: %v", err)
+ }
+
+ // Should have type "user" with message containing tool_result
+ msgObj, ok := parsed["message"].(map[string]interface{})
+ if !ok {
+ t.Fatal("missing 'message' field")
+ }
+ content, ok := msgObj["content"].([]interface{})
+ if !ok || len(content) == 0 {
+ t.Fatal("missing content array")
+ }
+
+ block := content[0].(map[string]interface{})
+ if block["type"] != "tool_result" {
+ t.Errorf("type: want 'tool_result', got %v", block["type"])
+ }
+ if block["tool_use_id"] != "toolu_123" {
+ t.Errorf("tool_use_id: want 'toolu_123', got %v", block["tool_use_id"])
+ }
+
+ // The content should contain the answer JSON
+ resultContent, ok := block["content"].(string)
+ if !ok {
+ t.Fatal("content should be a string")
+ }
+ var answerData map[string]interface{}
+ if err := json.Unmarshal([]byte(resultContent), &answerData); err != nil {
+ t.Fatalf("answer content is not valid JSON: %v", err)
+ }
+ answers, ok := answerData["answers"].(map[string]interface{})
+ if !ok {
+ t.Fatal("missing answers in result content")
+ }
+ // At least one answer key should have the value "blue"
+ found := false
+ for _, v := range answers {
+ if v == "blue" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("expected 'blue' in answers, got %v", answers)
+ }
+}
diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go
new file mode 100644
index 0000000..884da43
--- /dev/null
+++ b/internal/executor/ratelimit.go
@@ -0,0 +1,76 @@
+package executor
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+)
+
+var retryAfterRe = regexp.MustCompile(`(?i)retry[-_ ]after[:\s]+(\d+)`)
+
+const maxBackoffDelay = 5 * time.Minute
+
+// isRateLimitError returns true if err looks like a Claude API rate-limit response.
+func isRateLimitError(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := strings.ToLower(err.Error())
+ return strings.Contains(msg, "rate limit") ||
+ strings.Contains(msg, "too many requests") ||
+ strings.Contains(msg, "429") ||
+ strings.Contains(msg, "overloaded")
+}
+
+// parseRetryAfter extracts a Retry-After duration from an error message.
+// Returns 0 if no retry-after value is found.
+func parseRetryAfter(msg string) time.Duration {
+ m := retryAfterRe.FindStringSubmatch(msg)
+ if m == nil {
+ return 0
+ }
+ secs, err := strconv.Atoi(m[1])
+ if err != nil || secs <= 0 {
+ return 0
+ }
+ return time.Duration(secs) * time.Second
+}
+
+// runWithBackoff calls fn repeatedly on rate-limit errors, using exponential backoff.
+// maxRetries is the max number of retry attempts (not counting the initial call).
+// baseDelay is the initial backoff duration (doubled each retry).
+func runWithBackoff(ctx context.Context, maxRetries int, baseDelay time.Duration, fn func() error) error {
+ var lastErr error
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ lastErr = fn()
+ if lastErr == nil {
+ return nil
+ }
+ if !isRateLimitError(lastErr) {
+ return lastErr
+ }
+ if attempt == maxRetries {
+ break
+ }
+
+ // Compute exponential backoff delay.
+ delay := baseDelay * (1 << attempt)
+ if delay > maxBackoffDelay {
+ delay = maxBackoffDelay
+ }
+ // Use Retry-After header value if present.
+ if ra := parseRetryAfter(lastErr.Error()); ra > 0 {
+ delay = ra
+ }
+
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("context cancelled during rate-limit backoff: %w", ctx.Err())
+ case <-time.After(delay):
+ }
+ }
+ return lastErr
+}
diff --git a/internal/executor/ratelimit_test.go b/internal/executor/ratelimit_test.go
new file mode 100644
index 0000000..f45216f
--- /dev/null
+++ b/internal/executor/ratelimit_test.go
@@ -0,0 +1,170 @@
+package executor
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+)
+
+// --- isRateLimitError tests ---
+
+func TestIsRateLimitError_RateLimitMessage(t *testing.T) {
+ err := errors.New("claude exited with error: rate limit exceeded")
+ if !isRateLimitError(err) {
+ t.Error("want true for 'rate limit exceeded', got false")
+ }
+}
+
+func TestIsRateLimitError_TooManyRequests(t *testing.T) {
+ err := errors.New("too many requests to the API")
+ if !isRateLimitError(err) {
+ t.Error("want true for 'too many requests', got false")
+ }
+}
+
+func TestIsRateLimitError_HTTP429(t *testing.T) {
+ err := errors.New("API returned status 429")
+ if !isRateLimitError(err) {
+ t.Error("want true for '429', got false")
+ }
+}
+
+func TestIsRateLimitError_Overloaded(t *testing.T) {
+ err := errors.New("API overloaded, please retry later")
+ if !isRateLimitError(err) {
+ t.Error("want true for 'overloaded', got false")
+ }
+}
+
+func TestIsRateLimitError_NonRateLimitError(t *testing.T) {
+ err := errors.New("claude exited with error: exit status 1")
+ if isRateLimitError(err) {
+ t.Error("want false for non-rate-limit error, got true")
+ }
+}
+
+func TestIsRateLimitError_NilError(t *testing.T) {
+ if isRateLimitError(nil) {
+ t.Error("want false for nil error, got true")
+ }
+}
+
+// --- parseRetryAfter tests ---
+
+func TestParseRetryAfter_RetryAfterSeconds(t *testing.T) {
+ msg := "rate limit exceeded, retry after 30 seconds"
+ d := parseRetryAfter(msg)
+ if d != 30*time.Second {
+ t.Errorf("want 30s, got %v", d)
+ }
+}
+
+func TestParseRetryAfter_RetryAfterHeader(t *testing.T) {
+ msg := "rate_limit_error: retry-after: 60"
+ d := parseRetryAfter(msg)
+ if d != 60*time.Second {
+ t.Errorf("want 60s, got %v", d)
+ }
+}
+
+func TestParseRetryAfter_NoRetryInfo(t *testing.T) {
+ msg := "rate limit exceeded"
+ d := parseRetryAfter(msg)
+ if d != 0 {
+ t.Errorf("want 0, got %v", d)
+ }
+}
+
+// --- runWithBackoff tests ---
+
+func TestRunWithBackoff_SuccessOnFirstTry(t *testing.T) {
+ calls := 0
+ fn := func() error {
+ calls++
+ return nil
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("want nil error, got %v", err)
+ }
+ if calls != 1 {
+ t.Errorf("want 1 call, got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_RetriesOnRateLimit(t *testing.T) {
+ calls := 0
+ fn := func() error {
+ calls++
+ if calls < 3 {
+ return fmt.Errorf("rate limit exceeded")
+ }
+ return nil
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("want nil error, got %v", err)
+ }
+ if calls != 3 {
+ t.Errorf("want 3 calls, got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_GivesUpAfterMaxRetries(t *testing.T) {
+ calls := 0
+ rateLimitErr := fmt.Errorf("rate limit exceeded")
+ fn := func() error {
+ calls++
+ return rateLimitErr
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err == nil {
+ t.Fatal("want error after max retries, got nil")
+ }
+ // maxRetries=3: 1 initial call + 3 retries = 4 total calls
+ if calls != 4 {
+ t.Errorf("want 4 calls (1 initial + 3 retries), got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_DoesNotRetryNonRateLimitError(t *testing.T) {
+ calls := 0
+ fn := func() error {
+ calls++
+ return fmt.Errorf("permission denied")
+ }
+ err := runWithBackoff(context.Background(), 3, time.Millisecond, fn)
+ if err == nil {
+ t.Fatal("want error, got nil")
+ }
+ if calls != 1 {
+ t.Errorf("want 1 call (no retry for non-rate-limit), got %d", calls)
+ }
+}
+
+func TestRunWithBackoff_ContextCancellation(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ calls := 0
+
+ fn := func() error {
+ calls++
+ cancel() // cancel immediately after first call
+ return fmt.Errorf("rate limit exceeded")
+ }
+
+ start := time.Now()
+ err := runWithBackoff(ctx, 3, time.Second, fn) // large delay confirms ctx preempts wait
+ elapsed := time.Since(start)
+
+ if err == nil {
+ t.Fatal("want error on context cancellation, got nil")
+ }
+ if elapsed > 500*time.Millisecond {
+ t.Errorf("context cancellation too slow: %v (want < 500ms)", elapsed)
+ }
+ if calls != 1 {
+ t.Errorf("want 1 call before cancellation, got %d", calls)
+ }
+}
diff --git a/web/app.js b/web/app.js
index 6289d00..6d2a029 100644
--- a/web/app.js
+++ b/web/app.js
@@ -69,20 +69,41 @@ function createTaskCard(task) {
card.appendChild(desc);
}
- // Footer: Run button (only for PENDING / FAILED)
- if (task.state === 'PENDING' || task.state === 'FAILED') {
+ // Footer: action buttons based on state
+ const RESTART_STATES = new Set(['FAILED', 'TIMED_OUT', 'CANCELLED']);
+ if (task.state === 'PENDING' || task.state === 'RUNNING' || RESTART_STATES.has(task.state)) {
const footer = document.createElement('div');
footer.className = 'task-card-footer';
- const btn = document.createElement('button');
- btn.className = 'btn-run';
- btn.textContent = 'Run';
- btn.addEventListener('click', (e) => {
- e.stopPropagation();
- handleRun(task.id, btn, footer);
- });
+ if (task.state === 'PENDING') {
+ const btn = document.createElement('button');
+ btn.className = 'btn-run';
+ btn.textContent = 'Run';
+ btn.addEventListener('click', (e) => {
+ e.stopPropagation();
+ handleRun(task.id, btn, footer);
+ });
+ footer.appendChild(btn);
+ } else if (task.state === 'RUNNING') {
+ const btn = document.createElement('button');
+ btn.className = 'btn-cancel';
+ btn.textContent = 'Cancel';
+ btn.addEventListener('click', (e) => {
+ e.stopPropagation();
+ handleCancel(task.id, btn, footer);
+ });
+ footer.appendChild(btn);
+ } else if (RESTART_STATES.has(task.state)) {
+ const btn = document.createElement('button');
+ btn.className = 'btn-restart';
+ btn.textContent = 'Restart';
+ btn.addEventListener('click', (e) => {
+ e.stopPropagation();
+ handleRestart(task.id, btn, footer);
+ });
+ footer.appendChild(btn);
+ }
- footer.appendChild(btn);
card.appendChild(footer);
}
@@ -252,6 +273,66 @@ async function handleRun(taskId, btn, footer) {
}
}
+// ── Cancel / Restart actions ──────────────────────────────────────────────────
+
+async function cancelTask(taskId) {
+ const res = await fetch(`${API_BASE}/api/tasks/${taskId}/cancel`, { method: 'POST' });
+ if (!res.ok) {
+ let msg = `HTTP ${res.status}`;
+ try { const body = await res.json(); msg = body.error || body.message || msg; } catch {}
+ throw new Error(msg);
+ }
+ return res.json();
+}
+
+async function restartTask(taskId) {
+ const res = await fetch(`${API_BASE}/api/tasks/${taskId}/restart`, { method: 'POST' });
+ if (!res.ok) {
+ let msg = `HTTP ${res.status}`;
+ try { const body = await res.json(); msg = body.error || body.message || msg; } catch {}
+ throw new Error(msg);
+ }
+ return res.json();
+}
+
+async function handleCancel(taskId, btn, footer) {
+ btn.disabled = true;
+ btn.textContent = 'Cancelling…';
+ const prev = footer.querySelector('.task-error');
+ if (prev) prev.remove();
+
+ try {
+ await cancelTask(taskId);
+ await poll();
+ } catch (err) {
+ btn.disabled = false;
+ btn.textContent = 'Cancel';
+ const errEl = document.createElement('span');
+ errEl.className = 'task-error';
+ errEl.textContent = `Failed: ${err.message}`;
+ footer.appendChild(errEl);
+ }
+}
+
+async function handleRestart(taskId, btn, footer) {
+ btn.disabled = true;
+ btn.textContent = 'Restarting…';
+ const prev = footer.querySelector('.task-error');
+ if (prev) prev.remove();
+
+ try {
+ await restartTask(taskId);
+ await poll();
+ } catch (err) {
+ btn.disabled = false;
+ btn.textContent = 'Restart';
+ const errEl = document.createElement('span');
+ errEl.className = 'task-error';
+ errEl.textContent = `Failed: ${err.message}`;
+ footer.appendChild(errEl);
+ }
+}
+
// ── Delete template ────────────────────────────────────────────────────────────
async function deleteTemplate(id) {
@@ -286,6 +367,147 @@ function startPolling(intervalMs = 10_000) {
setInterval(poll, intervalMs);
}
+
+
+// ── WebSocket (real-time events) ──────────────────────────────────────────────
+
+let ws = null;
+let activeLogSource = null;
+
+function connectWebSocket() {
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
+ const url = `${protocol}//${window.location.host}${BASE_PATH}/api/ws`;
+ ws = new WebSocket(url);
+
+ ws.onmessage = (event) => {
+ try {
+ const data = JSON.parse(event.data);
+ handleWsEvent(data);
+ } catch { /* ignore parse errors */ }
+ };
+
+ ws.onclose = () => {
+ // Reconnect after 3 seconds.
+ setTimeout(connectWebSocket, 3000);
+ };
+
+ ws.onerror = () => {
+ ws.close();
+ };
+}
+
+function handleWsEvent(data) {
+ switch (data.type) {
+ case 'task_completed':
+ poll(); // refresh task list
+ break;
+ case 'task_question':
+ showQuestionBanner(data);
+ break;
+ }
+}
+
+// ── Question UI ───────────────────────────────────────────────────────────────
+
+function showQuestionBanner(data) {
+ const taskId = data.task_id;
+ const questionId = data.question_id;
+ const questionData = data.data || {};
+ const questions = questionData.questions || [];
+
+ // Find the task card for this task.
+ const card = document.querySelector(`.task-card[data-task-id="${taskId}"]`);
+ if (!card) return;
+
+ // Remove any existing question banner on this card.
+ const existing = card.querySelector('.question-banner');
+ if (existing) existing.remove();
+
+ const banner = document.createElement('div');
+ banner.className = 'question-banner';
+
+ for (const q of questions) {
+ const qDiv = document.createElement('div');
+ qDiv.className = 'question-item';
+
+ const label = document.createElement('div');
+ label.className = 'question-text';
+ label.textContent = q.question || 'The agent has a question';
+ qDiv.appendChild(label);
+
+ const options = q.options || [];
+ if (options.length > 0) {
+ const btnGroup = document.createElement('div');
+ btnGroup.className = 'question-options';
+ for (const opt of options) {
+ const btn = document.createElement('button');
+ btn.className = 'btn-question-option';
+ btn.textContent = opt.label;
+ if (opt.description) btn.title = opt.description;
+ btn.addEventListener('click', () => {
+ submitAnswer(taskId, questionId, opt.label, banner);
+ });
+ btnGroup.appendChild(btn);
+ }
+ qDiv.appendChild(btnGroup);
+ }
+
+ // Always show a free-text input as fallback.
+ const inputRow = document.createElement('div');
+ inputRow.className = 'question-input-row';
+ const input = document.createElement('input');
+ input.type = 'text';
+ input.className = 'question-input';
+ input.placeholder = 'Type an answer…';
+ const sendBtn = document.createElement('button');
+ sendBtn.className = 'btn-question-send';
+ sendBtn.textContent = 'Send';
+ sendBtn.addEventListener('click', () => {
+ const val = input.value.trim();
+ if (val) submitAnswer(taskId, questionId, val, banner);
+ });
+ input.addEventListener('keydown', (e) => {
+ if (e.key === 'Enter') {
+ const val = input.value.trim();
+ if (val) submitAnswer(taskId, questionId, val, banner);
+ }
+ });
+ inputRow.append(input, sendBtn);
+ qDiv.appendChild(inputRow);
+
+ banner.appendChild(qDiv);
+ }
+
+ card.appendChild(banner);
+}
+
+async function submitAnswer(taskId, questionId, answer, banner) {
+ // Disable all buttons in the banner.
+ banner.querySelectorAll('button').forEach(b => { b.disabled = true; });
+ banner.querySelector('.question-input')?.setAttribute('disabled', '');
+
+ try {
+ const res = await fetch(`${API_BASE}/api/tasks/${taskId}/answer`, {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({ question_id: questionId, answer }),
+ });
+ if (!res.ok) {
+ const body = await res.json().catch(() => ({}));
+ throw new Error(body.error || `HTTP ${res.status}`);
+ }
+ banner.remove();
+ } catch (err) {
+ const errEl = document.createElement('div');
+ errEl.className = 'question-error';
+ errEl.textContent = `Failed: ${err.message}`;
+ banner.appendChild(errEl);
+ // Re-enable buttons.
+ banner.querySelectorAll('button').forEach(b => { b.disabled = false; });
+ banner.querySelector('.question-input')?.removeAttribute('disabled');
+ }
+}
+
// ── Elaborate (Draft with AI) ─────────────────────────────────────────────────
async function elaborateTask(prompt) {
@@ -302,6 +524,86 @@ async function elaborateTask(prompt) {
return res.json();
}
+// ── Validate ──────────────────────────────────────────────────────────────────
+
+async function validateTask(payload) {
+ const res = await fetch(`${API_BASE}/api/tasks/validate`, {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify(payload),
+ });
+ if (!res.ok) {
+ let msg = res.statusText;
+ try { const body = await res.json(); msg = body.error || body.message || msg; } catch {}
+ throw new Error(msg);
+ }
+ return res.json();
+}
+
+function buildValidatePayload() {
+ const f = document.getElementById('task-form');
+ const name = f.querySelector('[name="name"]').value;
+ const instructions = f.querySelector('[name="instructions"]').value;
+ const working_dir = f.querySelector('[name="working_dir"]').value;
+ const model = f.querySelector('[name="model"]').value;
+ const allowedToolsEl = f.querySelector('[name="allowed_tools"]');
+ const allowed_tools = allowedToolsEl
+ ? allowedToolsEl.value.split(',').map(s => s.trim()).filter(Boolean)
+ : [];
+ return { name, claude: { instructions, working_dir, model, allowed_tools } };
+}
+
+function renderValidationResult(result) {
+ const container = document.getElementById('validate-result');
+ container.removeAttribute('hidden');
+ container.dataset.clarity = result.clarity;
+
+ let icon;
+ if (result.ready === true) {
+ icon = '✓';
+ } else if (result.clarity === 'ambiguous') {
+ icon = '⚠';
+ } else {
+ icon = '✗';
+ }
+
+ container.innerHTML = '';
+
+ const header = document.createElement('div');
+ header.className = 'validate-header';
+ const iconSpan = document.createElement('span');
+ iconSpan.className = 'validate-icon';
+ iconSpan.textContent = icon;
+ const summarySpan = document.createElement('span');
+ summarySpan.textContent = ' ' + (result.summary || '');
+ header.append(iconSpan, summarySpan);
+ container.appendChild(header);
+
+ if (result.questions && result.questions.length > 0) {
+ const ul = document.createElement('ul');
+ ul.className = 'validate-questions';
+ for (const q of result.questions) {
+ const li = document.createElement('li');
+ li.className = q.severity === 'blocking' ? 'validate-blocking' : 'validate-minor';
+ li.textContent = q.text;
+ ul.appendChild(li);
+ }
+ container.appendChild(ul);
+ }
+
+ if (result.suggestions && result.suggestions.length > 0) {
+ const ul = document.createElement('ul');
+ ul.className = 'validate-suggestions';
+ for (const s of result.suggestions) {
+ const li = document.createElement('li');
+ li.className = 'validate-suggestion';
+ li.textContent = s;
+ ul.appendChild(li);
+ }
+ container.appendChild(ul);
+ }
+}
+
// ── Task modal ────────────────────────────────────────────────────────────────
function openTaskModal() {
@@ -312,6 +614,10 @@ function closeTaskModal() {
document.getElementById('task-modal').close();
document.getElementById('task-form').reset();
document.getElementById('elaborate-prompt').value = '';
+ const validateResult = document.getElementById('validate-result');
+ validateResult.setAttribute('hidden', '');
+ validateResult.innerHTML = '';
+ validateResult.removeAttribute('data-clarity');
}
async function createTask(formData) {
@@ -671,6 +977,127 @@ async function handleViewLogs(execId) {
}
}
+// ── Log viewer ────────────────────────────────────────────────────────────────
+
+function openLogViewer(execId, containerEl) {
+ // Save original children so Back can restore them (with event listeners intact)
+ const originalChildren = [...containerEl.childNodes];
+
+ containerEl.innerHTML = '';
+
+ const viewer = document.createElement('div');
+ viewer.className = 'log-viewer';
+
+ // Back button
+ const backBtn = document.createElement('button');
+ backBtn.className = 'log-back-btn';
+ backBtn.textContent = '← Back';
+ backBtn.addEventListener('click', () => {
+ closeLogViewer();
+ containerEl.innerHTML = '';
+ for (const node of originalChildren) containerEl.appendChild(node);
+ });
+ viewer.appendChild(backBtn);
+
+ // Pulsing status indicator
+ const statusEl = document.createElement('div');
+ statusEl.className = 'log-status-indicator';
+ statusEl.textContent = 'Streaming...';
+ viewer.appendChild(statusEl);
+
+ // Log output area
+ const logOutput = document.createElement('div');
+ logOutput.className = 'log-output';
+ logOutput.style.fontFamily = 'monospace';
+ logOutput.style.overflowY = 'auto';
+ logOutput.style.maxHeight = '400px';
+ viewer.appendChild(logOutput);
+
+ containerEl.appendChild(viewer);
+
+ let userScrolled = false;
+ logOutput.addEventListener('scroll', () => {
+ const nearBottom = logOutput.scrollHeight - logOutput.scrollTop - logOutput.clientHeight < 50;
+ if (!nearBottom) userScrolled = true;
+ });
+
+ const source = new EventSource(`${API_BASE}/api/executions/${execId}/logs/stream`);
+ activeLogSource = source;
+
+ source.onmessage = (event) => {
+ let data;
+ try { data = JSON.parse(event.data); } catch { return; }
+
+ const line = document.createElement('div');
+ line.className = 'log-line';
+
+ switch (data.type) {
+ case 'text': {
+ line.classList.add('log-text');
+ line.textContent = data.text ?? data.content ?? '';
+ break;
+ }
+ case 'tool_use': {
+ line.classList.add('log-tool-use');
+ const toolName = document.createElement('span');
+ toolName.className = 'tool-name';
+ toolName.textContent = `[${data.name ?? 'Tool'}]`;
+ line.appendChild(toolName);
+ const inputStr = data.input ? JSON.stringify(data.input) : '';
+ const inputPreview = document.createElement('span');
+ inputPreview.textContent = ' ' + inputStr.slice(0, 120);
+ line.appendChild(inputPreview);
+ break;
+ }
+ case 'tool_result': {
+ line.classList.add('log-tool-result');
+ line.style.opacity = '0.6';
+ const content = Array.isArray(data.content)
+ ? data.content.map(c => c.text ?? '').join(' ')
+ : (data.content ?? '');
+ line.textContent = String(content).slice(0, 120);
+ break;
+ }
+ case 'cost': {
+ line.classList.add('log-cost');
+ const cost = data.total_cost ?? data.cost ?? 0;
+ line.textContent = `Cost: $${Number(cost).toFixed(3)}`;
+ break;
+ }
+ default:
+ return;
+ }
+
+ logOutput.appendChild(line);
+ if (!userScrolled) {
+ logOutput.scrollTop = logOutput.scrollHeight;
+ }
+ };
+
+ source.addEventListener('done', () => {
+ source.close();
+ activeLogSource = null;
+ userScrolled = false;
+ statusEl.classList.remove('log-status-indicator');
+ statusEl.textContent = 'Stream complete';
+ });
+
+ source.onerror = () => {
+ source.close();
+ activeLogSource = null;
+ statusEl.hidden = true;
+ const errEl = document.createElement('div');
+ errEl.className = 'log-line log-error';
+ errEl.textContent = 'Connection error. Stream closed.';
+ logOutput.appendChild(errEl);
+ };
+}
+
+function closeLogViewer() {
+ activeLogSource?.close();
+ activeLogSource = null;
+}
+
// ── Tab switching ─────────────────────────────────────────────────────────────
function switchTab(name) {
@@ -711,6 +1138,7 @@ document.addEventListener('DOMContentLoaded', () => {
});
startPolling();
+ connectWebSocket();
// Side panel close
document.getElementById('btn-close-panel').addEventListener('click', closeTaskPanel);
@@ -730,6 +1158,25 @@ document.addEventListener('DOMContentLoaded', () => {
document.getElementById('btn-new-task').addEventListener('click', openTaskModal);
document.getElementById('btn-cancel-task').addEventListener('click', closeTaskModal);
+ // Validate button
+ document.getElementById('btn-validate').addEventListener('click', async () => {
+ const btn = document.getElementById('btn-validate');
+ const resultDiv = document.getElementById('validate-result');
+ btn.disabled = true;
+ btn.textContent = 'Checking…';
+ try {
+ const payload = buildValidatePayload();
+ const result = await validateTask(payload);
+ renderValidationResult(result);
+ } catch (err) {
+ resultDiv.removeAttribute('hidden');
+ resultDiv.textContent = 'Validation failed: ' + err.message;
+ } finally {
+ btn.disabled = false;
+ btn.textContent = 'Validate Instructions';
+ }
+ });
+
// Draft with AI button
const btnElaborate = document.getElementById('btn-elaborate');
btnElaborate.addEventListener('click', async () => {
@@ -782,6 +1229,14 @@ document.addEventListener('DOMContentLoaded', () => {
banner.className = 'elaborate-banner';
banner.textContent = 'AI draft ready — review and submit.';
document.getElementById('task-form').querySelector('.elaborate-section').appendChild(banner);
+
+ // Auto-validate after elaboration
+ try {
+ const result = await validateTask(buildValidatePayload());
+ renderValidationResult(result);
+ } catch (_) {
+ // silent - elaboration already succeeded, validation is bonus
+ }
} catch (err) {
const errEl = document.createElement('p');
errEl.className = 'form-error';
@@ -805,6 +1260,12 @@ document.addEventListener('DOMContentLoaded', () => {
btn.textContent = 'Creating…';
try {
+ const validateResult = document.getElementById('validate-result');
+ if (!validateResult.hasAttribute('hidden') && validateResult.dataset.clarity && validateResult.dataset.clarity !== 'clear') {
+ if (!window.confirm('The validator flagged issues. Create task anyway?')) {
+ return;
+ }
+ }
await createTask(new FormData(e.target));
} catch (err) {
const errEl = document.createElement('p');
diff --git a/web/index.html b/web/index.html
index 6d7f23b..482b9a9 100644
--- a/web/index.html
+++ b/web/index.html
@@ -50,6 +50,12 @@
<hr class="form-divider">
<label>Name <input name="name" required></label>
<label>Instructions <textarea name="instructions" rows="6" required></textarea></label>
+ <div class="validate-section">
+ <button type="button" id="btn-validate" class="btn-secondary">
+ Validate Instructions
+ </button>
+ <div id="validate-result" hidden></div>
+ </div>
<label>Working Directory <input name="working_dir" placeholder="/path/to/repo"></label>
<label>Model <input name="model" value="sonnet"></label>
<label>Max Budget (USD) <input name="max_budget_usd" type="number" step="0.01" value="1.00"></label>
diff --git a/web/style.css b/web/style.css
index de8ce83..268f80c 100644
--- a/web/style.css
+++ b/web/style.css
@@ -225,6 +225,40 @@ main {
cursor: not-allowed;
}
+.btn-cancel {
+ font-size: 0.8rem;
+ font-weight: 600;
+ padding: 0.35em 0.85em;
+ border-radius: 0.375rem;
+ border: none;
+ cursor: pointer;
+ background: var(--state-failed);
+ color: #fff;
+ transition: opacity 0.15s;
+}
+
+.btn-cancel:disabled {
+ opacity: 0.5;
+ cursor: not-allowed;
+}
+
+.btn-restart {
+ font-size: 0.8rem;
+ font-weight: 600;
+ padding: 0.35em 0.85em;
+ border-radius: 0.375rem;
+ border: none;
+ cursor: pointer;
+ background: var(--text-muted);
+ color: #0f172a;
+ transition: opacity 0.15s;
+}
+
+.btn-restart:disabled {
+ opacity: 0.5;
+ cursor: not-allowed;
+}
+
.task-error {
font-size: 0.78rem;
color: var(--state-failed);
@@ -704,3 +738,231 @@ dialog label select:focus {
.logs-modal-body .meta-grid {
row-gap: 0.625rem;
}
+
+/* ── Question banner ───────────────────────────────────────────────────────── */
+
+.question-banner {
+ margin-top: 0.75rem;
+ padding: 0.75rem;
+ background: #1e293b;
+ border: 1px solid #f59e0b;
+ border-radius: 0.5rem;
+}
+
+.question-item {
+ display: flex;
+ flex-direction: column;
+ gap: 0.5rem;
+}
+
+.question-text {
+ font-weight: 600;
+ color: #f59e0b;
+ font-size: 0.9rem;
+}
+
+.question-options {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 0.375rem;
+}
+
+.btn-question-option {
+ padding: 0.375rem 0.75rem;
+ border: 1px solid #475569;
+ border-radius: 0.375rem;
+ background: #334155;
+ color: #e2e8f0;
+ cursor: pointer;
+ font-size: 0.8rem;
+ transition: background 0.15s, border-color 0.15s;
+}
+
+.btn-question-option:hover:not(:disabled) {
+ background: #475569;
+ border-color: #f59e0b;
+}
+
+.btn-question-option:disabled {
+ opacity: 0.5;
+ cursor: not-allowed;
+}
+
+.question-input-row {
+ display: flex;
+ gap: 0.375rem;
+}
+
+.question-input {
+ flex: 1;
+ padding: 0.375rem 0.5rem;
+ border: 1px solid #475569;
+ border-radius: 0.375rem;
+ background: #0f172a;
+ color: #e2e8f0;
+ font-size: 0.8rem;
+}
+
+.question-input:focus {
+ outline: none;
+ border-color: #f59e0b;
+}
+
+.btn-question-send {
+ padding: 0.375rem 0.75rem;
+ border: 1px solid #f59e0b;
+ border-radius: 0.375rem;
+ background: #f59e0b;
+ color: #0f172a;
+ font-weight: 600;
+ cursor: pointer;
+ font-size: 0.8rem;
+}
+
+.btn-question-send:hover:not(:disabled) {
+ background: #fbbf24;
+}
+
+.btn-question-send:disabled {
+ opacity: 0.5;
+ cursor: not-allowed;
+}
+
+.question-error {
+ color: #f87171;
+ font-size: 0.8rem;
+ margin-top: 0.25rem;
+}
+
+/* ── Log Viewer ──────────────────────────────────────────────────────────── */
+
+.log-viewer {
+ width: 100%;
+ padding: 0;
+}
+
+.log-back-btn {
+ font-size: 0.78rem;
+ font-weight: 600;
+ padding: 0.3em 0.75em;
+ border-radius: 0.375rem;
+ border: 1px solid var(--border);
+ background: transparent;
+ color: var(--text-muted);
+ cursor: pointer;
+ transition: background 0.15s, color 0.15s;
+ margin-bottom: 1rem;
+ display: inline-flex;
+ align-items: center;
+}
+
+.log-back-btn:hover {
+ background: var(--border);
+ color: var(--text);
+}
+
+@keyframes pulse-dot {
+ 0%, 100% { opacity: 1; }
+ 50% { opacity: 0.3; }
+}
+
+.log-status-indicator {
+ display: flex;
+ align-items: center;
+ gap: 0.4rem;
+ font-size: 0.78rem;
+ color: var(--text-muted);
+ margin-bottom: 0.75rem;
+}
+
+.log-status-indicator::before {
+ content: '';
+ display: inline-block;
+ width: 6px;
+ height: 6px;
+ border-radius: 50%;
+ background: var(--state-running);
+ flex-shrink: 0;
+ animation: pulse-dot 1.4s ease-in-out infinite;
+}
+
+.log-output {
+ font-family: monospace;
+ font-size: 0.8rem;
+ overflow-y: auto;
+ max-height: 400px;
+ background: var(--bg);
+ padding: 0.625rem 0.75rem;
+ border-radius: 0.375rem;
+ border: 1px solid var(--border);
+}
+
+.log-line {
+ padding: 2px 0;
+ line-height: 1.5;
+}
+
+.log-text {
+ color: var(--text);
+}
+
+.log-tool-use {
+ background: rgba(56, 189, 248, 0.1);
+ padding: 4px 8px;
+ border-radius: 3px;
+ margin: 2px 0;
+}
+
+.tool-name {
+ color: var(--accent);
+ font-weight: bold;
+ margin-right: 6px;
+}
+
+.log-tool-result {
+ color: var(--text-muted);
+ opacity: 0.6;
+}
+
+.log-cost {
+ color: var(--state-running);
+ font-weight: bold;
+ margin-top: 8px;
+}
+
+/* ── Validate section ────────────────────────────────────────────────────── */
+
+.validate-section {
+ margin-top: 8px;
+}
+
+#validate-result {
+ border-left: 3px solid transparent;
+ padding: 8px 12px;
+ margin-top: 8px;
+ font-size: 0.85rem;
+}
+
+#validate-result[data-clarity="clear"] {
+ border-color: var(--state-completed);
+}
+
+#validate-result[data-clarity="ambiguous"] {
+ border-color: var(--state-running);
+}
+
+#validate-result[data-clarity="unclear"] {
+ border-color: var(--state-failed);
+}
+
+.validate-blocking {
+ color: var(--state-failed);
+}
+
+.validate-minor {
+ color: var(--state-running);
+}
+
+.validate-suggestion {
+ color: #94a3b8;
+}
diff --git a/web/test/start-next-task.test.mjs b/web/test/start-next-task.test.mjs
new file mode 100644
index 0000000..6863f7e
--- /dev/null
+++ b/web/test/start-next-task.test.mjs
@@ -0,0 +1,84 @@
+// start-next-task.test.mjs — contract tests for startNextTask fetch helper
+// Run: node --test web/test/start-next-task.test.mjs
+
+import { describe, it } from 'node:test';
+import assert from 'node:assert/strict';
+
+// ── Contract: startNextTask(basePath, fetchFn) ─────────────────────────────────
+// POSTs to ${basePath}/api/scripts/start-next-task
+// Returns {output, exit_code} on HTTP 2xx
+// Throws on HTTP error
+
+async function startNextTask(basePath, fetchFn) {
+ const res = await fetchFn(`${basePath}/api/scripts/start-next-task`, { method: 'POST' });
+ if (!res.ok) {
+ let msg = `HTTP ${res.status}`;
+ try { const body = await res.json(); msg = body.error || msg; } catch {}
+ throw new Error(msg);
+ }
+ return res.json();
+}
+
+describe('startNextTask', () => {
+ it('POSTs to /api/scripts/start-next-task', async () => {
+ let captured = null;
+ const mockFetch = (url, opts) => {
+ captured = { url, opts };
+ return Promise.resolve({
+ ok: true,
+ json: () => Promise.resolve({ output: 'claudomator start abc-123\n', exit_code: 0 }),
+ });
+ };
+
+ await startNextTask('http://localhost:8484', mockFetch);
+ assert.equal(captured.url, 'http://localhost:8484/api/scripts/start-next-task');
+ assert.equal(captured.opts.method, 'POST');
+ });
+
+ it('returns output and exit_code on success', async () => {
+ const mockFetch = () => Promise.resolve({
+ ok: true,
+ json: () => Promise.resolve({ output: 'claudomator start abc-123\n', exit_code: 0 }),
+ });
+
+ const result = await startNextTask('', mockFetch);
+ assert.equal(result.output, 'claudomator start abc-123\n');
+ assert.equal(result.exit_code, 0);
+ });
+
+ it('returns output when no task available', async () => {
+ const mockFetch = () => Promise.resolve({
+ ok: true,
+ json: () => Promise.resolve({ output: 'No task to start.\n', exit_code: 0 }),
+ });
+
+ const result = await startNextTask('', mockFetch);
+ assert.equal(result.output, 'No task to start.\n');
+ });
+
+ it('throws with server error message on HTTP error', async () => {
+ const mockFetch = () => Promise.resolve({
+ ok: false,
+ status: 500,
+ json: () => Promise.resolve({ error: 'script execution failed' }),
+ });
+
+ await assert.rejects(
+ () => startNextTask('', mockFetch),
+ /script execution failed/,
+ );
+ });
+
+ it('throws with HTTP status on non-JSON error response', async () => {
+ const mockFetch = () => Promise.resolve({
+ ok: false,
+ status: 503,
+ json: () => Promise.reject(new Error('not json')),
+ });
+
+ await assert.rejects(
+ () => startNextTask('', mockFetch),
+ /HTTP 503/,
+ );
+ });
+});
diff --git a/web/test/task-actions.test.mjs b/web/test/task-actions.test.mjs
new file mode 100644
index 0000000..f2c21c4
--- /dev/null
+++ b/web/test/task-actions.test.mjs
@@ -0,0 +1,53 @@
+// task-actions.test.mjs — button visibility logic for Cancel/Restart actions
+//
+// Run with: node --test web/test/task-actions.test.mjs
+
+import { describe, it } from 'node:test';
+import assert from 'node:assert/strict';
+
+// ── Logic under test ──────────────────────────────────────────────────────────
+
+const RESTART_STATES = new Set(['FAILED', 'TIMED_OUT', 'CANCELLED']);
+
+function getCardAction(state) {
+ if (state === 'PENDING') return 'run';
+ if (state === 'RUNNING') return 'cancel';
+ if (RESTART_STATES.has(state)) return 'restart';
+ return null;
+}
+
+// ── Tests ─────────────────────────────────────────────────────────────────────
+
+describe('task card action buttons', () => {
+ it('shows Run button for PENDING', () => {
+ assert.equal(getCardAction('PENDING'), 'run');
+ });
+
+ it('shows Cancel button for RUNNING', () => {
+ assert.equal(getCardAction('RUNNING'), 'cancel');
+ });
+
+ it('shows Restart button for FAILED', () => {
+ assert.equal(getCardAction('FAILED'), 'restart');
+ });
+
+ it('shows Restart button for TIMED_OUT', () => {
+ assert.equal(getCardAction('TIMED_OUT'), 'restart');
+ });
+
+ it('shows Restart button for CANCELLED', () => {
+ assert.equal(getCardAction('CANCELLED'), 'restart');
+ });
+
+ it('shows no button for COMPLETED', () => {
+ assert.equal(getCardAction('COMPLETED'), null);
+ });
+
+ it('shows no button for QUEUED', () => {
+ assert.equal(getCardAction('QUEUED'), null);
+ });
+
+ it('shows no button for BUDGET_EXCEEDED', () => {
+ assert.equal(getCardAction('BUDGET_EXCEEDED'), null);
+ });
+});