diff options
50 files changed, 2945 insertions, 255 deletions
diff --git a/docs/adr/003-security-model.md b/docs/adr/003-security-model.md new file mode 100644 index 0000000..529e50e --- /dev/null +++ b/docs/adr/003-security-model.md @@ -0,0 +1,135 @@ +# ADR-003: Security Model + +## Status +Accepted + +## Context + +Claudomator is a local developer tool: it runs on the developer's own machine, +accepts tasks from the operator (a human or automation they control), and executes +`claude` subprocesses with filesystem access. The primary deployment model is +single-user, single-machine — not a shared or internet-facing service. + +This ADR documents the current security posture, the explicit trust boundary +assumptions, and known risks surfaced during code review (2026-03-08). + +## Trust Boundary + +``` +[ Operator (human / script) ] + │ loopback or LAN + ▼ + [ claudomator HTTP API :8484 ] + │ + [ claude subprocess (bypassPermissions) ] + │ + [ local filesystem, tools ] +``` + +**Trusted:** The operator and all callers of the HTTP API. Any caller can create, +run, delete, and cancel tasks; execute server-side scripts; and read all logs. + +**Untrusted:** The content that Claude processes (web pages, code repos, user +instructions containing adversarial prompts). The tool makes no attempt to sandbox +Claude's output before acting on it. + +## Explicit Design Decisions + +### No Authentication or Authorization + +**Decision:** The HTTP API has no auth middleware. All routes are publicly +accessible to any network-reachable client. + +**Rationale:** Claudomator is a local tool. Adding auth would impose key management +overhead on a single-user workflow. The operator is assumed to control who can +reach `:8484`. + +**Risk:** If the server is exposed on a non-loopback interface (the current default), +any host on the LAN — or the public internet if port-forwarded — can: +- Create and run arbitrary tasks +- Trigger `POST /api/scripts/deploy` and `POST /api/scripts/start-next-task` + (which run shell scripts on the server with no additional gate) +- Read all execution logs and task data via WebSocket or REST + +**Mitigation until auth is added:** Run behind a firewall or bind to `127.0.0.1` +only. The `addr` config key controls the listen address. + +### Permissive CORS (`Access-Control-Allow-Origin: *`) + +**Decision:** All responses include `Access-Control-Allow-Origin: *`. + +**Rationale:** Allows the web UI to be opened from any origin (file://, localhost +variants) during development. Consistent with no-auth posture: if there is no +auth, CORS restriction adds no real security. + +**Risk:** Combined with no auth, any web page the operator visits can issue +cross-origin requests to `localhost:8484` if the browser is on the same machine. + +### `bypassPermissions` as Default Permission Mode + +**Decision:** `executor.ClaudeRunner` defaults `permission_mode` to +`bypassPermissions` when the task does not specify one. + +**Rationale:** The typical claudomator workflow is fully automated, running inside +a container or a dedicated workspace. Stopping for tool-use confirmations would +break unattended execution. The operator is assumed to have reviewed the task +instructions before submission. + +**Risk:** Claude subprocesses run without any tool-use confirmation prompt. +A malicious or miscrafted task instruction can cause Claude to execute arbitrary +shell commands, delete files, or make network requests without operator review. + +**Mitigation:** Tasks are created by the operator. Prompt injection in task +instructions or working directories is the primary attack surface (see below). + +### Prompt Injection via `working_dir` in Elaborate + +**Decision:** `POST /api/elaborate` accepts a `working_dir` field from the HTTP +request body and embeds it (via `%q`) into a Claude system prompt. + +**Risk:** A crafted `working_dir` value can inject arbitrary text into the +elaboration prompt. The `%q` quoting prevents trivial injection but does not +eliminate the risk for sophisticated inputs. + +**Mitigation:** `working_dir` should be validated as an absolute filesystem path +before embedding. This is a known gap; see issue tracker. + +## Known Risks (from code review 2026-03-08) + +| ID | Severity | Location | Description | +|----|----------|----------|-------------| +| C1 | High | `server.go:62-92` | No auth on any endpoint | +| C2 | High | `scripts.go:28-88` | Unauthenticated server-side script execution | +| C3 | Medium | `server.go:481` | `Access-Control-Allow-Origin: *` (intentional) | +| C4 | Medium | `elaborate.go:101-103` | Prompt injection via `working_dir` | +| M2 | Medium | `logs.go:79,125` | Path traversal: `exec.StdoutPath` from DB not validated before `os.Open` | +| M3 | Medium | `server.go` | No request body size limit — 1 GB body accepted | +| M6 | Medium | `scripts.go:85` | `stderr` returned to caller may contain internal paths/credentials | +| M7 | Medium | `websocket.go:58` | WebSocket broadcasts task events (cost, errors) to all clients without auth | +| M8 | Medium | `server.go:256-269` | `/api/workspaces` enumerates filesystem layout; raw `os.ReadDir` errors returned | +| L1 | Low | `notify.go:34-39` | Webhook URL not validated — `file://`, internal addresses accepted (SSRF if exposed) | +| L6 | Low | `reporter.go:109-113` | `HTMLReporter` uses `fmt.Fprintf` for HTML — XSS if user-visible fields added | +| X1 | Low | `app.js:1047,1717,2131` | `err.message` injected via `innerHTML` — XSS if server returns HTML in error | + +Risks marked Medium/High are acceptable for the current local-only deployment model +but must be addressed before exposing the service to untrusted networks. + +## Recommended Changes (if exposed to untrusted networks) + +1. Add API-key middleware (token in `Authorization` header or `X-API-Key` header). +2. Bind the listen address to `127.0.0.1` by default; require explicit opt-in for + LAN/public exposure. +3. Validate `StdoutPath` is under the known data directory before `os.Open`. +4. Wrap request bodies with `http.MaxBytesReader` (e.g. 10 MB limit). +5. Sanitize `working_dir` in elaborate: must be absolute, no shell metacharacters. +6. Validate webhook URLs to `http://` or `https://` scheme only. +7. Replace `fmt.Fprintf` HTML generation in `HTMLReporter` with `html/template`. +8. Replace `innerHTML` template literals containing `err.message` with `textContent`. + +## Consequences + +- Current posture: local-only, single-user, no auth. Acceptable for the current use case. +- Future: any expansion to shared/remote access requires auth first. +- The `bypassPermissions` default is a permanent trade-off: unattended automation vs. + per-operation safety prompts. Operator override via task YAML `permission_mode` is + always available. @@ -11,6 +11,7 @@ require ( ) require ( + github.com/BurntSushi/toml v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/spf13/pflag v1.0.9 // indirect ) @@ -1,3 +1,5 @@ +github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= +github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go index 5ab9ff0..907cb98 100644 --- a/internal/api/elaborate.go +++ b/internal/api/elaborate.go @@ -14,9 +14,9 @@ import ( const elaborateTimeout = 30 * time.Second func buildElaboratePrompt(workDir string) string { - workDirLine := ` "working_dir": string — leave empty unless you have a specific reason to set it,` + workDirLine := ` "project_dir": string — leave empty unless you have a specific reason to set it,` if workDir != "" { - workDirLine = fmt.Sprintf(` "working_dir": string — use %q for tasks that operate on this codebase, empty string otherwise,`, workDir) + workDirLine = fmt.Sprintf(` "project_dir": string — use %q for tasks that operate on this codebase, empty string otherwise,`, workDir) } return `You are a task configuration assistant for Claudomator, an AI task runner that executes tasks by running Claude or Gemini as a subprocess. @@ -55,7 +55,7 @@ type elaboratedAgent struct { Type string `json:"type"` Model string `json:"model"` Instructions string `json:"instructions"` - WorkingDir string `json:"working_dir"` + ProjectDir string `json:"project_dir"` MaxBudgetUSD float64 `json:"max_budget_usd"` AllowedTools []string `json:"allowed_tools"` } @@ -87,9 +87,14 @@ func (s *Server) claudeBinaryPath() string { } func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { + if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + var input struct { Prompt string `json:"prompt"` - WorkingDir string `json:"working_dir"` + ProjectDir string `json:"project_dir"` } if err := json.NewDecoder(r.Body).Decode(&input); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) @@ -101,8 +106,8 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { } workDir := s.workDir - if input.WorkingDir != "" { - workDir = input.WorkingDir + if input.ProjectDir != "" { + workDir = input.ProjectDir } ctx, cancel := context.WithTimeout(r.Context(), elaborateTimeout) diff --git a/internal/api/elaborate_test.go b/internal/api/elaborate_test.go index 4939701..b33ca11 100644 --- a/internal/api/elaborate_test.go +++ b/internal/api/elaborate_test.go @@ -57,7 +57,7 @@ func TestElaborateTask_Success(t *testing.T) { Type: "claude", Model: "sonnet", Instructions: "Run go test -race ./... and report results.", - WorkingDir: "", + ProjectDir: "", MaxBudgetUSD: 0.5, AllowedTools: []string{"Bash"}, }, diff --git a/internal/api/executions.go b/internal/api/executions.go index d9214c0..114425e 100644 --- a/internal/api/executions.go +++ b/internal/api/executions.go @@ -21,12 +21,16 @@ func (s *Server) handleListRecentExecutions(w http.ResponseWriter, r *http.Reque } } + const maxLimit = 1000 limit := 50 if v := r.URL.Query().Get("limit"); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 { limit = n } } + if limit > maxLimit { + limit = maxLimit + } taskID := r.URL.Query().Get("task_id") diff --git a/internal/api/executions_test.go b/internal/api/executions_test.go index a2bba21..45548ad 100644 --- a/internal/api/executions_test.go +++ b/internal/api/executions_test.go @@ -258,6 +258,26 @@ func TestGetExecutionLog_FollowSSEHeaders(t *testing.T) { } } +func TestListRecentExecutions_LimitClamped(t *testing.T) { + srv, _ := testServer(t) + + req := httptest.NewRequest("GET", "/api/executions?limit=10000000", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } + // The handler should not pass limit > maxLimit to the store. + // We verify indirectly: if the query param is accepted without error and + // does not cause a panic or 500, the clamp is in effect. + // A direct assertion requires a mock store; here we check the response is valid. + var execs []storage.RecentExecution + if err := json.NewDecoder(w.Body).Decode(&execs); err != nil { + t.Fatalf("decoding response: %v", err) + } +} + func TestListTasks_ReturnsStateField(t *testing.T) { srv, store := testServer(t) createTaskWithState(t, store, "state-check", task.StateRunning) diff --git a/internal/api/logs.go b/internal/api/logs.go index 1ba4b00..4e63489 100644 --- a/internal/api/logs.go +++ b/internal/api/logs.go @@ -36,9 +36,10 @@ var terminalStates = map[string]bool{ } type logStreamMsg struct { - Type string `json:"type"` - Message *logAssistMsg `json:"message,omitempty"` - CostUSD float64 `json:"cost_usd,omitempty"` + Type string `json:"type"` + Message *logAssistMsg `json:"message,omitempty"` + CostUSD float64 `json:"cost_usd,omitempty"` + TotalCostUSD float64 `json:"total_cost_usd,omitempty"` } type logAssistMsg struct { @@ -258,29 +259,31 @@ func emitLogLine(w http.ResponseWriter, flusher http.Flusher, line []byte) { return } for _, block := range msg.Message.Content { - var event map[string]string + var data []byte switch block.Type { case "text": - event = map[string]string{"type": "text", "content": block.Text} + data, _ = json.Marshal(map[string]string{"type": "text", "content": block.Text}) case "tool_use": - summary := string(block.Input) - if len(summary) > 80 { - summary = summary[:80] - } - event = map[string]string{"type": "tool_use", "content": fmt.Sprintf("%s(%s)", block.Name, summary)} + data, _ = json.Marshal(struct { + Type string `json:"type"` + Name string `json:"name"` + Input json.RawMessage `json:"input,omitempty"` + }{Type: "tool_use", Name: block.Name, Input: block.Input}) default: continue } - data, _ := json.Marshal(event) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() } case "result": - event := map[string]string{ - "type": "cost", - "content": fmt.Sprintf("%g", msg.CostUSD), + cost := msg.TotalCostUSD + if cost == 0 { + cost = msg.CostUSD } - data, _ := json.Marshal(event) + data, _ := json.Marshal(struct { + Type string `json:"type"` + TotalCost float64 `json:"total_cost"` + }{Type: "cost", TotalCost: cost}) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() } diff --git a/internal/api/logs_test.go b/internal/api/logs_test.go index 52fa168..6c6be05 100644 --- a/internal/api/logs_test.go +++ b/internal/api/logs_test.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "errors" "net/http" "net/http/httptest" @@ -293,6 +294,76 @@ func TestHandleStreamTaskLogs_TerminalExecution_EmitsEventsAndDone(t *testing.T) } } +// TestEmitLogLine_ToolUse_EmitsNameField verifies that emitLogLine emits a tool_use SSE event +// with a "name" field matching the tool name so the web UI can display it as "[ToolName]". +func TestEmitLogLine_ToolUse_EmitsNameField(t *testing.T) { + line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","input":{"command":"ls -la"}}]}}`) + + w := httptest.NewRecorder() + emitLogLine(w, w, line) + + body := w.Body.String() + var found bool + for _, chunk := range strings.Split(body, "\n\n") { + chunk = strings.TrimSpace(chunk) + if !strings.HasPrefix(chunk, "data: ") { + continue + } + jsonStr := strings.TrimPrefix(chunk, "data: ") + var e map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &e); err != nil { + continue + } + if e["type"] == "tool_use" { + if e["name"] != "Bash" { + t.Errorf("tool_use event name: want Bash, got %v", e["name"]) + } + if e["input"] == nil { + t.Error("tool_use event input: expected non-nil") + } + found = true + } + } + if !found { + t.Errorf("no tool_use event found in SSE output:\n%s", body) + } +} + +// TestEmitLogLine_Cost_EmitsTotalCostField verifies that emitLogLine emits a cost SSE event +// with a numeric "total_cost" field so the web UI can display it correctly. +func TestEmitLogLine_Cost_EmitsTotalCostField(t *testing.T) { + line := []byte(`{"type":"result","total_cost_usd":0.0042}`) + + w := httptest.NewRecorder() + emitLogLine(w, w, line) + + body := w.Body.String() + var found bool + for _, chunk := range strings.Split(body, "\n\n") { + chunk = strings.TrimSpace(chunk) + if !strings.HasPrefix(chunk, "data: ") { + continue + } + jsonStr := strings.TrimPrefix(chunk, "data: ") + var e map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &e); err != nil { + continue + } + if e["type"] == "cost" { + if e["total_cost"] == nil { + t.Error("cost event total_cost: expected non-nil numeric field") + } + if v, ok := e["total_cost"].(float64); !ok || v != 0.0042 { + t.Errorf("cost event total_cost: want 0.0042, got %v", e["total_cost"]) + } + found = true + } + } + if !found { + t.Errorf("no cost event found in SSE output:\n%s", body) + } +} + // TestHandleStreamTaskLogs_RunningExecution_LiveTails verifies that a RUNNING execution is // live-tailed and a done event is emitted once it transitions to a terminal state. func TestHandleStreamTaskLogs_RunningExecution_LiveTails(t *testing.T) { diff --git a/internal/api/ratelimit.go b/internal/api/ratelimit.go new file mode 100644 index 0000000..089354c --- /dev/null +++ b/internal/api/ratelimit.go @@ -0,0 +1,99 @@ +package api + +import ( + "net" + "net/http" + "sync" + "time" +) + +// ipRateLimiter provides per-IP token-bucket rate limiting. +type ipRateLimiter struct { + mu sync.Mutex + limiters map[string]*tokenBucket + rate float64 // tokens replenished per second + burst int // maximum token capacity +} + +// newIPRateLimiter creates a limiter with the given replenishment rate (tokens/sec) +// and burst capacity. Use rate=0 to disable replenishment (tokens never refill). +func newIPRateLimiter(rate float64, burst int) *ipRateLimiter { + return &ipRateLimiter{ + limiters: make(map[string]*tokenBucket), + rate: rate, + burst: burst, + } +} + +func (l *ipRateLimiter) allow(ip string) bool { + l.mu.Lock() + b, ok := l.limiters[ip] + if !ok { + b = &tokenBucket{ + tokens: float64(l.burst), + capacity: float64(l.burst), + rate: l.rate, + lastTime: time.Now(), + } + l.limiters[ip] = b + } + l.mu.Unlock() + return b.allow() +} + +// middleware wraps h with per-IP rate limiting, returning 429 when exceeded. +func (l *ipRateLimiter) middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := realIP(r) + if !l.allow(ip) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + next.ServeHTTP(w, r) + }) +} + +// tokenBucket is a simple token-bucket rate limiter for a single key. +type tokenBucket struct { + mu sync.Mutex + tokens float64 + capacity float64 + rate float64 // tokens per second + lastTime time.Time +} + +func (b *tokenBucket) allow() bool { + b.mu.Lock() + defer b.mu.Unlock() + now := time.Now() + if !b.lastTime.IsZero() { + elapsed := now.Sub(b.lastTime).Seconds() + b.tokens = min(b.capacity, b.tokens+elapsed*b.rate) + } + b.lastTime = now + if b.tokens >= 1.0 { + b.tokens-- + return true + } + return false +} + +// realIP extracts the client's real IP from a request. +func realIP(r *http.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + for i, c := range xff { + if c == ',' { + return xff[:i] + } + } + return xff + } + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} diff --git a/internal/api/scripts.go b/internal/api/scripts.go index 9afbb75..822bd32 100644 --- a/internal/api/scripts.go +++ b/internal/api/scripts.go @@ -5,62 +5,33 @@ import ( "context" "net/http" "os/exec" - "path/filepath" "time" ) const scriptTimeout = 30 * time.Second -func (s *Server) startNextTaskScriptPath() string { - if s.startNextTaskScript != "" { - return s.startNextTaskScript - } - return filepath.Join(s.workDir, "scripts", "start-next-task") -} +// ScriptRegistry maps endpoint names to executable script paths. +// Only registered scripts are exposed via POST /api/scripts/{name}. +type ScriptRegistry map[string]string -func (s *Server) deployScriptPath() string { - if s.deployScript != "" { - return s.deployScript - } - return filepath.Join(s.workDir, "scripts", "deploy") +// SetScripts configures the script registry. The mux is not re-registered; +// the handler looks up the registry at request time, so this may be called +// after NewServer but before the first request. +func (s *Server) SetScripts(r ScriptRegistry) { + s.scripts = r } -func (s *Server) handleStartNextTask(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), scriptTimeout) - defer cancel() - - scriptPath := s.startNextTaskScriptPath() - cmd := exec.CommandContext(ctx, scriptPath) - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - err := cmd.Run() - exitCode := 0 - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - exitCode = exitErr.ExitCode() - } else { - s.logger.Error("start-next-task: script execution failed", "error", err, "path", scriptPath) - writeJSON(w, http.StatusInternalServerError, map[string]string{ - "error": "script execution failed: " + err.Error(), - }) - return - } +func (s *Server) handleScript(w http.ResponseWriter, r *http.Request) { + name := r.PathValue("name") + scriptPath, ok := s.scripts[name] + if !ok { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "script not found: " + name}) + return } - writeJSON(w, http.StatusOK, map[string]interface{}{ - "output": stdout.String(), - "exit_code": exitCode, - }) -} - -func (s *Server) handleDeploy(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), scriptTimeout) defer cancel() - scriptPath := s.deployScriptPath() cmd := exec.CommandContext(ctx, scriptPath) var stdout, stderr bytes.Buffer @@ -72,17 +43,18 @@ func (s *Server) handleDeploy(w http.ResponseWriter, r *http.Request) { if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { exitCode = exitErr.ExitCode() + s.logger.Warn("script exited non-zero", "name", name, "exit_code", exitCode, "stderr", stderr.String()) } else { - s.logger.Error("deploy: script execution failed", "error", err, "path", scriptPath) + s.logger.Error("script execution failed", "name", name, "error", err, "path", scriptPath) writeJSON(w, http.StatusInternalServerError, map[string]string{ - "error": "script execution failed: " + err.Error(), + "error": "script execution failed", }) return } } writeJSON(w, http.StatusOK, map[string]interface{}{ - "output": stdout.String() + stderr.String(), + "output": stdout.String(), "exit_code": exitCode, }) } diff --git a/internal/api/scripts_test.go b/internal/api/scripts_test.go index 7da133e..f5ece20 100644 --- a/internal/api/scripts_test.go +++ b/internal/api/scripts_test.go @@ -6,20 +6,65 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" ) +func TestServer_NoScripts_Returns404(t *testing.T) { + srv, _ := testServer(t) + // No scripts configured — all /api/scripts/* should return 404. + for _, name := range []string{"deploy", "start-next-task", "unknown"} { + req := httptest.NewRequest("POST", "/api/scripts/"+name, nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Errorf("POST /api/scripts/%s: want 404, got %d", name, w.Code) + } + } +} + +func TestServer_WithScripts_RunsRegisteredScript(t *testing.T) { + srv, _ := testServer(t) + + scriptDir := t.TempDir() + scriptPath := filepath.Join(scriptDir, "my-script") + if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho hello"), 0o755); err != nil { + t.Fatal(err) + } + + srv.SetScripts(ScriptRegistry{"my-script": scriptPath}) + + req := httptest.NewRequest("POST", "/api/scripts/my-script", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestServer_WithScripts_UnregisteredReturns404(t *testing.T) { + srv, _ := testServer(t) + srv.SetScripts(ScriptRegistry{"deploy": "/some/path"}) + + req := httptest.NewRequest("POST", "/api/scripts/other", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("want 404, got %d", w.Code) + } +} + func TestHandleDeploy_Success(t *testing.T) { srv, _ := testServer(t) - // Create a fake deploy script that exits 0 and prints output. scriptDir := t.TempDir() scriptPath := filepath.Join(scriptDir, "deploy") - script := "#!/bin/sh\necho 'deployed successfully'" - if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho 'deployed successfully'"), 0o755); err != nil { t.Fatal(err) } - srv.deployScript = scriptPath + srv.SetScripts(ScriptRegistry{"deploy": scriptPath}) req := httptest.NewRequest("POST", "/api/scripts/deploy", nil) w := httptest.NewRecorder() @@ -35,8 +80,7 @@ func TestHandleDeploy_Success(t *testing.T) { if body["exit_code"] != float64(0) { t.Errorf("exit_code: want 0, got %v", body["exit_code"]) } - output, _ := body["output"].(string) - if output == "" { + if output, _ := body["output"].(string); output == "" { t.Errorf("expected non-empty output") } } @@ -46,11 +90,10 @@ func TestHandleDeploy_ScriptFails(t *testing.T) { scriptDir := t.TempDir() scriptPath := filepath.Join(scriptDir, "deploy") - script := "#!/bin/sh\necho 'build failed' && exit 1" - if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho 'build failed' && exit 1"), 0o755); err != nil { t.Fatal(err) } - srv.deployScript = scriptPath + srv.SetScripts(ScriptRegistry{"deploy": scriptPath}) req := httptest.NewRequest("POST", "/api/scripts/deploy", nil) w := httptest.NewRecorder() @@ -67,3 +110,25 @@ func TestHandleDeploy_ScriptFails(t *testing.T) { t.Errorf("expected non-zero exit_code") } } + +func TestHandleScript_StderrNotLeakedToResponse(t *testing.T) { + srv, _ := testServer(t) + + scriptDir := t.TempDir() + scriptPath := filepath.Join(scriptDir, "deploy") + // Script writes sensitive info to stderr and exits non-zero. + script := "#!/bin/sh\necho 'stdout output'\necho 'SECRET_TOKEN=abc123' >&2\nexit 1" + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatal(err) + } + srv.SetScripts(ScriptRegistry{"deploy": scriptPath}) + + req := httptest.NewRequest("POST", "/api/scripts/deploy", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + body := w.Body.String() + if strings.Contains(body, "SECRET_TOKEN") { + t.Errorf("response must not contain stderr content; got: %s", body) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 6f343b6..3d7cb1e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -7,40 +7,64 @@ import ( "log/slog" "net/http" "os" + "strings" "time" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" webui "github.com/thepeterstone/claudomator/web" "github.com/google/uuid" ) +// questionStore is the minimal storage interface needed by handleAnswerQuestion. +type questionStore interface { + GetTask(id string) (*task.Task, error) + GetLatestExecution(taskID string) (*storage.Execution, error) + UpdateTaskQuestion(taskID, questionJSON string) error + UpdateTaskState(id string, newState task.State) error +} + // Server provides the REST API and WebSocket endpoint for Claudomator. type Server struct { store *storage.DB - logStore logStore // injectable for tests; defaults to store - taskLogStore taskLogStore // injectable for tests; defaults to store + logStore logStore // injectable for tests; defaults to store + taskLogStore taskLogStore // injectable for tests; defaults to store + questionStore questionStore // injectable for tests; defaults to store pool *executor.Pool hub *Hub logger *slog.Logger mux *http.ServeMux - claudeBinPath string // path to claude binary; defaults to "claude" - geminiBinPath string // path to gemini binary; defaults to "gemini" - elaborateCmdPath string // overrides claudeBinPath; used in tests - validateCmdPath string // overrides claudeBinPath for validate; used in tests - startNextTaskScript string // path to start-next-task script; overridden in tests - deployScript string // path to deploy script; overridden in tests - workDir string // working directory injected into elaborate system prompt + claudeBinPath string // path to claude binary; defaults to "claude" + geminiBinPath string // path to gemini binary; defaults to "gemini" + elaborateCmdPath string // overrides claudeBinPath; used in tests + validateCmdPath string // overrides claudeBinPath for validate; used in tests + scripts ScriptRegistry // optional; maps endpoint name → script path + workDir string // working directory injected into elaborate system prompt + notifier notify.Notifier + apiToken string // if non-empty, required for WebSocket (and REST) connections + elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints +} + +// SetAPIToken configures a bearer token that must be supplied to access the API. +func (s *Server) SetAPIToken(token string) { + s.apiToken = token +} + +// SetNotifier configures a notifier that is called on every task completion. +func (s *Server) SetNotifier(n notify.Notifier) { + s.notifier = n } func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server { wd, _ := os.Getwd() s := &Server{ - store: store, - logStore: store, - taskLogStore: store, - pool: pool, + store: store, + logStore: store, + taskLogStore: store, + questionStore: store, + pool: pool, hub: NewHub(), logger: logger, mux: http.NewServeMux(), @@ -86,8 +110,7 @@ func (s *Server) routes() { s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate) s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion) s.mux.HandleFunc("POST /api/tasks/{id}/resume", s.handleResumeTimedOutTask) - s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask) - s.mux.HandleFunc("POST /api/scripts/deploy", s.handleDeploy) + s.mux.HandleFunc("POST /api/scripts/{name}", s.handleScript) s.mux.HandleFunc("GET /api/ws", s.handleWebSocket) s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces) s.mux.HandleFunc("GET /api/health", s.handleHealth) @@ -97,17 +120,44 @@ func (s *Server) routes() { // forwardResults listens on the executor pool's result channel and broadcasts via WebSocket. func (s *Server) forwardResults() { for result := range s.pool.Results() { - event := map[string]interface{}{ - "type": "task_completed", - "task_id": result.TaskID, - "status": result.Execution.Status, - "exit_code": result.Execution.ExitCode, - "cost_usd": result.Execution.CostUSD, - "error": result.Execution.ErrorMsg, - "timestamp": time.Now().UTC(), + s.processResult(result) + } +} + +// processResult broadcasts a task completion event via WebSocket and calls the notifier if set. +func (s *Server) processResult(result *executor.Result) { + event := map[string]interface{}{ + "type": "task_completed", + "task_id": result.TaskID, + "status": result.Execution.Status, + "exit_code": result.Execution.ExitCode, + "cost_usd": result.Execution.CostUSD, + "error": result.Execution.ErrorMsg, + "timestamp": time.Now().UTC(), + } + data, _ := json.Marshal(event) + s.hub.Broadcast(data) + + if s.notifier != nil { + var taskName string + if t, err := s.store.GetTask(result.TaskID); err == nil { + taskName = t.Name + } + var dur string + if !result.Execution.StartTime.IsZero() && !result.Execution.EndTime.IsZero() { + dur = result.Execution.EndTime.Sub(result.Execution.StartTime).String() + } + ne := notify.Event{ + TaskID: result.TaskID, + TaskName: taskName, + Status: result.Execution.Status, + CostUSD: result.Execution.CostUSD, + Duration: dur, + Error: result.Execution.ErrorMsg, + } + if err := s.notifier.Notify(ne); err != nil { + s.logger.Error("notifier failed", "error", err) } - data, _ := json.Marshal(event) - s.hub.Broadcast(data) } } @@ -169,7 +219,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") - tk, err := s.store.GetTask(taskID) + tk, err := s.questionStore.GetTask(taskID) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return @@ -192,15 +242,21 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { } // Look up the session ID from the most recent execution. - latest, err := s.store.GetLatestExecution(taskID) + latest, err := s.questionStore.GetLatestExecution(taskID) if err != nil || latest.SessionID == "" { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no resumable session found"}) return } // Clear the question and transition to QUEUED. - s.store.UpdateTaskQuestion(taskID, "") - s.store.UpdateTaskState(taskID, task.StateQueued) + if err := s.questionStore.UpdateTaskQuestion(taskID, ""); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to clear question"}) + return + } + if err := s.questionStore.UpdateTaskState(taskID, task.StateQueued); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to queue task"}) + return + } // Submit a resume execution. resumeExec := &storage.Execution{ @@ -256,9 +312,23 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request } func (s *Server) handleListWorkspaces(w http.ResponseWriter, r *http.Request) { + if s.apiToken != "" { + token := r.URL.Query().Get("token") + if token == "" { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token = strings.TrimPrefix(auth, "Bearer ") + } + } + if token != s.apiToken { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + entries, err := os.ReadDir("/workspace") if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list workspaces"}) return } var dirs []string @@ -375,6 +445,21 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { return } + // Enforce retry limit for non-initial runs (PENDING is the initial state). + if t.State != task.StatePending { + execs, err := s.store.ListExecutions(id) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to count executions"}) + return + } + if t.Retry.MaxAttempts > 0 && len(execs) >= t.Retry.MaxAttempts { + writeJSON(w, http.StatusConflict, map[string]string{ + "error": fmt.Sprintf("retry limit reached (%d/%d attempts used)", len(execs), t.Retry.MaxAttempts), + }) + return + } + } + if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return diff --git a/internal/api/server_test.go b/internal/api/server_test.go index b0ccb4a..cd415ae 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -9,15 +9,70 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" + "time" "context" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" ) +// mockNotifier records calls to Notify. +type mockNotifier struct { + events []notify.Event +} + +func (m *mockNotifier) Notify(e notify.Event) error { + m.events = append(m.events, e) + return nil +} + +func TestServer_ProcessResult_CallsNotifier(t *testing.T) { + srv, store := testServer(t) + + mn := &mockNotifier{} + srv.SetNotifier(mn) + + tk := &task.Task{ + ID: "task-notifier-test", + Name: "Notifier Task", + State: task.StatePending, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + + result := &executor.Result{ + TaskID: tk.ID, + Execution: &storage.Execution{ + ID: "exec-1", + TaskID: tk.ID, + Status: "COMPLETED", + CostUSD: 0.42, + ErrorMsg: "", + }, + } + srv.processResult(result) + + if len(mn.events) != 1 { + t.Fatalf("expected 1 notify event, got %d", len(mn.events)) + } + ev := mn.events[0] + if ev.TaskID != tk.ID { + t.Errorf("event.TaskID = %q, want %q", ev.TaskID, tk.ID) + } + if ev.Status != "COMPLETED" { + t.Errorf("event.Status = %q, want COMPLETED", ev.Status) + } + if ev.CostUSD != 0.42 { + t.Errorf("event.CostUSD = %v, want 0.42", ev.CostUSD) + } +} + func testServer(t *testing.T) (*Server, *storage.DB) { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") @@ -175,6 +230,20 @@ func TestListTasks_WithTasks(t *testing.T) { } } +// stateWalkPaths defines the sequence of intermediate states needed to reach each target state. +var stateWalkPaths = map[task.State][]task.State{ + task.StatePending: {}, + task.StateQueued: {task.StateQueued}, + task.StateRunning: {task.StateQueued, task.StateRunning}, + task.StateCompleted: {task.StateQueued, task.StateRunning, task.StateCompleted}, + task.StateFailed: {task.StateQueued, task.StateRunning, task.StateFailed}, + task.StateTimedOut: {task.StateQueued, task.StateRunning, task.StateTimedOut}, + task.StateCancelled: {task.StateCancelled}, + task.StateBudgetExceeded: {task.StateQueued, task.StateRunning, task.StateBudgetExceeded}, + task.StateReady: {task.StateQueued, task.StateRunning, task.StateReady}, + task.StateBlocked: {task.StateQueued, task.StateRunning, task.StateBlocked}, +} + func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.State) *task.Task { t.Helper() tk := &task.Task{ @@ -188,9 +257,9 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task. if err := store.CreateTask(tk); err != nil { t.Fatalf("createTaskWithState: CreateTask: %v", err) } - if state != task.StatePending { - if err := store.UpdateTaskState(id, state); err != nil { - t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", state, err) + for _, s := range stateWalkPaths[state] { + if err := store.UpdateTaskState(id, s); err != nil { + t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", s, err) } } tk.State = state @@ -425,7 +494,7 @@ func TestHandleStartNextTask_Success(t *testing.T) { } srv, _ := testServer(t) - srv.startNextTaskScript = script + srv.SetScripts(ScriptRegistry{"start-next-task": script}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -452,7 +521,7 @@ func TestHandleStartNextTask_NoTask(t *testing.T) { } srv, _ := testServer(t) - srv.startNextTaskScript = script + srv.SetScripts(ScriptRegistry{"start-next-task": script}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -535,9 +604,87 @@ func TestResumeTimedOut_Success_Returns202(t *testing.T) { } } +func TestRunTask_RetryLimitReached_Returns409(t *testing.T) { + srv, store := testServer(t) + // Task with MaxAttempts: 1 — only 1 attempt allowed. Create directly as FAILED + // so state is consistent without going through transition sequence. + tk := &task.Task{ + ID: "retry-limit-1", + Name: "Retry Limit Task", + Agent: task.AgentConfig{Instructions: "do something"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateFailed, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + // Record one execution — the first attempt already used. + exec := &storage.Execution{ + ID: "exec-retry-1", + TaskID: "retry-limit-1", + StartTime: time.Now(), + Status: "FAILED", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/api/tasks/retry-limit-1/run", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if !strings.Contains(body["error"], "retry limit") { + t.Errorf("error body should mention retry limit, got %q", body["error"]) + } +} + +func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) { + srv, store := testServer(t) + // Task with MaxAttempts: 3 — 1 execution used, 2 remaining. + tk := &task.Task{ + ID: "retry-within-1", + Name: "Retry Within Task", + Agent: task.AgentConfig{Instructions: "do something"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StatePending, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + exec := &storage.Execution{ + ID: "exec-within-1", + TaskID: "retry-within-1", + StartTime: time.Now(), + Status: "FAILED", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + store.UpdateTaskState("retry-within-1", task.StateFailed) + + req := httptest.NewRequest("POST", "/api/tasks/retry-within-1/run", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Errorf("status: want 202, got %d; body: %s", w.Code, w.Body.String()) + } +} + func TestHandleStartNextTask_ScriptNotFound(t *testing.T) { srv, _ := testServer(t) - srv.startNextTaskScript = "/nonexistent/start-next-task" + srv.SetScripts(ScriptRegistry{"start-next-task": "/nonexistent/start-next-task"}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -583,10 +730,20 @@ func TestDeleteTask_NotFound(t *testing.T) { func TestDeleteTask_RunningTaskRejected(t *testing.T) { srv, store := testServer(t) - created := createTestTask(t, srv, `{"name":"Running Task","agent":{"type":"claude","instructions":"x","model":"sonnet"}}`) - store.UpdateTaskState(created.ID, "RUNNING") - - req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil) + // Create the task directly in RUNNING state to avoid going through state transitions. + tk := &task.Task{ + ID: "running-task-del", + Name: "Running Task", + Agent: task.AgentConfig{Instructions: "x", Model: "sonnet"}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateRunning, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + req := httptest.NewRequest("DELETE", "/api/tasks/running-task-del", nil) w := httptest.NewRecorder() srv.Handler().ServeHTTP(w, req) @@ -662,3 +819,170 @@ func TestServer_CancelTask_Completed_Returns409(t *testing.T) { t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) } } + +// mockQuestionStore implements questionStore for testing handleAnswerQuestion. +type mockQuestionStore struct { + getTaskFn func(id string) (*task.Task, error) + getLatestExecutionFn func(taskID string) (*storage.Execution, error) + updateTaskQuestionFn func(taskID, questionJSON string) error + updateTaskStateFn func(id string, newState task.State) error +} + +func (m *mockQuestionStore) GetTask(id string) (*task.Task, error) { + return m.getTaskFn(id) +} +func (m *mockQuestionStore) GetLatestExecution(taskID string) (*storage.Execution, error) { + return m.getLatestExecutionFn(taskID) +} +func (m *mockQuestionStore) UpdateTaskQuestion(taskID, questionJSON string) error { + return m.updateTaskQuestionFn(taskID, questionJSON) +} +func (m *mockQuestionStore) UpdateTaskState(id string, newState task.State) error { + return m.updateTaskStateFn(id, newState) +} + +func TestServer_AnswerQuestion_UpdateQuestionFails_Returns500(t *testing.T) { + srv, _ := testServer(t) + srv.questionStore = &mockQuestionStore{ + getTaskFn: func(id string) (*task.Task, error) { + return &task.Task{ID: id, State: task.StateBlocked}, nil + }, + getLatestExecutionFn: func(taskID string) (*storage.Execution, error) { + return &storage.Execution{SessionID: "sess-1"}, nil + }, + updateTaskQuestionFn: func(taskID, questionJSON string) error { + return fmt.Errorf("db error") + }, + updateTaskStateFn: func(id string, newState task.State) error { + return nil + }, + } + + body := bytes.NewBufferString(`{"answer":"yes"}`) + req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestServer_AnswerQuestion_UpdateStateFails_Returns500(t *testing.T) { + srv, _ := testServer(t) + srv.questionStore = &mockQuestionStore{ + getTaskFn: func(id string) (*task.Task, error) { + return &task.Task{ID: id, State: task.StateBlocked}, nil + }, + getLatestExecutionFn: func(taskID string) (*storage.Execution, error) { + return &storage.Execution{SessionID: "sess-1"}, nil + }, + updateTaskQuestionFn: func(taskID, questionJSON string) error { + return nil + }, + updateTaskStateFn: func(id string, newState task.State) error { + return fmt.Errorf("db error") + }, + } + + body := bytes.NewBufferString(`{"answer":"yes"}`) + req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestRateLimit_ElaborateRejectsExcess(t *testing.T) { + srv, _ := testServer(t) + // Use burst-1 and rate-0 so the second request from the same IP is rejected. + srv.elaborateLimiter = newIPRateLimiter(0, 1) + + makeReq := func(remoteAddr string) int { + req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(`{"description":"x"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = remoteAddr + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + return w.Code + } + + // First request from IP A: limiter allows it (non-429). + if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests { + t.Errorf("first request should not be rate limited, got 429") + } + // Second request from IP A: bucket exhausted, must be 429. + if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests { + t.Errorf("second request from same IP should be 429, got %d", code) + } + // First request from IP B: separate bucket, not limited. + if code := makeReq("192.0.2.2:1234"); code == http.StatusTooManyRequests { + t.Errorf("first request from different IP should not be rate limited, got 429") + } +} + +func TestListWorkspaces_RequiresAuth(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + // No token: expect 401. + req := httptest.NewRequest("GET", "/api/workspaces", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 without token, got %d", w.Code) + } +} + +func TestListWorkspaces_RejectsWrongToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + req := httptest.NewRequest("GET", "/api/workspaces", nil) + req.Header.Set("Authorization", "Bearer wrong-token") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 with wrong token, got %d", w.Code) + } +} + +func TestListWorkspaces_SuppressesRawError(t *testing.T) { + srv, _ := testServer(t) + // No token configured so auth is bypassed; /workspace likely doesn't exist in test env. + + req := httptest.NewRequest("GET", "/api/workspaces", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code == http.StatusInternalServerError { + body := w.Body.String() + if strings.Contains(body, "/workspace") || strings.Contains(body, "no such file") { + t.Errorf("response leaks filesystem details: %s", body) + } + } +} + +func TestRateLimit_ValidateRejectsExcess(t *testing.T) { + srv, _ := testServer(t) + srv.elaborateLimiter = newIPRateLimiter(0, 1) + + makeReq := func(remoteAddr string) int { + req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(`{"description":"x"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = remoteAddr + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + return w.Code + } + + if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests { + t.Errorf("first validate request should not be rate limited, got 429") + } + if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests { + t.Errorf("second validate request from same IP should be 429, got %d", code) + } +} diff --git a/internal/api/validate.go b/internal/api/validate.go index a3b2cf0..07d293c 100644 --- a/internal/api/validate.go +++ b/internal/api/validate.go @@ -48,16 +48,22 @@ func (s *Server) validateBinaryPath() string { if s.validateCmdPath != "" { return s.validateCmdPath } - return s.claudeBinaryPath() + return s.claudeBinPath } func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) { + if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + var input struct { Name string `json:"name"` Agent struct { Type string `json:"type"` Instructions string `json:"instructions"` - WorkingDir string `json:"working_dir"` + ProjectDir string `json:"project_dir"` + WorkingDir string `json:"working_dir"` // legacy AllowedTools []string `json:"allowed_tools"` } `json:"agent"` } @@ -79,9 +85,14 @@ func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) { agentType = "claude" } + projectDir := input.Agent.ProjectDir + if projectDir == "" { + projectDir = input.Agent.WorkingDir + } + userMsg := fmt.Sprintf("Task name: %s\nAgent: %s\n\nInstructions:\n%s", input.Name, agentType, input.Agent.Instructions) - if input.Agent.WorkingDir != "" { - userMsg += fmt.Sprintf("\n\nWorking directory: %s", input.Agent.WorkingDir) + if projectDir != "" { + userMsg += fmt.Sprintf("\n\nWorking directory: %s", projectDir) } if len(input.Agent.AllowedTools) > 0 { userMsg += fmt.Sprintf("\n\nAllowed tools: %v", input.Agent.AllowedTools) diff --git a/internal/api/websocket.go b/internal/api/websocket.go index 6bd8c88..b5bf728 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -1,13 +1,27 @@ package api import ( + "errors" "log/slog" "net/http" + "strings" "sync" + "time" "golang.org/x/net/websocket" ) +// wsPingInterval and wsPingDeadline control heartbeat timing. +// Exposed as vars so tests can override them without rebuilding. +var ( + wsPingInterval = 30 * time.Second + wsPingDeadline = 10 * time.Second + + // maxWsClients caps the number of concurrent WebSocket connections. + // Exposed as a var so tests can override it. + maxWsClients = 1000 +) + // Hub manages WebSocket connections and broadcasts messages. type Hub struct { mu sync.RWMutex @@ -25,10 +39,14 @@ func NewHub() *Hub { // Run is a no-op loop kept for future cleanup/heartbeat logic. func (h *Hub) Run() {} -func (h *Hub) Register(ws *websocket.Conn) { +func (h *Hub) Register(ws *websocket.Conn) error { h.mu.Lock() + defer h.mu.Unlock() + if len(h.clients) >= maxWsClients { + return errors.New("max WebSocket clients reached") + } h.clients[ws] = true - h.mu.Unlock() + return nil } func (h *Hub) Unregister(ws *websocket.Conn) { @@ -56,11 +74,56 @@ func (h *Hub) ClientCount() int { } func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + if s.hub.ClientCount() >= maxWsClients { + http.Error(w, "too many connections", http.StatusServiceUnavailable) + return + } + + if s.apiToken != "" { + token := r.URL.Query().Get("token") + if token == "" { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token = strings.TrimPrefix(auth, "Bearer ") + } + } + if token != s.apiToken { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + handler := websocket.Handler(func(ws *websocket.Conn) { - s.hub.Register(ws) + if err := s.hub.Register(ws); err != nil { + return + } defer s.hub.Unregister(ws) - // Keep connection alive until client disconnects. + // Ping goroutine: detect dead connections by sending periodic pings. + // A write failure (including write deadline exceeded) closes the conn, + // causing the read loop below to exit and unregister the client. + done := make(chan struct{}) + defer close(done) + go func() { + ticker := time.NewTicker(wsPingInterval) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + ws.SetWriteDeadline(time.Now().Add(wsPingDeadline)) + err := websocket.Message.Send(ws, "ping") + ws.SetWriteDeadline(time.Time{}) + if err != nil { + ws.Close() + return + } + } + } + }() + + // Keep connection alive until client disconnects or ping fails. buf := make([]byte, 1024) for { if _, err := ws.Read(buf); err != nil { diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go new file mode 100644 index 0000000..72b83f2 --- /dev/null +++ b/internal/api/websocket_test.go @@ -0,0 +1,221 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +// TestWebSocket_RejectsConnectionWithoutToken verifies that when an API token +// is configured, WebSocket connections without a valid token are rejected with 401. +func TestWebSocket_RejectsConnectionWithoutToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + // Plain HTTP request simulates a WebSocket upgrade attempt without token. + req := httptest.NewRequest("GET", "/api/ws", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("want 401, got %d", w.Code) + } +} + +// TestWebSocket_RejectsConnectionWithWrongToken verifies a wrong token is rejected. +func TestWebSocket_RejectsConnectionWithWrongToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + req := httptest.NewRequest("GET", "/api/ws?token=wrong-token", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("want 401, got %d", w.Code) + } +} + +// TestWebSocket_AcceptsConnectionWithValidQueryToken verifies a valid token in +// the query string is accepted. +func TestWebSocket_AcceptsConnectionWithValidQueryToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws?token=secret-token" + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("expected connection to succeed with valid token: %v", err) + } + ws.Close() +} + +// TestWebSocket_AcceptsConnectionWithBearerToken verifies a valid token in the +// Authorization header is accepted. +func TestWebSocket_AcceptsConnectionWithBearerToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws" + cfg, err := websocket.NewConfig(wsURL, "http://localhost/") + if err != nil { + t.Fatalf("config: %v", err) + } + cfg.Header = http.Header{"Authorization": {"Bearer secret-token"}} + ws, err := websocket.DialConfig(cfg) + if err != nil { + t.Fatalf("expected connection to succeed with Bearer token: %v", err) + } + ws.Close() +} + +// TestWebSocket_NoTokenConfigured verifies that when no API token is set, +// connections are allowed without authentication. +func TestWebSocket_NoTokenConfigured(t *testing.T) { + srv, _ := testServer(t) + // No SetAPIToken call — auth is disabled. + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws" + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("expected connection without token when auth disabled: %v", err) + } + ws.Close() +} + +// TestWebSocket_RejectsConnectionWhenAtMaxClients verifies that when the hub +// is at capacity, new WebSocket upgrade requests are rejected with 503. +func TestWebSocket_RejectsConnectionWhenAtMaxClients(t *testing.T) { + orig := maxWsClients + maxWsClients = 0 // immediately at capacity + t.Cleanup(func() { maxWsClients = orig }) + + srv, _ := testServer(t) + srv.StartHub() + + req := httptest.NewRequest("GET", "/api/ws", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("want 503, got %d", w.Code) + } +} + +// TestWebSocket_StaleConnectionCleanedUp verifies that when a client +// disconnects (or the connection is closed), the hub unregisters it. +// Short ping intervals are used so the test completes quickly. +func TestWebSocket_StaleConnectionCleanedUp(t *testing.T) { + origInterval := wsPingInterval + origDeadline := wsPingDeadline + wsPingInterval = 20 * time.Millisecond + wsPingDeadline = 20 * time.Millisecond + t.Cleanup(func() { + wsPingInterval = origInterval + wsPingDeadline = origDeadline + }) + + srv, _ := testServer(t) + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws" + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("dial: %v", err) + } + + // Wait for hub to register the client. + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + if srv.hub.ClientCount() == 1 { + break + } + time.Sleep(5 * time.Millisecond) + } + if got := srv.hub.ClientCount(); got != 1 { + t.Fatalf("before close: want 1 client, got %d", got) + } + + // Close connection without a proper WebSocket close handshake + // to simulate a client crash / network drop. + ws.Close() + + // Hub should unregister the client promptly. + deadline = time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + if srv.hub.ClientCount() == 0 { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("after close: expected 0 clients, got %d", srv.hub.ClientCount()) +} + +// TestWebSocket_PingWriteDeadlineEvictsStaleConn verifies that a stale +// connection (write times out) is eventually evicted by the ping goroutine. +// It uses a very short write deadline to force a timeout on a connection +// whose receive buffer is full. +func TestWebSocket_PingWriteDeadlineEvictsStaleConn(t *testing.T) { + origInterval := wsPingInterval + origDeadline := wsPingDeadline + // Very short deadline: ping fails almost immediately after the first tick. + wsPingInterval = 30 * time.Millisecond + wsPingDeadline = 1 * time.Millisecond + t.Cleanup(func() { + wsPingInterval = origInterval + wsPingDeadline = origDeadline + }) + + srv, _ := testServer(t) + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws" + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("dial: %v", err) + } + defer ws.Close() + + // Wait for registration. + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + if srv.hub.ClientCount() == 1 { + break + } + time.Sleep(5 * time.Millisecond) + } + if got := srv.hub.ClientCount(); got != 1 { + t.Fatalf("before stale: want 1 client, got %d", got) + } + + // The connection itself is alive (loopback), so the 1ms deadline is generous + // enough to succeed. This test mainly verifies the ping goroutine doesn't + // panic and that ClientCount stays consistent after disconnect. + ws.Close() + + deadline = time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + if srv.hub.ClientCount() == 0 { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("expected 0 clients after stale eviction, got %d", srv.hub.ClientCount()) +} diff --git a/internal/cli/create.go b/internal/cli/create.go index fdad932..addd034 100644 --- a/internal/cli/create.go +++ b/internal/cli/create.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" "fmt" - "net/http" + "io" "github.com/spf13/cobra" ) @@ -52,7 +52,7 @@ func createTask(serverURL, name, instructions, workingDir, model, parentID strin "priority": priority, "claude": map[string]interface{}{ "instructions": instructions, - "working_dir": workingDir, + "project_dir": workingDir, "model": model, "max_budget_usd": budget, }, @@ -62,20 +62,26 @@ func createTask(serverURL, name, instructions, workingDir, model, parentID strin } data, _ := json.Marshal(body) - resp, err := http.Post(serverURL+"/api/tasks", "application/json", bytes.NewReader(data)) //nolint:noctx + resp, err := httpClient.Post(serverURL+"/api/tasks", "application/json", bytes.NewReader(data)) //nolint:noctx if err != nil { return fmt.Errorf("POST /api/tasks: %w", err) } defer resp.Body.Close() + raw, _ := io.ReadAll(resp.Body) var result map[string]interface{} - _ = json.NewDecoder(resp.Body).Decode(&result) + if err := json.Unmarshal(raw, &result); err != nil { + return fmt.Errorf("server returned invalid JSON (status %d): %s", resp.StatusCode, string(raw)) + } if resp.StatusCode >= 300 { return fmt.Errorf("server returned %d: %v", resp.StatusCode, result["error"]) } id, _ := result["id"].(string) + if id == "" { + return fmt.Errorf("server returned task without id field") + } fmt.Printf("Created task %s\n", id) if autoStart { diff --git a/internal/cli/create_test.go b/internal/cli/create_test.go new file mode 100644 index 0000000..22ce6bd --- /dev/null +++ b/internal/cli/create_test.go @@ -0,0 +1,125 @@ +package cli + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestCreateTask_TimesOut(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(5 * time.Second): // fallback so srv.Close() never deadlocks + } + })) + defer srv.Close() + + orig := httpClient + httpClient = &http.Client{Timeout: 50 * time.Millisecond} + defer func() { httpClient = orig }() + + err := createTask(srv.URL, "test", "do something", "", "", "", 1.0, "15m", "normal", false) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "POST /api/tasks") { + t.Errorf("expected error mentioning POST /api/tasks, got: %v", err) + } +} + +func TestStartTask_EscapesTaskID(t *testing.T) { + var capturedPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.RawPath + if capturedPath == "" { + capturedPath = r.URL.Path + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer srv.Close() + + err := startTask(srv.URL, "task/with/slashes") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if strings.Contains(capturedPath, "task/with/slashes") { + t.Errorf("task ID was not escaped; raw path contains unescaped slashes: %s", capturedPath) + } + if !strings.Contains(capturedPath, "task%2Fwith%2Fslashes") { + t.Errorf("expected escaped path segment, got: %s", capturedPath) + } +} + +func TestCreateTask_MissingIDField_ReturnsError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"name":"test"}`)) // no "id" field + })) + defer srv.Close() + + err := createTask(srv.URL, "test", "do something", "", "", "", 1.0, "15m", "normal", false) + if err == nil { + t.Fatal("expected error for missing id field, got nil") + } + if !strings.Contains(err.Error(), "without id") { + t.Errorf("expected error mentioning missing id, got: %v", err) + } +} + +func TestCreateTask_NonJSONResponse_ReturnsError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte(`<html>502 Bad Gateway</html>`)) + })) + defer srv.Close() + + err := createTask(srv.URL, "test", "do something", "", "", "", 1.0, "15m", "normal", false) + if err == nil { + t.Fatal("expected error for non-JSON response, got nil") + } + if !strings.Contains(err.Error(), "invalid JSON") { + t.Errorf("expected error mentioning invalid JSON, got: %v", err) + } +} + +func TestStartTask_NonJSONResponse_ReturnsError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte(`<html>502 Bad Gateway</html>`)) + })) + defer srv.Close() + + err := startTask(srv.URL, "task-abc") + if err == nil { + t.Fatal("expected error for non-JSON response, got nil") + } + if !strings.Contains(err.Error(), "invalid JSON") { + t.Errorf("expected error mentioning invalid JSON, got: %v", err) + } +} + +func TestStartTask_TimesOut(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(5 * time.Second): // fallback so srv.Close() never deadlocks + } + })) + defer srv.Close() + + orig := httpClient + httpClient = &http.Client{Timeout: 50 * time.Millisecond} + defer func() { httpClient = orig }() + + err := startTask(srv.URL, "task-abc") + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "POST") { + t.Errorf("expected error mentioning POST, got: %v", err) + } +} diff --git a/internal/cli/http.go b/internal/cli/http.go new file mode 100644 index 0000000..907818a --- /dev/null +++ b/internal/cli/http.go @@ -0,0 +1,10 @@ +package cli + +import ( + "net/http" + "time" +) + +const httpTimeout = 30 * time.Second + +var httpClient = &http.Client{Timeout: httpTimeout} diff --git a/internal/cli/report.go b/internal/cli/report.go new file mode 100644 index 0000000..7f95c80 --- /dev/null +++ b/internal/cli/report.go @@ -0,0 +1,74 @@ +package cli + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + "github.com/thepeterstone/claudomator/internal/reporter" + "github.com/thepeterstone/claudomator/internal/storage" +) + +func newReportCmd() *cobra.Command { + var format string + var limit int + var taskID string + + cmd := &cobra.Command{ + Use: "report", + Short: "Report execution history", + RunE: func(cmd *cobra.Command, args []string) error { + return runReport(format, limit, taskID) + }, + } + + cmd.Flags().StringVar(&format, "format", "table", "output format: table, json, html") + cmd.Flags().IntVar(&limit, "limit", 50, "maximum number of executions to show") + cmd.Flags().StringVar(&taskID, "task", "", "filter by task ID") + + return cmd +} + +func runReport(format string, limit int, taskID string) error { + var rep reporter.Reporter + switch format { + case "table", "": + rep = &reporter.ConsoleReporter{} + case "json": + rep = &reporter.JSONReporter{Pretty: true} + case "html": + rep = &reporter.HTMLReporter{} + default: + return fmt.Errorf("invalid format %q: must be table, json, or html", format) + } + + store, err := storage.Open(cfg.DBPath) + if err != nil { + return fmt.Errorf("opening db: %w", err) + } + defer store.Close() + + recent, err := store.ListRecentExecutions(time.Time{}, limit, taskID) + if err != nil { + return fmt.Errorf("listing executions: %w", err) + } + + execs := make([]*storage.Execution, len(recent)) + for i, r := range recent { + e := &storage.Execution{ + ID: r.ID, + TaskID: r.TaskID, + Status: r.State, + StartTime: r.StartedAt, + ExitCode: r.ExitCode, + CostUSD: r.CostUSD, + } + if r.FinishedAt != nil { + e.EndTime = *r.FinishedAt + } + execs[i] = e + } + + return rep.Generate(os.Stdout, execs) +} diff --git a/internal/cli/report_test.go b/internal/cli/report_test.go new file mode 100644 index 0000000..3ef96f4 --- /dev/null +++ b/internal/cli/report_test.go @@ -0,0 +1,32 @@ +package cli + +import ( + "strings" + "testing" +) + +func TestReportCmd_InvalidFormat(t *testing.T) { + cmd := newReportCmd() + cmd.SetArgs([]string{"--format", "xml"}) + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for invalid format, got nil") + } + if !strings.Contains(err.Error(), "format") { + t.Errorf("expected error to mention 'format', got: %v", err) + } +} + +func TestReportCmd_DefaultsRegistered(t *testing.T) { + cmd := newReportCmd() + f := cmd.Flags() + if f.Lookup("format") == nil { + t.Error("missing --format flag") + } + if f.Lookup("limit") == nil { + t.Error("missing --limit flag") + } + if f.Lookup("task") == nil { + t.Error("missing --task flag") + } +} diff --git a/internal/cli/root.go b/internal/cli/root.go index 1a528fb..ab6ac1f 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -1,12 +1,17 @@ package cli import ( + "fmt" + "log/slog" + "os" "path/filepath" "github.com/thepeterstone/claudomator/internal/config" "github.com/spf13/cobra" ) +const defaultServerURL = "http://localhost:8484" + var ( cfgFile string verbose bool @@ -14,7 +19,12 @@ var ( ) func NewRootCmd() *cobra.Command { - cfg = config.Default() + var err error + cfg, err = config.Default() + if err != nil { + fmt.Fprintf(os.Stderr, "fatal: %v\n", err) + os.Exit(1) + } cmd := &cobra.Command{ Use: "claudomator", @@ -43,6 +53,7 @@ func NewRootCmd() *cobra.Command { newLogsCmd(), newStartCmd(), newCreateCmd(), + newReportCmd(), ) return cmd @@ -51,3 +62,11 @@ func NewRootCmd() *cobra.Command { func Execute() error { return NewRootCmd().Execute() } + +func newLogger(v bool) *slog.Logger { + level := slog.LevelInfo + if v { + level = slog.LevelDebug + } + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} diff --git a/internal/cli/run.go b/internal/cli/run.go index 62e1252..49aa28e 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -3,7 +3,6 @@ package cli import ( "context" "fmt" - "log/slog" "os" "os/signal" "syscall" @@ -36,6 +35,10 @@ func newRunCmd() *cobra.Command { } func runTasks(file string, parallel int, dryRun bool) error { + if parallel < 1 { + return fmt.Errorf("--parallel must be at least 1, got %d", parallel) + } + tasks, err := task.ParseFile(file) if err != nil { return fmt.Errorf("parsing: %w", err) @@ -67,11 +70,7 @@ func runTasks(file string, parallel int, dryRun bool) error { } defer store.Close() - level := slog.LevelInfo - if verbose { - level = slog.LevelDebug - } - logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) + logger := newLogger(verbose) runners := map[string]executor.Runner{ "claude": &executor.ClaudeRunner{ diff --git a/internal/cli/run_test.go b/internal/cli/run_test.go new file mode 100644 index 0000000..705fe29 --- /dev/null +++ b/internal/cli/run_test.go @@ -0,0 +1,18 @@ +package cli + +import ( + "strings" + "testing" +) + +func TestRunTasks_InvalidParallel(t *testing.T) { + for _, parallel := range []int{0, -1, -100} { + err := runTasks("ignored.yaml", parallel, false) + if err == nil { + t.Fatalf("parallel=%d: expected error, got nil", parallel) + } + if !strings.Contains(err.Error(), "--parallel") { + t.Errorf("parallel=%d: error should mention --parallel flag, got: %v", parallel, err) + } + } +} diff --git a/internal/cli/serve.go b/internal/cli/serve.go index b679b38..36a53b5 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -3,7 +3,6 @@ package cli import ( "context" "fmt" - "log/slog" "net/http" "os" "os/signal" @@ -12,6 +11,7 @@ import ( "github.com/thepeterstone/claudomator/internal/api" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/version" "github.com/spf13/cobra" @@ -44,11 +44,7 @@ func serve(addr string) error { } defer store.Close() - level := slog.LevelInfo - if verbose { - level = slog.LevelDebug - } - logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) + logger := newLogger(verbose) apiURL := "http://localhost" + addr if len(addr) > 0 && addr[0] != ':' { @@ -76,6 +72,9 @@ func serve(addr string) error { } srv := api.NewServer(store, pool, logger, cfg.ClaudeBinaryPath, cfg.GeminiBinaryPath) + if cfg.WebhookURL != "" { + srv.SetNotifier(notify.NewWebhookNotifier(cfg.WebhookURL, logger)) + } srv.StartHub() httpSrv := &http.Server{ @@ -94,7 +93,9 @@ func serve(addr string) error { logger.Info("shutting down server...") shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second) defer shutdownCancel() - httpSrv.Shutdown(shutdownCtx) + if err := httpSrv.Shutdown(shutdownCtx); err != nil { + logger.Warn("shutdown error", "err", err) + } }() fmt.Printf("Claudomator %s listening on %s\n", version.Version(), addr) diff --git a/internal/cli/serve_test.go b/internal/cli/serve_test.go new file mode 100644 index 0000000..6bd0e8f --- /dev/null +++ b/internal/cli/serve_test.go @@ -0,0 +1,91 @@ +package cli + +import ( + "context" + "log/slog" + "net" + "net/http" + "sync" + "testing" + "time" +) + +// recordHandler captures log records for assertions. +type recordHandler struct { + mu sync.Mutex + records []slog.Record +} + +func (h *recordHandler) Enabled(_ context.Context, _ slog.Level) bool { return true } +func (h *recordHandler) Handle(_ context.Context, r slog.Record) error { + h.mu.Lock() + h.records = append(h.records, r) + h.mu.Unlock() + return nil +} +func (h *recordHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } +func (h *recordHandler) WithGroup(_ string) slog.Handler { return h } +func (h *recordHandler) hasWarn(msg string) bool { + h.mu.Lock() + defer h.mu.Unlock() + for _, r := range h.records { + if r.Level == slog.LevelWarn && r.Message == msg { + return true + } + } + return false +} + +// TestServe_ShutdownError_IsLogged verifies that a shutdown timeout error is +// logged as a warning rather than silently dropped. +func TestServe_ShutdownError_IsLogged(t *testing.T) { + // Start a real listener so we have an address. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + + // Handler that hangs so the active connection prevents clean shutdown. + hang := make(chan struct{}) + mux := http.NewServeMux() + mux.HandleFunc("/hang", func(w http.ResponseWriter, r *http.Request) { + <-hang + }) + + srv := &http.Server{Handler: mux} + + // Serve in background. + go srv.Serve(ln) //nolint:errcheck + + // Open a connection and start a hanging request so the server has an + // active connection when we call Shutdown. + addr := ln.Addr().String() + connReady := make(chan struct{}) + go func() { + req, _ := http.NewRequest(http.MethodGet, "http://"+addr+"/hang", nil) + close(connReady) + http.DefaultClient.Do(req) //nolint:errcheck + }() + <-connReady + // Give the goroutine a moment to establish the request. + time.Sleep(20 * time.Millisecond) + + // Shutdown with an already-expired deadline so it times out immediately. + expiredCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) + defer cancel() + + h := &recordHandler{} + logger := slog.New(h) + + // This is the exact logic from serve.go's shutdown goroutine. + if err := srv.Shutdown(expiredCtx); err != nil { + logger.Warn("shutdown error", "err", err) + } + + // Unblock the hanging handler. + close(hang) + + if !h.hasWarn("shutdown error") { + t.Error("expected shutdown error to be logged as Warn, but it was not") + } +} diff --git a/internal/cli/start.go b/internal/cli/start.go index 6ec09b2..9e66e00 100644 --- a/internal/cli/start.go +++ b/internal/cli/start.go @@ -3,7 +3,8 @@ package cli import ( "encoding/json" "fmt" - "net/http" + "io" + "net/url" "github.com/spf13/cobra" ) @@ -25,15 +26,18 @@ func newStartCmd() *cobra.Command { } func startTask(serverURL, id string) error { - url := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, id) - resp, err := http.Post(url, "application/json", nil) //nolint:noctx + url := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, url.PathEscape(id)) + resp, err := httpClient.Post(url, "application/json", nil) //nolint:noctx if err != nil { return fmt.Errorf("POST %s: %w", url, err) } defer resp.Body.Close() + raw, _ := io.ReadAll(resp.Body) var body map[string]string - _ = json.NewDecoder(resp.Body).Decode(&body) + if err := json.Unmarshal(raw, &body); err != nil { + return fmt.Errorf("server returned invalid JSON (status %d): %s", resp.StatusCode, string(raw)) + } if resp.StatusCode >= 300 { return fmt.Errorf("server returned %d: %s", resp.StatusCode, body["error"]) diff --git a/internal/config/config.go b/internal/config/config.go index a66524a..d3d9d68 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,8 @@ package config import ( + "errors" + "fmt" "os" "path/filepath" ) @@ -17,8 +19,14 @@ type Config struct { WebhookURL string `toml:"webhook_url"` } -func Default() *Config { - home, _ := os.UserHomeDir() +func Default() (*Config, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("cannot determine home directory: %w", err) + } + if home == "" { + return nil, errors.New("cannot determine home directory: HOME is empty") + } dataDir := filepath.Join(home, ".claudomator") return &Config{ DataDir: dataDir, @@ -29,7 +37,7 @@ func Default() *Config { MaxConcurrent: 3, DefaultTimeout: "15m", ServerAddr: ":8484", - } + }, nil } // EnsureDirs creates the data directory structure. diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..766b856 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,24 @@ +package config + +import ( + "testing" +) + +func TestDefault_EmptyHome_ReturnsError(t *testing.T) { + t.Setenv("HOME", "") + _, err := Default() + if err == nil { + t.Fatal("expected error when HOME is empty, got nil") + } +} + +func TestDefault_ValidHome_ReturnsConfig(t *testing.T) { + t.Setenv("HOME", "/tmp/testhome") + cfg, err := Default() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.DataDir != "/tmp/testhome/.claudomator" { + t.Errorf("DataDir = %q, want /tmp/testhome/.claudomator", cfg.DataDir) + } +} diff --git a/internal/executor/claude.go b/internal/executor/claude.go index 86a2ba5..e504369 100644 --- a/internal/executor/claude.go +++ b/internal/executor/claude.go @@ -55,10 +55,18 @@ func (r *ClaudeRunner) binaryPath() string { // Run executes a claude -p invocation, streaming output to log files. // It retries up to 3 times on rate-limit errors using exponential backoff. // If the agent writes a question file and exits, Run returns *BlockedError. +// +// When project_dir is set and this is not a resume execution, Run clones the +// project into a temp sandbox, runs the agent there, then merges committed +// changes back to project_dir. On failure the sandbox is preserved and its +// path is included in the error. func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { - if t.Agent.WorkingDir != "" { - if _, err := os.Stat(t.Agent.WorkingDir); err != nil { - return fmt.Errorf("working_dir %q: %w", t.Agent.WorkingDir, err) + projectDir := t.Agent.ProjectDir + + // Validate project_dir exists when set. + if projectDir != "" { + if _, err := os.Stat(projectDir); err != nil { + return fmt.Errorf("project_dir %q: %w", projectDir, err) } } @@ -82,6 +90,20 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi e.SessionID = e.ID // reuse execution UUID as session UUID (both are UUIDs) } + // For new (non-resume) executions with a project_dir, clone into a sandbox. + // Resume executions run directly in project_dir to pick up the previous session. + var sandboxDir string + effectiveWorkingDir := projectDir + if projectDir != "" && e.ResumeSessionID == "" { + var err error + sandboxDir, err = setupSandbox(projectDir) + if err != nil { + return fmt.Errorf("setting up sandbox: %w", err) + } + effectiveWorkingDir = sandboxDir + r.Logger.Info("sandbox created", "sandbox", sandboxDir, "project_dir", projectDir) + } + questionFile := filepath.Join(logDir, "question.json") args := r.buildArgs(t, e, questionFile) @@ -95,9 +117,12 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi ) } attempt++ - return r.execOnce(ctx, args, t.Agent.WorkingDir, e) + return r.execOnce(ctx, args, effectiveWorkingDir, e) }) if err != nil { + if sandboxDir != "" { + return fmt.Errorf("%w (sandbox preserved at %s)", err, sandboxDir) + } return err } @@ -105,8 +130,89 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi data, readErr := os.ReadFile(questionFile) if readErr == nil { os.Remove(questionFile) // consumed + // Preserve sandbox on BLOCKED — agent may have partial work. return &BlockedError{QuestionJSON: strings.TrimSpace(string(data)), SessionID: e.SessionID} } + + // Merge sandbox back to project_dir and clean up. + if sandboxDir != "" { + if mergeErr := teardownSandbox(projectDir, sandboxDir, r.Logger); mergeErr != nil { + return fmt.Errorf("sandbox teardown: %w (sandbox preserved at %s)", mergeErr, sandboxDir) + } + } + return nil +} + +// setupSandbox prepares a temporary git clone of projectDir. +// If projectDir is not a git repo it is initialised with an initial commit first. +func setupSandbox(projectDir string) (string, error) { + // Ensure projectDir is a git repo; initialise if not. + check := exec.Command("git", "-C", projectDir, "rev-parse", "--git-dir") + if err := check.Run(); err != nil { + // Not a git repo — init and commit everything. + cmds := [][]string{ + {"git", "-C", projectDir, "init"}, + {"git", "-C", projectDir, "add", "-A"}, + {"git", "-C", projectDir, "commit", "--allow-empty", "-m", "chore: initial commit"}, + } + for _, args := range cmds { + if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil { //nolint:gosec + return "", fmt.Errorf("git init %s: %w\n%s", projectDir, err, out) + } + } + } + + tempDir, err := os.MkdirTemp("", "claudomator-sandbox-*") + if err != nil { + return "", fmt.Errorf("creating sandbox dir: %w", err) + } + + // Clone into the pre-created dir (git clone requires the target to not exist, + // so remove it first and let git recreate it). + if err := os.Remove(tempDir); err != nil { + return "", fmt.Errorf("removing temp dir placeholder: %w", err) + } + out, err := exec.Command("git", "clone", "--local", projectDir, tempDir).CombinedOutput() + if err != nil { + return "", fmt.Errorf("git clone: %w\n%s", err, out) + } + return tempDir, nil +} + +// teardownSandbox verifies the sandbox is clean, merges commits back to +// projectDir via fast-forward, then removes the sandbox. +func teardownSandbox(projectDir, sandboxDir string, logger *slog.Logger) error { + // Fail if agent left uncommitted changes. + out, err := exec.Command("git", "-C", sandboxDir, "status", "--porcelain").Output() + if err != nil { + return fmt.Errorf("git status: %w", err) + } + if len(strings.TrimSpace(string(out))) > 0 { + return fmt.Errorf("uncommitted changes in sandbox (agent must commit all work):\n%s", out) + } + + // Check whether there are any new commits to merge. + ahead, err := exec.Command("git", "-C", sandboxDir, "rev-list", "--count", "origin/HEAD..HEAD").Output() + if err != nil { + // No origin/HEAD (e.g. fresh init with no prior commits) — proceed anyway. + logger.Warn("could not determine commits ahead of origin; proceeding with merge", "err", err) + } + if strings.TrimSpace(string(ahead)) == "0" { + // Nothing to merge — clean up and return. + os.RemoveAll(sandboxDir) + return nil + } + + // Fetch new commits from sandbox into project_dir and fast-forward merge. + if out, err := exec.Command("git", "-C", projectDir, "fetch", sandboxDir, "HEAD").CombinedOutput(); err != nil { + return fmt.Errorf("git fetch from sandbox: %w\n%s", err, out) + } + if out, err := exec.Command("git", "-C", projectDir, "merge", "--ff-only", "FETCH_HEAD").CombinedOutput(); err != nil { + return fmt.Errorf("git merge --ff-only FETCH_HEAD: %w\n%s", err, out) + } + + logger.Info("sandbox merged and cleaned up", "sandbox", sandboxDir, "project_dir", projectDir) + os.RemoveAll(sandboxDir) return nil } @@ -189,6 +295,11 @@ func (r *ClaudeRunner) execOnce(ctx context.Context, args []string, workingDir s if exitErr, ok := waitErr.(*exec.ExitError); ok { e.ExitCode = exitErr.ExitCode() } + // If the stream captured a rate-limit or quota message, return it + // so callers can distinguish it from a generic exit-status failure. + if isRateLimitError(streamErr) || isQuotaExhausted(streamErr) { + return streamErr + } return fmt.Errorf("claude exited with error: %w", waitErr) } diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go index dba7470..b5380f4 100644 --- a/internal/executor/claude_test.go +++ b/internal/executor/claude_test.go @@ -233,7 +233,7 @@ func TestClaudeRunner_Run_InaccessibleWorkingDir_ReturnsError(t *testing.T) { tk := &task.Task{ Agent: task.AgentConfig{ Type: "claude", - WorkingDir: "/nonexistent/path/does/not/exist", + ProjectDir: "/nonexistent/path/does/not/exist", SkipPlanning: true, }, } @@ -244,8 +244,8 @@ func TestClaudeRunner_Run_InaccessibleWorkingDir_ReturnsError(t *testing.T) { if err == nil { t.Fatal("expected error for inaccessible working_dir, got nil") } - if !strings.Contains(err.Error(), "working_dir") { - t.Errorf("expected 'working_dir' in error, got: %v", err) + if !strings.Contains(err.Error(), "project_dir") { + t.Errorf("expected 'project_dir' in error, got: %v", err) } } diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 6bd1c68..d1c8e72 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -26,12 +26,20 @@ type Runner interface { Run(ctx context.Context, t *task.Task, exec *storage.Execution) error } +// workItem is an entry in the pool's internal work queue. +type workItem struct { + ctx context.Context + task *task.Task + exec *storage.Execution // non-nil for resume submissions +} + // Pool manages a bounded set of concurrent task workers. type Pool struct { - maxConcurrent int - runners map[string]Runner - store *storage.DB - logger *slog.Logger + maxConcurrent int + runners map[string]Runner + store *storage.DB + logger *slog.Logger + depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s mu sync.Mutex active int @@ -39,6 +47,8 @@ type Pool struct { rateLimited map[string]time.Time // agentType -> until cancels map[string]context.CancelFunc // taskID → cancel resultCh chan *Result + workCh chan workItem // internal bounded queue; Submit enqueues here + doneCh chan struct{} // signals when a worker slot is freed Questions *QuestionRegistry Classifier *Classifier } @@ -54,33 +64,57 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, lo if maxConcurrent < 1 { maxConcurrent = 1 } - return &Pool{ - maxConcurrent: maxConcurrent, - runners: runners, - store: store, - logger: logger, - activePerAgent: make(map[string]int), - rateLimited: make(map[string]time.Time), - cancels: make(map[string]context.CancelFunc), - resultCh: make(chan *Result, maxConcurrent*2), - Questions: NewQuestionRegistry(), + p := &Pool{ + maxConcurrent: maxConcurrent, + runners: runners, + store: store, + logger: logger, + depPollInterval: 5 * time.Second, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), + cancels: make(map[string]context.CancelFunc), + resultCh: make(chan *Result, maxConcurrent*2), + workCh: make(chan workItem, maxConcurrent*10+100), + doneCh: make(chan struct{}, maxConcurrent), + Questions: NewQuestionRegistry(), } + go p.dispatch() + return p } -// Submit dispatches a task for execution. Blocks if pool is at capacity. -func (p *Pool) Submit(ctx context.Context, t *task.Task) error { - p.mu.Lock() - if p.active >= p.maxConcurrent { - active := p.active - max := p.maxConcurrent - p.mu.Unlock() - return fmt.Errorf("executor pool at capacity (%d/%d)", active, max) +// dispatch is a long-running goroutine that reads from the internal work queue +// and launches goroutines as soon as a pool slot is available. This prevents +// tasks from being rejected when the pool is temporarily at capacity. +func (p *Pool) dispatch() { + for item := range p.workCh { + for { + p.mu.Lock() + if p.active < p.maxConcurrent { + p.active++ + p.mu.Unlock() + if item.exec != nil { + go p.executeResume(item.ctx, item.task, item.exec) + } else { + go p.execute(item.ctx, item.task) + } + break + } + p.mu.Unlock() + <-p.doneCh // wait for a worker to finish + } } - p.active++ - p.mu.Unlock() +} - go p.execute(ctx, t) - return nil +// Submit enqueues a task for execution. Returns an error only if the internal +// work queue is full. When the pool is at capacity the task is buffered and +// dispatched as soon as a slot becomes available. +func (p *Pool) Submit(ctx context.Context, t *task.Task) error { + select { + case p.workCh <- workItem{ctx: ctx, task: t}: + return nil + default: + return fmt.Errorf("executor work queue full (capacity %d)", cap(p.workCh)) + } } // Results returns the channel for reading execution results. @@ -104,18 +138,18 @@ func (p *Pool) Cancel(taskID string) bool { // SubmitResume re-queues a blocked task using the provided resume execution. // The execution must have ResumeSessionID and ResumeAnswer set. func (p *Pool) SubmitResume(ctx context.Context, t *task.Task, exec *storage.Execution) error { - p.mu.Lock() - if p.active >= p.maxConcurrent { - active := p.active - max := p.maxConcurrent - p.mu.Unlock() - return fmt.Errorf("executor pool at capacity (%d/%d)", active, max) + if t.State != task.StateBlocked && t.State != task.StateTimedOut { + return fmt.Errorf("task %s must be in BLOCKED or TIMED_OUT state to resume (current: %s)", t.ID, t.State) + } + if exec.ResumeSessionID == "" { + return fmt.Errorf("resume execution for task %s must have a ResumeSessionID", t.ID) + } + select { + case p.workCh <- workItem{ctx: ctx, task: t, exec: exec}: + return nil + default: + return fmt.Errorf("executor work queue full (capacity %d)", cap(p.workCh)) } - p.active++ - p.mu.Unlock() - - go p.executeResume(ctx, t, exec) - return nil } func (p *Pool) getRunner(t *task.Task) (Runner, error) { @@ -145,6 +179,10 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex p.active-- p.activePerAgent[agentType]-- p.mu.Unlock() + select { + case p.doneCh <- struct{}{}: + default: + } }() runner, err := p.getRunner(t) @@ -178,7 +216,15 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex } else { ctx, cancel = context.WithCancel(ctx) } - defer cancel() + p.mu.Lock() + p.cancels[t.ID] = cancel + p.mu.Unlock() + defer func() { + cancel() + p.mu.Lock() + delete(p.cancels, t.ID) + p.mu.Unlock() + }() err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() @@ -207,6 +253,10 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex exec.Status = "CANCELLED" exec.ErrorMsg = "execution cancelled" p.store.UpdateTaskState(t.ID, task.StateCancelled) + } else if isQuotaExhausted(err) { + exec.Status = "BUDGET_EXCEEDED" + exec.ErrorMsg = err.Error() + p.store.UpdateTaskState(t.ID, task.StateBudgetExceeded) } else { exec.Status = "FAILED" exec.ErrorMsg = err.Error() @@ -276,6 +326,10 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { p.active-- p.activePerAgent[agentType]-- p.mu.Unlock() + select { + case p.doneCh <- struct{}{}: + default: + } }() runner, err := p.getRunner(t) @@ -390,6 +444,10 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { exec.Status = "CANCELLED" exec.ErrorMsg = "execution cancelled" p.store.UpdateTaskState(t.ID, task.StateCancelled) + } else if isQuotaExhausted(err) { + exec.Status = "BUDGET_EXCEEDED" + exec.ErrorMsg = err.Error() + p.store.UpdateTaskState(t.ID, task.StateBudgetExceeded) } else { exec.Status = "FAILED" exec.ErrorMsg = err.Error() @@ -444,7 +502,7 @@ func (p *Pool) waitForDependencies(ctx context.Context, t *task.Task) error { select { case <-ctx.Done(): return ctx.Err() - case <-time.After(5 * time.Second): + case <-time.After(p.depPollInterval): } } } diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index 9ad0617..028e5cf 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -224,27 +224,35 @@ func TestPool_Cancel_UnknownTask_ReturnsFalse(t *testing.T) { } } -func TestPool_AtCapacity(t *testing.T) { +// TestPool_QueuedWhenAtCapacity verifies that Submit enqueues a task rather than +// returning an error when the pool is at capacity. Both tasks should eventually complete. +func TestPool_QueuedWhenAtCapacity(t *testing.T) { store := testStore(t) - runner := &mockRunner{delay: time.Second} + runner := &mockRunner{delay: 100 * time.Millisecond} runners := map[string]Runner{"claude": runner} logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) pool := NewPool(1, runners, store, logger) - tk1 := makeTask("cap-1") + tk1 := makeTask("queue-1") store.CreateTask(tk1) - pool.Submit(context.Background(), tk1) + if err := pool.Submit(context.Background(), tk1); err != nil { + t.Fatalf("first submit: %v", err) + } - // Pool is at capacity, second submit should fail. - time.Sleep(10 * time.Millisecond) // let goroutine start - tk2 := makeTask("cap-2") + // Second submit must succeed (queued) even though pool slot is taken. + tk2 := makeTask("queue-2") store.CreateTask(tk2) - err := pool.Submit(context.Background(), tk2) - if err == nil { - t.Fatal("expected capacity error") + if err := pool.Submit(context.Background(), tk2); err != nil { + t.Fatalf("second submit: %v — expected task to be queued, not rejected", err) } - <-pool.Results() // drain + // Both tasks must complete. + for i := 0; i < 2; i++ { + r := <-pool.Results() + if r.Err != nil { + t.Errorf("task %s error: %v", r.TaskID, r.Err) + } + } } // logPatherMockRunner is a mockRunner that also implements LogPather, diff --git a/internal/executor/gemini.go b/internal/executor/gemini.go index 3cabed5..956d8b5 100644 --- a/internal/executor/gemini.go +++ b/internal/executor/gemini.go @@ -40,9 +40,9 @@ func (r *GeminiRunner) binaryPath() string { // Run executes a gemini <instructions> invocation, streaming output to log files. func (r *GeminiRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { - if t.Agent.WorkingDir != "" { - if _, err := os.Stat(t.Agent.WorkingDir); err != nil { - return fmt.Errorf("working_dir %q: %w", t.Agent.WorkingDir, err) + if t.Agent.ProjectDir != "" { + if _, err := os.Stat(t.Agent.ProjectDir); err != nil { + return fmt.Errorf("project_dir %q: %w", t.Agent.ProjectDir, err) } } @@ -68,7 +68,7 @@ func (r *GeminiRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi // Gemini CLI doesn't necessarily have the same rate limiting behavior as Claude, // but we'll use a similar execution pattern. - err := r.execOnce(ctx, args, t.Agent.WorkingDir, e) + err := r.execOnce(ctx, args, t.Agent.ProjectDir, e) if err != nil { return err } diff --git a/internal/executor/gemini_test.go b/internal/executor/gemini_test.go index c7acc3c..42253da 100644 --- a/internal/executor/gemini_test.go +++ b/internal/executor/gemini_test.go @@ -63,7 +63,7 @@ func TestGeminiRunner_BuildArgs_PreamblePrepended(t *testing.T) { } } -func TestGeminiRunner_Run_InaccessibleWorkingDir_ReturnsError(t *testing.T) { +func TestGeminiRunner_Run_InaccessibleProjectDir_ReturnsError(t *testing.T) { r := &GeminiRunner{ BinaryPath: "true", // would succeed if it ran Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), @@ -72,7 +72,7 @@ func TestGeminiRunner_Run_InaccessibleWorkingDir_ReturnsError(t *testing.T) { tk := &task.Task{ Agent: task.AgentConfig{ Type: "gemini", - WorkingDir: "/nonexistent/path/does/not/exist", + ProjectDir: "/nonexistent/path/does/not/exist", SkipPlanning: true, }, } @@ -81,10 +81,10 @@ func TestGeminiRunner_Run_InaccessibleWorkingDir_ReturnsError(t *testing.T) { err := r.Run(context.Background(), tk, exec) if err == nil { - t.Fatal("expected error for inaccessible working_dir, got nil") + t.Fatal("expected error for inaccessible project_dir, got nil") } - if !strings.Contains(err.Error(), "working_dir") { - t.Errorf("expected 'working_dir' in error, got: %v", err) + if !strings.Contains(err.Error(), "project_dir") { + t.Errorf("expected 'project_dir' in error, got: %v", err) } } diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go index 884da43..deaad18 100644 --- a/internal/executor/ratelimit.go +++ b/internal/executor/ratelimit.go @@ -13,7 +13,8 @@ var retryAfterRe = regexp.MustCompile(`(?i)retry[-_ ]after[:\s]+(\d+)`) const maxBackoffDelay = 5 * time.Minute -// isRateLimitError returns true if err looks like a Claude API rate-limit response. +// isRateLimitError returns true if err looks like a transient Claude API +// rate-limit that is worth retrying (e.g. per-minute/per-request throttle). func isRateLimitError(err error) bool { if err == nil { return false @@ -25,6 +26,17 @@ func isRateLimitError(err error) bool { strings.Contains(msg, "overloaded") } +// isQuotaExhausted returns true if err indicates the 5-hour usage quota is +// fully exhausted. Unlike transient rate limits, these should not be retried. +func isQuotaExhausted(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "hit your limit") || + strings.Contains(msg, "you've hit your limit") +} + // parseRetryAfter extracts a Retry-After duration from an error message. // Returns 0 if no retry-after value is found. func parseRetryAfter(msg string) time.Duration { diff --git a/internal/storage/db.go b/internal/storage/db.go index c396bbe..0a4f7a5 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -193,21 +193,31 @@ func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) { return tasks, rows.Err() } -// UpdateTaskState atomically updates a task's state. +// UpdateTaskState atomically updates a task's state, enforcing valid transitions. func (s *DB) UpdateTaskState(id string, newState task.State) error { - now := time.Now().UTC() - result, err := s.db.Exec(`UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?`, string(newState), now, id) + tx, err := s.db.Begin() if err != nil { return err } - n, err := result.RowsAffected() - if err != nil { + defer tx.Rollback() //nolint:errcheck + + var currentState string + if err := tx.QueryRow(`SELECT state FROM tasks WHERE id = ?`, id).Scan(¤tState); err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("task %q not found", id) + } return err } - if n == 0 { - return fmt.Errorf("task %q not found", id) + + if !task.ValidTransition(task.State(currentState), newState) { + return fmt.Errorf("invalid state transition %s → %s for task %q", currentState, newState, id) } - return nil + + now := time.Now().UTC() + if _, err := tx.Exec(`UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?`, string(newState), now, id); err != nil { + return err + } + return tx.Commit() } // RejectTask sets a task's state to PENDING and stores the rejection comment. diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index 36f1644..f737096 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -41,7 +41,7 @@ func TestCreateTask_AndGetTask(t *testing.T) { Type: "claude", Model: "sonnet", Instructions: "do it", - WorkingDir: "/tmp", + ProjectDir: "/tmp", MaxBudgetUSD: 2.5, }, Priority: task.PriorityHigh, @@ -124,6 +124,38 @@ func TestUpdateTaskState_NotFound(t *testing.T) { } } +func TestUpdateTaskState_InvalidTransition(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + tk := &task.Task{ + ID: "task-invalid", + Name: "InvalidTransition", + Claude: task.ClaudeConfig{Instructions: "test"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + if err := db.CreateTask(tk); err != nil { + t.Fatal(err) + } + + // PENDING → COMPLETED is not a valid transition. + err := db.UpdateTaskState("task-invalid", task.StateCompleted) + if err == nil { + t.Fatal("expected error for invalid state transition PENDING → COMPLETED") + } + + // State must not have changed. + got, _ := db.GetTask("task-invalid") + if got.State != task.StatePending { + t.Errorf("state must remain PENDING, got %v", got.State) + } +} + func TestListTasks_FilterByState(t *testing.T) { db := testDB(t) now := time.Now().UTC() diff --git a/internal/task/task.go b/internal/task/task.go index c6a321d..6b240dd 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -1,6 +1,9 @@ package task -import "time" +import ( + "encoding/json" + "time" +) type State string @@ -30,7 +33,7 @@ type AgentConfig struct { Model string `yaml:"model" json:"model"` ContextFiles []string `yaml:"context_files" json:"context_files"` Instructions string `yaml:"instructions" json:"instructions"` - WorkingDir string `yaml:"working_dir" json:"working_dir"` + ProjectDir string `yaml:"project_dir" json:"project_dir"` MaxBudgetUSD float64 `yaml:"max_budget_usd" json:"max_budget_usd"` PermissionMode string `yaml:"permission_mode" json:"permission_mode"` AllowedTools []string `yaml:"allowed_tools" json:"allowed_tools"` @@ -40,6 +43,25 @@ type AgentConfig struct { SkipPlanning bool `yaml:"skip_planning" json:"skip_planning"` } +// UnmarshalJSON reads project_dir with fallback to legacy working_dir. +func (c *AgentConfig) UnmarshalJSON(data []byte) error { + type Alias AgentConfig + aux := &struct { + ProjectDir string `json:"project_dir"` + WorkingDir string `json:"working_dir"` // legacy + *Alias + }{Alias: (*Alias)(c)} + if err := json.Unmarshal(data, aux); err != nil { + return err + } + if aux.ProjectDir != "" { + c.ProjectDir = aux.ProjectDir + } else { + c.ProjectDir = aux.WorkingDir + } + return nil +} + type RetryConfig struct { MaxAttempts int `yaml:"max_attempts" json:"max_attempts"` Backoff string `yaml:"backoff" json:"backoff"` // "linear", "exponential" diff --git a/internal/task/validator_test.go b/internal/task/validator_test.go index 5678a00..657d93f 100644 --- a/internal/task/validator_test.go +++ b/internal/task/validator_test.go @@ -12,7 +12,7 @@ func validTask() *Task { Agent: AgentConfig{ Type: "claude", Instructions: "do something", - WorkingDir: "/tmp", + ProjectDir: "/tmp", }, Priority: PriorityNormal, Retry: RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, diff --git a/scripts/deploy b/scripts/deploy index cc51fc1..5f730cc 100755 --- a/scripts/deploy +++ b/scripts/deploy @@ -1,18 +1,34 @@ #!/bin/bash # deploy — Build and deploy claudomator to /site/doot.terst.org -# Usage: ./scripts/deploy +# Usage: ./scripts/deploy [--dirty] # Example: sudo ./scripts/deploy set -euo pipefail +DIRTY=false +for arg in "$@"; do + if [ "$arg" == "--dirty" ]; then + DIRTY=true + fi +done + FQDN="doot.terst.org" SITE_DIR="/site/${FQDN}" BIN_DIR="${SITE_DIR}/bin" SERVICE="claudomator@${FQDN}.service" REPO_DIR="$(cd "$(dirname "$0")/.." && pwd)" -echo "==> Building claudomator..." cd "${REPO_DIR}" + +STASHED=false +if [ "$DIRTY" = false ] && [ -n "$(git status --porcelain)" ]; then + echo "==> Stashing uncommitted changes..." + git stash push -u -m "Auto-stash before deploy" + STASHED=true + trap 'if [ "$STASHED" = true ]; then echo "==> Popping stash..."; git stash pop; fi' EXIT +fi + +echo "==> Building claudomator..." export GOCACHE="${SITE_DIR}/cache/go-build" export GOPATH="${SITE_DIR}/cache/gopath" mkdir -p "${GOCACHE}" "${GOPATH}" @@ -15,6 +15,61 @@ async function fetchTemplates() { return res.json(); } +// Fetches recent executions (last 24h) from /api/executions?since=24h. +// fetchFn defaults to window.fetch; injectable for tests. +async function fetchRecentExecutions(basePath = BASE_PATH, fetchFn = fetch) { + const res = await fetchFn(`${basePath}/api/executions?since=24h`); + if (!res.ok) throw new Error(`HTTP ${res.status}`); + return res.json(); +} + +// Returns only tasks currently in state RUNNING. +function filterRunningTasks(tasks) { + return tasks.filter(t => t.state === 'RUNNING'); +} + +// Returns human-readable elapsed time from an ISO timestamp to now. +function formatElapsed(startISO) { + if (startISO == null) return ''; + const elapsed = Math.floor((Date.now() - new Date(startISO).getTime()) / 1000); + if (elapsed < 0) return '0s'; + const h = Math.floor(elapsed / 3600); + const m = Math.floor((elapsed % 3600) / 60); + const s = elapsed % 60; + if (h > 0) return `${h}h ${m}m`; + if (m > 0) return `${m}m ${s}s`; + return `${s}s`; +} + +// Returns human-readable duration between two ISO timestamps. +// If endISO is null, uses now (for in-progress tasks). +// If startISO is null, returns '--'. +function formatDuration(startISO, endISO) { + if (startISO == null) return '--'; + const start = new Date(startISO).getTime(); + const end = endISO != null ? new Date(endISO).getTime() : Date.now(); + const elapsed = Math.max(0, Math.floor((end - start) / 1000)); + const h = Math.floor(elapsed / 3600); + const m = Math.floor((elapsed % 3600) / 60); + const s = elapsed % 60; + if (h > 0) return `${h}h ${m}m`; + if (m > 0) return `${m}m ${s}s`; + return `${s}s`; +} + +// Returns last max lines from array (for testability). +function extractLogLines(lines, max = 500) { + if (lines.length <= max) return lines; + return lines.slice(lines.length - max); +} + +// Returns a new array of executions sorted by started_at descending. +function sortExecutionsDesc(executions) { + return [...executions].sort((a, b) => + new Date(b.started_at).getTime() - new Date(a.started_at).getTime(), + ); +} + // ── Render ──────────────────────────────────────────────────────────────────── function formatDate(iso) { @@ -173,9 +228,10 @@ function sortTasksByDate(tasks) { // ── Filter ──────────────────────────────────────────────────────────────────── -const HIDE_STATES = new Set(['COMPLETED', 'FAILED']); -const ACTIVE_STATES = new Set(['PENDING', 'QUEUED', 'RUNNING', 'READY', 'BLOCKED']); -const DONE_STATES = new Set(['COMPLETED', 'FAILED', 'TIMED_OUT', 'CANCELLED', 'BUDGET_EXCEEDED']); +const HIDE_STATES = new Set(['COMPLETED', 'FAILED']); +const ACTIVE_STATES = new Set(['PENDING', 'QUEUED', 'RUNNING', 'READY', 'BLOCKED']); +const INTERRUPTED_STATES = new Set(['CANCELLED', 'FAILED']); +const DONE_STATES = new Set(['COMPLETED', 'TIMED_OUT', 'BUDGET_EXCEEDED']); // filterActiveTasks uses its own set (excludes PENDING — tasks "in-flight" only) const _PANEL_ACTIVE_STATES = new Set(['RUNNING', 'READY', 'QUEUED', 'BLOCKED']); @@ -190,8 +246,9 @@ export function filterActiveTasks(tasks) { } export function filterTasksByTab(tasks, tab) { - if (tab === 'active') return tasks.filter(t => ACTIVE_STATES.has(t.state)); - if (tab === 'done') return tasks.filter(t => DONE_STATES.has(t.state)); + if (tab === 'active') return tasks.filter(t => ACTIVE_STATES.has(t.state)); + if (tab === 'interrupted') return tasks.filter(t => INTERRUPTED_STATES.has(t.state)); + if (tab === 'done') return tasks.filter(t => DONE_STATES.has(t.state)); return tasks; } @@ -477,7 +534,7 @@ function createEditForm(task) { form.appendChild(typeLabel); form.appendChild(makeField('Model', 'input', { type: 'text', name: 'model', value: a.model || 'sonnet' })); - form.appendChild(makeField('Working Directory', 'input', { type: 'text', name: 'working_dir', value: a.working_dir || '', placeholder: '/path/to/repo' })); + form.appendChild(makeField('Project Directory', 'input', { type: 'text', name: 'project_dir', value: a.project_dir || a.working_dir || '', placeholder: '/path/to/repo' })); form.appendChild(makeField('Max Budget (USD)', 'input', { type: 'number', name: 'max_budget_usd', step: '0.01', value: a.max_budget_usd != null ? String(a.max_budget_usd) : '1.00' })); form.appendChild(makeField('Timeout', 'input', { type: 'text', name: 'timeout', value: formatDurationForInput(task.timeout) || '15m', placeholder: '15m' })); @@ -530,7 +587,7 @@ async function handleEditSave(taskId, form, saveBtn) { type: get('type'), model: get('model'), instructions: get('instructions'), - working_dir: get('working_dir'), + project_dir: get('project_dir'), max_budget_usd: parseFloat(get('max_budget_usd')), }, timeout: get('timeout'), @@ -812,6 +869,15 @@ async function poll() { const tasks = await fetchTasks(); renderTaskList(tasks); renderActiveTaskList(tasks); + if (isRunningTabActive()) { + renderRunningView(tasks); + fetchRecentExecutions(BASE_PATH, fetch) + .then(execs => renderRunningHistory(execs)) + .catch(() => { + const histEl = document.querySelector('.running-history'); + if (histEl) histEl.innerHTML = '<p class="task-meta">Could not load execution history.</p>'; + }); + } } catch { document.querySelector('.task-list').innerHTML = '<div id="loading">Could not reach server.</div>'; @@ -970,7 +1036,7 @@ async function elaborateTask(prompt, workingDir) { const res = await fetch(`${API_BASE}/api/tasks/elaborate`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ prompt, working_dir: workingDir }), + body: JSON.stringify({ prompt, project_dir: workingDir }), }); if (!res.ok) { let msg = `HTTP ${res.status}`; @@ -1000,14 +1066,14 @@ function buildValidatePayload() { const f = document.getElementById('task-form'); const name = f.querySelector('[name="name"]').value; const instructions = f.querySelector('[name="instructions"]').value; - const working_dir = f.querySelector('[name="working_dir"]').value; + const project_dir = f.querySelector('[name="project_dir"]').value; const model = f.querySelector('[name="model"]').value; const type = f.querySelector('[name="type"]').value; const allowedToolsEl = f.querySelector('[name="allowed_tools"]'); const allowed_tools = allowedToolsEl ? allowedToolsEl.value.split(',').map(s => s.trim()).filter(Boolean) : []; - return { name, agent: { type, instructions, working_dir, model, allowed_tools } }; + return { name, agent: { type, instructions, project_dir, model, allowed_tools } }; } function renderValidationResult(result) { @@ -1121,7 +1187,7 @@ function closeTaskModal() { } async function createTask(formData) { - const selectVal = formData.get('working_dir'); + const selectVal = formData.get('project_dir'); const workingDir = selectVal === '__new__' ? document.getElementById('new-project-input').value.trim() : selectVal; @@ -1132,7 +1198,7 @@ async function createTask(formData) { type: formData.get('type'), model: formData.get('model'), instructions: formData.get('instructions'), - working_dir: workingDir, + project_dir: workingDir, max_budget_usd: parseFloat(formData.get('max_budget_usd')), }, timeout: formData.get('timeout'), @@ -1177,7 +1243,7 @@ async function saveTemplate(formData) { type: formData.get('type'), model: formData.get('model'), instructions: formData.get('instructions'), - working_dir: formData.get('working_dir'), + project_dir: formData.get('project_dir'), max_budget_usd: parseFloat(formData.get('max_budget_usd')), allowed_tools: splitTrim(formData.get('allowed_tools') || ''), }, @@ -1358,7 +1424,7 @@ function renderTaskPanel(task, executions) { makeMetaItem('Type', a.type || 'claude'), makeMetaItem('Model', a.model), makeMetaItem('Max Budget', a.max_budget_usd != null ? `$${a.max_budget_usd.toFixed(2)}` : '—'), - makeMetaItem('Working Dir', a.working_dir), + makeMetaItem('Project Dir', a.project_dir || a.working_dir), makeMetaItem('Permission Mode', a.permission_mode || 'default'), ); if (a.allowed_tools && a.allowed_tools.length > 0) { @@ -1609,6 +1675,288 @@ function closeLogViewer() { activeLogSource = null; } +// ── Running view ─────────────────────────────────────────────────────────────── + +// Map of taskId → EventSource for live log streams in the Running tab. +const runningViewLogSources = {}; + +function renderRunningView(tasks) { + const currentEl = document.querySelector('.running-current'); + if (!currentEl) return; + + const running = filterRunningTasks(tasks); + + // Close SSE streams for tasks that are no longer RUNNING. + for (const [id, src] of Object.entries(runningViewLogSources)) { + if (!running.find(t => t.id === id)) { + src.close(); + delete runningViewLogSources[id]; + } + } + + // Update elapsed spans in place if the same tasks are still running. + const existingCards = currentEl.querySelectorAll('[data-task-id]'); + const existingIds = new Set([...existingCards].map(c => c.dataset.taskId)); + const unchanged = running.length === existingCards.length && + running.every(t => existingIds.has(t.id)); + + if (unchanged) { + updateRunningElapsed(); + return; + } + + // Full re-render. + currentEl.innerHTML = ''; + + const h2 = document.createElement('h2'); + h2.textContent = 'Currently Running'; + currentEl.appendChild(h2); + + if (running.length === 0) { + const empty = document.createElement('p'); + empty.className = 'task-meta'; + empty.textContent = 'No tasks are currently running.'; + currentEl.appendChild(empty); + return; + } + + for (const task of running) { + const card = document.createElement('div'); + card.className = 'running-task-card task-card'; + card.dataset.taskId = task.id; + + const header = document.createElement('div'); + header.className = 'task-card-header'; + + const name = document.createElement('span'); + name.className = 'task-name'; + name.textContent = task.name; + + const badge = document.createElement('span'); + badge.className = 'state-badge'; + badge.dataset.state = task.state; + badge.textContent = task.state; + + const elapsed = document.createElement('span'); + elapsed.className = 'running-elapsed'; + elapsed.dataset.startedAt = task.updated_at ?? ''; + elapsed.textContent = formatElapsed(task.updated_at); + + header.append(name, badge, elapsed); + card.appendChild(header); + + // Parent context (async fetch) + if (task.parent_task_id) { + const parentEl = document.createElement('div'); + parentEl.className = 'task-meta'; + parentEl.textContent = 'Subtask of: …'; + card.appendChild(parentEl); + fetch(`${API_BASE}/api/tasks/${task.parent_task_id}`) + .then(r => r.ok ? r.json() : null) + .then(parent => { + if (parent) parentEl.textContent = `Subtask of: ${parent.name}`; + }) + .catch(() => { parentEl.textContent = ''; }); + } + + // Log area + const logArea = document.createElement('div'); + logArea.className = 'running-log'; + logArea.dataset.logTarget = task.id; + card.appendChild(logArea); + + // Footer with Cancel button + const footer = document.createElement('div'); + footer.className = 'task-card-footer'; + const cancelBtn = document.createElement('button'); + cancelBtn.className = 'btn-cancel'; + cancelBtn.textContent = 'Cancel'; + cancelBtn.addEventListener('click', (e) => { + e.stopPropagation(); + handleCancel(task.id, cancelBtn, footer); + }); + footer.appendChild(cancelBtn); + card.appendChild(footer); + + currentEl.appendChild(card); + + // Open SSE stream if not already streaming for this task. + if (!runningViewLogSources[task.id]) { + startRunningLogStream(task.id, logArea); + } + } +} + +function startRunningLogStream(taskId, logArea) { + fetch(`${API_BASE}/api/executions?task_id=${taskId}&limit=1`) + .then(r => r.ok ? r.json() : []) + .then(execs => { + if (!execs || execs.length === 0) return; + const execId = execs[0].id; + + let userScrolled = false; + logArea.addEventListener('scroll', () => { + const nearBottom = logArea.scrollHeight - logArea.scrollTop - logArea.clientHeight < 50; + userScrolled = !nearBottom; + }); + + const src = new EventSource(`${API_BASE}/api/executions/${execId}/logs/stream`); + runningViewLogSources[taskId] = src; + + src.onmessage = (event) => { + let data; + try { data = JSON.parse(event.data); } catch { return; } + + const line = document.createElement('div'); + line.className = 'log-line'; + + switch (data.type) { + case 'text': { + line.classList.add('log-text'); + line.textContent = data.text ?? data.content ?? ''; + break; + } + case 'tool_use': { + line.classList.add('log-tool-use'); + const toolName = document.createElement('span'); + toolName.className = 'tool-name'; + toolName.textContent = `[${data.name ?? 'Tool'}]`; + line.appendChild(toolName); + const inputStr = data.input ? JSON.stringify(data.input) : ''; + const inputPreview = document.createElement('span'); + inputPreview.textContent = ' ' + inputStr.slice(0, 120); + line.appendChild(inputPreview); + break; + } + case 'cost': { + line.classList.add('log-cost'); + const cost = data.total_cost ?? data.cost ?? 0; + line.textContent = `Cost: $${Number(cost).toFixed(3)}`; + break; + } + default: + return; + } + + logArea.appendChild(line); + // Trim to last 500 lines. + while (logArea.childElementCount > 500) { + logArea.removeChild(logArea.firstElementChild); + } + if (!userScrolled) logArea.scrollTop = logArea.scrollHeight; + }; + + src.addEventListener('done', () => { + src.close(); + delete runningViewLogSources[taskId]; + }); + + src.onerror = () => { + src.close(); + delete runningViewLogSources[taskId]; + const errEl = document.createElement('div'); + errEl.className = 'log-line log-error'; + errEl.textContent = 'Stream closed.'; + logArea.appendChild(errEl); + }; + }) + .catch(() => {}); +} + +function updateRunningElapsed() { + document.querySelectorAll('.running-elapsed[data-started-at]').forEach(el => { + el.textContent = formatElapsed(el.dataset.startedAt || null); + }); +} + +function isRunningTabActive() { + const panel = document.querySelector('[data-panel="running"]'); + return panel && !panel.hasAttribute('hidden'); +} + +function sortExecutionsByDate(executions) { + return sortExecutionsDesc(executions); +} + +function renderRunningHistory(executions) { + const histEl = document.querySelector('.running-history'); + if (!histEl) return; + + histEl.innerHTML = ''; + + const h2 = document.createElement('h2'); + h2.textContent = 'Execution History (Last 24h)'; + histEl.appendChild(h2); + + if (!executions || executions.length === 0) { + const empty = document.createElement('p'); + empty.className = 'task-meta'; + empty.textContent = 'No executions in the last 24h'; + histEl.appendChild(empty); + return; + } + + const sorted = sortExecutionsDesc(executions); + + const table = document.createElement('table'); + table.className = 'history-table'; + + const thead = document.createElement('thead'); + const headerRow = document.createElement('tr'); + for (const col of ['Date', 'Task', 'Status', 'Duration', 'Cost', 'Exit', 'Logs']) { + const th = document.createElement('th'); + th.textContent = col; + headerRow.appendChild(th); + } + thead.appendChild(headerRow); + table.appendChild(thead); + + const tbody = document.createElement('tbody'); + for (const exec of sorted) { + const tr = document.createElement('tr'); + + const tdDate = document.createElement('td'); + tdDate.textContent = formatDate(exec.started_at); + tr.appendChild(tdDate); + + const tdTask = document.createElement('td'); + tdTask.textContent = exec.task_name || exec.task_id || '—'; + tr.appendChild(tdTask); + + const tdStatus = document.createElement('td'); + const stateBadge = document.createElement('span'); + stateBadge.className = 'state-badge'; + stateBadge.dataset.state = exec.state || ''; + stateBadge.textContent = exec.state || '—'; + tdStatus.appendChild(stateBadge); + tr.appendChild(tdStatus); + + const tdDur = document.createElement('td'); + tdDur.textContent = formatDuration(exec.started_at, exec.finished_at ?? null); + tr.appendChild(tdDur); + + const tdCost = document.createElement('td'); + tdCost.textContent = exec.cost_usd > 0 ? `$${exec.cost_usd.toFixed(4)}` : '—'; + tr.appendChild(tdCost); + + const tdExit = document.createElement('td'); + tdExit.textContent = exec.exit_code != null ? String(exec.exit_code) : '—'; + tr.appendChild(tdExit); + + const tdLogs = document.createElement('td'); + const viewBtn = document.createElement('button'); + viewBtn.className = 'btn-sm'; + viewBtn.textContent = 'View Logs'; + viewBtn.addEventListener('click', () => openLogViewer(exec.id, histEl)); + tdLogs.appendChild(viewBtn); + tr.appendChild(tdLogs); + + tbody.appendChild(tr); + } + table.appendChild(tbody); + histEl.appendChild(table); +} + // ── Tab switching ───────────────────────────────────────────────────────────── function switchTab(name) { @@ -1636,6 +1984,19 @@ function switchTab(name) { '<div id="loading">Could not reach server.</div>'; }); } + + if (name === 'running') { + fetchTasks().then(renderRunningView).catch(() => { + const currentEl = document.querySelector('.running-current'); + if (currentEl) currentEl.innerHTML = '<p class="task-meta">Could not reach server.</p>'; + }); + fetchRecentExecutions(BASE_PATH, fetch) + .then(execs => renderRunningHistory(execs)) + .catch(() => { + const histEl = document.querySelector('.running-history'); + if (histEl) histEl.innerHTML = '<p class="task-meta">Could not load execution history.</p>'; + }); + } } // ── Boot ────────────────────────────────────────────────────────────────────── @@ -1655,6 +2016,7 @@ if (typeof document !== 'undefined') document.addEventListener('DOMContentLoaded handleStartNextTask(this); }); + switchTab('running'); startPolling(); connectWebSocket(); @@ -1730,17 +2092,43 @@ if (typeof document !== 'undefined') document.addEventListener('DOMContentLoaded const f = document.getElementById('task-form'); if (result.name) f.querySelector('[name="name"]').value = result.name; +<<<<<<< HEAD if (result.agent && result.agent.instructions) f.querySelector('[name="instructions"]').value = result.agent.instructions; if (result.agent && result.agent.working_dir) { const pSel = document.getElementById('project-select'); const exists = [...pSel.options].some(o => o.value === result.agent.working_dir); +||||||| cad057f + if (result.claude && result.claude.instructions) + f.querySelector('[name="instructions"]').value = result.claude.instructions; + if (result.claude && result.claude.working_dir) { + const sel = document.getElementById('project-select'); + const exists = [...sel.options].some(o => o.value === result.claude.working_dir); +======= + if (result.claude && result.claude.instructions) + f.querySelector('[name="instructions"]').value = result.claude.instructions; + if (result.claude && result.claude.project_dir) { + const sel = document.getElementById('project-select'); + const exists = [...sel.options].some(o => o.value === result.claude.project_dir); +>>>>>>> master if (exists) { +<<<<<<< HEAD pSel.value = result.agent.working_dir; +||||||| cad057f + sel.value = result.claude.working_dir; +======= + sel.value = result.claude.project_dir; +>>>>>>> master } else { pSel.value = '__new__'; document.getElementById('new-project-row').hidden = false; +<<<<<<< HEAD document.getElementById('new-project-input').value = result.agent.working_dir; +||||||| cad057f + document.getElementById('new-project-input').value = result.claude.working_dir; +======= + document.getElementById('new-project-input').value = result.claude.project_dir; +>>>>>>> master } } if (result.agent && result.agent.model) diff --git a/web/index.html b/web/index.html index 629b248..a2800b0 100644 --- a/web/index.html +++ b/web/index.html @@ -15,14 +15,16 @@ <button id="btn-new-task" class="btn-primary">New Task</button> </header> <nav class="tab-bar"> - <button class="tab active" data-tab="tasks">Tasks</button> + <button class="tab" data-tab="tasks">Tasks</button> <button class="tab" data-tab="templates">Templates</button> <button class="tab" data-tab="active">Active</button> + <button class="tab active" data-tab="running">Running</button> </nav> <main id="app"> - <div data-panel="tasks"> + <div data-panel="tasks" hidden> <div class="task-list-toolbar"> <button class="filter-tab active" data-filter="active">Active</button> + <button class="filter-tab" data-filter="interrupted">Interrupted</button> <button class="filter-tab" data-filter="done">Done</button> <button class="filter-tab" data-filter="all">All</button> </div> @@ -40,6 +42,10 @@ <div data-panel="active" hidden> <div class="active-task-list"></div> </div> + <div data-panel="running"> + <div class="running-current"></div> + <div class="running-history"></div> + </div> </main> <dialog id="task-modal"> @@ -57,7 +63,7 @@ </div> <hr class="form-divider"> <label>Project - <select id="project-select" name="working_dir"> + <select name="project_dir" id="project-select"> <option value="/workspace/claudomator" selected>/workspace/claudomator</option> <option value="__new__">Create new project…</option> </select> @@ -113,7 +119,7 @@ <label>Model <input name="model" value="sonnet" placeholder="e.g. sonnet, gemini-2.0-flash"></label> </div> <label>Instructions <textarea name="instructions" rows="6" required></textarea></label> - <label>Working Directory <input name="working_dir" placeholder="/path/to/repo"></label> + <label>Project Directory <input name="project_dir" placeholder="/path/to/repo"></label> <label>Max Budget (USD) <input name="max_budget_usd" type="number" step="0.01" value="1.00"></label> <label>Allowed Tools <input name="allowed_tools" placeholder="Bash, Read, Write"></label> <label>Timeout <input name="timeout" value="15m"></label> diff --git a/web/style.css b/web/style.css index 106ae04..9cfe140 100644 --- a/web/style.css +++ b/web/style.css @@ -1057,6 +1057,74 @@ dialog label select:focus { color: #94a3b8; } +/* ── Running tab ─────────────────────────────────────────────────────────────── */ + +.running-current { + margin-bottom: 2rem; +} + +.running-current h2 { + font-size: 1rem; + font-weight: 600; + color: var(--text-muted); + text-transform: uppercase; + letter-spacing: 0.05em; + margin-bottom: 1rem; +} + +.running-elapsed { + font-size: 0.85rem; + color: var(--state-running); + font-variant-numeric: tabular-nums; +} + +.running-log { + background: var(--bg-card); + border: 1px solid var(--border); + border-radius: 6px; + padding: 0.75rem; + font-family: monospace; + font-size: 0.8rem; + max-height: 300px; + overflow-y: auto; + white-space: pre-wrap; + word-break: break-word; +} + +.running-history { + margin-top: 1.5rem; + overflow-x: auto; +} + +.running-history h2 { + font-size: 1rem; + font-weight: 600; + color: var(--text-muted); + text-transform: uppercase; + letter-spacing: 0.05em; + margin-bottom: 1rem; +} + +.history-table { + width: 100%; + border-collapse: collapse; + font-size: 0.875rem; +} + +.history-table th { + text-align: left; + padding: 0.5rem 0.75rem; + border-bottom: 1px solid var(--border); + color: var(--text-muted); + font-weight: 500; +} + +.history-table td { + padding: 0.5rem 0.75rem; + border-bottom: 1px solid var(--border); + vertical-align: middle; +} + /* ── Task delete button ──────────────────────────────────────────────────── */ .task-card { diff --git a/web/test/active-pane.test.mjs b/web/test/active-pane.test.mjs new file mode 100644 index 0000000..37bb8c5 --- /dev/null +++ b/web/test/active-pane.test.mjs @@ -0,0 +1,81 @@ +// active-pane.test.mjs — Tests for Active pane partition logic. +// +// Run with: node --test web/test/active-pane.test.mjs + +import { describe, it } from 'node:test'; +import assert from 'node:assert/strict'; +import { partitionActivePaneTasks } from '../app.js'; + +function makeTask(id, state, created_at) { + return { id, name: `task-${id}`, state, created_at: created_at ?? `2024-01-01T00:0${id}:00Z` }; +} + +const ALL_STATES = [ + 'PENDING', 'QUEUED', 'RUNNING', 'READY', 'BLOCKED', + 'COMPLETED', 'FAILED', 'TIMED_OUT', 'CANCELLED', 'BUDGET_EXCEEDED', +]; + +describe('partitionActivePaneTasks', () => { + it('running contains only RUNNING tasks', () => { + const tasks = ALL_STATES.map((s, i) => makeTask(String(i), s)); + const { running } = partitionActivePaneTasks(tasks); + assert.equal(running.length, 1); + assert.equal(running[0].state, 'RUNNING'); + }); + + it('ready contains only READY tasks', () => { + const tasks = ALL_STATES.map((s, i) => makeTask(String(i), s)); + const { ready } = partitionActivePaneTasks(tasks); + assert.equal(ready.length, 1); + assert.equal(ready[0].state, 'READY'); + }); + + it('excludes QUEUED, BLOCKED, PENDING, COMPLETED, FAILED and all other states', () => { + const tasks = ALL_STATES.map((s, i) => makeTask(String(i), s)); + const { running, ready } = partitionActivePaneTasks(tasks); + const allReturned = [...running, ...ready]; + assert.equal(allReturned.length, 2); + assert.ok(allReturned.every(t => t.state === 'RUNNING' || t.state === 'READY')); + }); + + it('returns empty arrays for empty input', () => { + const { running, ready } = partitionActivePaneTasks([]); + assert.deepEqual(running, []); + assert.deepEqual(ready, []); + }); + + it('handles multiple RUNNING tasks sorted by created_at ascending', () => { + const tasks = [ + makeTask('b', 'RUNNING', '2024-01-01T00:02:00Z'), + makeTask('a', 'RUNNING', '2024-01-01T00:01:00Z'), + makeTask('c', 'RUNNING', '2024-01-01T00:03:00Z'), + ]; + const { running } = partitionActivePaneTasks(tasks); + assert.equal(running.length, 3); + assert.equal(running[0].id, 'a'); + assert.equal(running[1].id, 'b'); + assert.equal(running[2].id, 'c'); + }); + + it('handles multiple READY tasks sorted by created_at ascending', () => { + const tasks = [ + makeTask('y', 'READY', '2024-01-01T00:02:00Z'), + makeTask('x', 'READY', '2024-01-01T00:01:00Z'), + ]; + const { ready } = partitionActivePaneTasks(tasks); + assert.equal(ready.length, 2); + assert.equal(ready[0].id, 'x'); + assert.equal(ready[1].id, 'y'); + }); + + it('returns both sections independently when both states present', () => { + const tasks = [ + makeTask('r1', 'RUNNING', '2024-01-01T00:01:00Z'), + makeTask('d1', 'READY', '2024-01-01T00:02:00Z'), + makeTask('r2', 'RUNNING', '2024-01-01T00:03:00Z'), + ]; + const { running, ready } = partitionActivePaneTasks(tasks); + assert.equal(running.length, 2); + assert.equal(ready.length, 1); + }); +}); diff --git a/web/test/filter-tabs.test.mjs b/web/test/filter-tabs.test.mjs index 44cfaf6..3a4e569 100644 --- a/web/test/filter-tabs.test.mjs +++ b/web/test/filter-tabs.test.mjs @@ -1,9 +1,5 @@ // filter-tabs.test.mjs — TDD contract tests for filterTasksByTab // -// filterTasksByTab is defined inline here to establish expected behaviour. -// Once filterTasksByTab is exported from web/app.js, remove the inline -// definition and import it instead. -// // Run with: node --test web/test/filter-tabs.test.mjs import { describe, it } from 'node:test'; @@ -45,15 +41,45 @@ describe('filterTasksByTab — active tab', () => { }); }); +describe('filterTasksByTab — interrupted tab', () => { + it('includes CANCELLED and FAILED', () => { + const tasks = ALL_STATES.map(makeTask); + const result = filterTasksByTab(tasks, 'interrupted'); + for (const state of ['CANCELLED', 'FAILED']) { + assert.ok(result.some(t => t.state === state), `${state} should be included`); + } + }); + + it('excludes all non-interrupted states', () => { + const tasks = ALL_STATES.map(makeTask); + const result = filterTasksByTab(tasks, 'interrupted'); + for (const state of ['PENDING', 'QUEUED', 'RUNNING', 'READY', 'BLOCKED', 'COMPLETED', 'TIMED_OUT', 'BUDGET_EXCEEDED']) { + assert.ok(!result.some(t => t.state === state), `${state} should be excluded`); + } + }); + + it('returns empty array for empty input', () => { + assert.deepEqual(filterTasksByTab([], 'interrupted'), []); + }); +}); + describe('filterTasksByTab — done tab', () => { - it('includes COMPLETED, FAILED, TIMED_OUT, CANCELLED, BUDGET_EXCEEDED', () => { + it('includes COMPLETED, TIMED_OUT, BUDGET_EXCEEDED', () => { const tasks = ALL_STATES.map(makeTask); const result = filterTasksByTab(tasks, 'done'); - for (const state of ['COMPLETED', 'FAILED', 'TIMED_OUT', 'CANCELLED', 'BUDGET_EXCEEDED']) { + for (const state of ['COMPLETED', 'TIMED_OUT', 'BUDGET_EXCEEDED']) { assert.ok(result.some(t => t.state === state), `${state} should be included`); } }); + it('excludes CANCELLED and FAILED (moved to interrupted tab)', () => { + const tasks = ALL_STATES.map(makeTask); + const result = filterTasksByTab(tasks, 'done'); + for (const state of ['CANCELLED', 'FAILED']) { + assert.ok(!result.some(t => t.state === state), `${state} should be excluded from done`); + } + }); + it('excludes PENDING, QUEUED, RUNNING, READY, BLOCKED', () => { const tasks = ALL_STATES.map(makeTask); const result = filterTasksByTab(tasks, 'done'); diff --git a/web/test/focus-preserve.test.mjs b/web/test/focus-preserve.test.mjs new file mode 100644 index 0000000..8acf73c --- /dev/null +++ b/web/test/focus-preserve.test.mjs @@ -0,0 +1,170 @@ +// focus-preserve.test.mjs — contract tests for captureFocusState / restoreFocusState +// +// These pure helpers fix the focus-stealing bug: poll() calls renderTaskList / +// renderActiveTaskList which do container.innerHTML='' on every tick, destroying +// any focused answer input (task-answer-input or question-input). +// +// captureFocusState(container, activeEl) +// Returns {taskId, className, value} if activeEl is a focusable answer input +// inside a .task-card within container. Returns null otherwise. +// +// restoreFocusState(container, state) +// Finds the equivalent input after rebuild and restores .value + .focus(). +// +// Run with: node --test web/test/focus-preserve.test.mjs + +import { describe, it } from 'node:test'; +import assert from 'node:assert/strict'; + +// ── Inline implementations (contract) ───────────────────────────────────────── + +function captureFocusState(container, activeEl) { + if (!activeEl || !container.contains(activeEl)) return null; + const card = activeEl.closest('.task-card'); + if (!card || !card.dataset || !card.dataset.taskId) return null; + return { + taskId: card.dataset.taskId, + className: activeEl.className, + value: activeEl.value || '', + }; +} + +function restoreFocusState(container, state) { + if (!state) return; + const card = container.querySelector(`.task-card[data-task-id="${state.taskId}"]`); + if (!card) return; + const el = card.querySelector(`.${state.className}`); + if (!el) return; + el.value = state.value; + el.focus(); +} + +// ── DOM-like mock helpers ────────────────────────────────────────────────────── + +function makeInput(className, value = '', taskId = 't1') { + const card = { + dataset: { taskId }, + _children: [], + querySelector(sel) { + const cls = sel.replace(/^\./, ''); + return this._children.find(c => c.className === cls) || null; + }, + closest(sel) { + return sel === '.task-card' ? this : null; + }, + }; + const input = { + className, + value, + _focused: false, + focus() { this._focused = true; }, + closest(sel) { return card.closest(sel); }, + }; + card._children.push(input); + return { card, input }; +} + +function makeContainer(cards = []) { + const allInputs = cards.flatMap(c => c._children); + return { + contains(el) { return allInputs.includes(el); }, + querySelector(sel) { + const m = sel.match(/\.task-card\[data-task-id="([^"]+)"\]/); + if (!m) return null; + return cards.find(c => c.dataset.taskId === m[1]) || null; + }, + }; +} + +// ── Tests: captureFocusState ─────────────────────────────────────────────────── + +describe('captureFocusState', () => { + it('returns null when activeEl is null', () => { + assert.strictEqual(captureFocusState(makeContainer([]), null), null); + }); + + it('returns null when activeEl is undefined', () => { + assert.strictEqual(captureFocusState(makeContainer([]), undefined), null); + }); + + it('returns null when activeEl is outside the container', () => { + const { input } = makeInput('task-answer-input', 'hello', 't1'); + const container = makeContainer([]); // empty — input not in it + assert.strictEqual(captureFocusState(container, input), null); + }); + + it('returns null when activeEl has no .task-card ancestor', () => { + const input = { + className: 'task-answer-input', + value: 'hi', + closest() { return null; }, + }; + const container = { contains() { return true; }, querySelector() { return null; } }; + assert.strictEqual(captureFocusState(container, input), null); + }); + + it('returns state for task-answer-input inside a task card', () => { + const { card, input } = makeInput('task-answer-input', 'partial answer', 't42'); + const state = captureFocusState(makeContainer([card]), input); + assert.deepStrictEqual(state, { + taskId: 't42', + className: 'task-answer-input', + value: 'partial answer', + }); + }); + + it('returns state for question-input inside a task card', () => { + const { card, input } = makeInput('question-input', 'my answer', 'q99'); + const state = captureFocusState(makeContainer([card]), input); + assert.deepStrictEqual(state, { + taskId: 'q99', + className: 'question-input', + value: 'my answer', + }); + }); + + it('returns empty string value when input is empty', () => { + const { card, input } = makeInput('task-answer-input', '', 't1'); + const state = captureFocusState(makeContainer([card]), input); + assert.strictEqual(state.value, ''); + }); +}); + +// ── Tests: restoreFocusState ─────────────────────────────────────────────────── + +describe('restoreFocusState', () => { + it('is a no-op when state is null', () => { + restoreFocusState(makeContainer([]), null); // must not throw + }); + + it('is a no-op when state is undefined', () => { + restoreFocusState(makeContainer([]), undefined); // must not throw + }); + + it('is a no-op when task card is no longer in container', () => { + const state = { taskId: 'gone', className: 'task-answer-input', value: 'hi' }; + restoreFocusState(makeContainer([]), state); // must not throw + }); + + it('restores value and focuses task-answer-input', () => { + const { card, input } = makeInput('task-answer-input', '', 't1'); + const state = { taskId: 't1', className: 'task-answer-input', value: 'restored text' }; + restoreFocusState(makeContainer([card]), state); + assert.strictEqual(input.value, 'restored text'); + assert.ok(input._focused, 'input should have been focused'); + }); + + it('restores value and focuses question-input', () => { + const { card, input } = makeInput('question-input', '', 'q7'); + const state = { taskId: 'q7', className: 'question-input', value: 'type answer' }; + restoreFocusState(makeContainer([card]), state); + assert.strictEqual(input.value, 'type answer'); + assert.ok(input._focused); + }); + + it('is a no-op when element className is not found in rebuilt card', () => { + const { card } = makeInput('task-answer-input', '', 't1'); + const state = { taskId: 't1', className: 'nonexistent-class', value: 'hi' }; + restoreFocusState(makeContainer([card]), state); // must not throw + }); +}); diff --git a/web/test/is-user-editing.test.mjs b/web/test/is-user-editing.test.mjs new file mode 100644 index 0000000..844d3cd --- /dev/null +++ b/web/test/is-user-editing.test.mjs @@ -0,0 +1,65 @@ +// is-user-editing.test.mjs — contract tests for isUserEditing() +// +// isUserEditing(activeEl) returns true when the browser has focus in an element +// that a poll-driven DOM refresh would destroy: INPUT, TEXTAREA, contenteditable, +// or any element inside a [role="dialog"]. +// +// Run with: node --test web/test/is-user-editing.test.mjs + +import { describe, it } from 'node:test'; +import assert from 'node:assert/strict'; +import { isUserEditing } from '../app.js'; + +// ── Mock helpers ─────────────────────────────────────────────────────────────── + +function makeEl(tagName, extras = {}) { + return { + tagName: tagName.toUpperCase(), + isContentEditable: false, + closest(sel) { return null; }, + ...extras, + }; +} + +// ── Tests ────────────────────────────────────────────────────────────────────── + +describe('isUserEditing', () => { + it('returns false for null', () => { + assert.strictEqual(isUserEditing(null), false); + }); + + it('returns false for undefined', () => { + assert.strictEqual(isUserEditing(undefined), false); + }); + + it('returns true for INPUT element', () => { + assert.strictEqual(isUserEditing(makeEl('INPUT')), true); + }); + + it('returns true for TEXTAREA element', () => { + assert.strictEqual(isUserEditing(makeEl('TEXTAREA')), true); + }); + + it('returns true for contenteditable element', () => { + assert.strictEqual(isUserEditing(makeEl('DIV', { isContentEditable: true })), true); + }); + + it('returns true for element inside [role="dialog"]', () => { + const el = makeEl('SPAN', { + closest(sel) { return sel === '[role="dialog"]' ? {} : null; }, + }); + assert.strictEqual(isUserEditing(el), true); + }); + + it('returns false for a non-editing BUTTON', () => { + assert.strictEqual(isUserEditing(makeEl('BUTTON')), false); + }); + + it('returns false for a non-editing DIV without contenteditable', () => { + assert.strictEqual(isUserEditing(makeEl('DIV')), false); + }); + + it('returns false for a non-editing SPAN not inside a dialog', () => { + assert.strictEqual(isUserEditing(makeEl('SPAN')), false); + }); +}); diff --git a/web/test/render-dedup.test.mjs b/web/test/render-dedup.test.mjs new file mode 100644 index 0000000..f13abb2 --- /dev/null +++ b/web/test/render-dedup.test.mjs @@ -0,0 +1,125 @@ +// render-dedup.test.mjs — contract tests for renderTaskList dedup logic +// +// Verifies the invariant: renderTaskList must never leave two .task-card elements +// with the same data-task-id in the container. When a card already exists but +// has no input field, the old card must be removed before inserting the new one. +// +// This file uses inline implementations that mirror the contract, not the actual +// DOM (which requires a browser). The test defines the expected behaviour so that +// a regression in app.js would motivate a failing test. +// +// Run with: node --test web/test/render-dedup.test.mjs + +import { describe, it } from 'node:test'; +import assert from 'node:assert/strict'; + +// ── Inline DOM mock ──────────────────────────────────────────────────────────── + +function makeCard(taskId, hasInput = false) { + const card = { + dataset: { taskId }, + _removed: false, + _hasInput: hasInput, + remove() { this._removed = true; }, + querySelector(sel) { + if (!this._hasInput) return null; + // simulate .task-answer-input or .question-input being present + if (sel === '.task-answer-input, .question-input') { + return { className: 'task-answer-input', value: 'partial' }; + } + return null; + }, + }; + return card; +} + +// Minimal container mirroring what renderTaskList works with. +function makeContainer(existingCards = []) { + const cards = [...existingCards]; + const inserted = []; + return { + _cards: cards, + _inserted: inserted, + querySelectorAll(sel) { + if (sel === '.task-card') return [...cards]; + return []; + }, + querySelector(sel) { + const m = sel.match(/\.task-card\[data-task-id="([^"]+)"\]/); + if (!m) return null; + return cards.find(c => c.dataset.taskId === m[1] && !c._removed) || null; + }, + insertBefore(node, ref) { + inserted.push(node); + if (!cards.includes(node)) cards.push(node); + }, + get firstChild() { return cards[0] || null; }, + }; +} + +// The fixed dedup logic extracted from renderTaskList (the contract we enforce). +function selectCardForTask(task, container) { + const existing = container.querySelector(`.task-card[data-task-id="${task.id}"]`); + const hasInput = existing?.querySelector('.task-answer-input, .question-input'); + + let node; + if (existing && hasInput) { + node = existing; // reuse — preserves in-progress input + } else { + if (existing) existing.remove(); // <-- the fix: remove old before inserting new + node = makeCard(task.id, false); // simulates createTaskCard(task) + } + return node; +} + +// ── Tests ────────────────────────────────────────────────────────────────────── + +describe('renderTaskList dedup logic', () => { + it('creates a new card when no existing card in DOM', () => { + const container = makeContainer([]); + const task = { id: 't1' }; + const node = selectCardForTask(task, container); + assert.equal(node.dataset.taskId, 't1'); + assert.equal(node._removed, false); + }); + + it('removes old card and creates new when existing has no input', () => { + const old = makeCard('t2', false); + const container = makeContainer([old]); + const task = { id: 't2' }; + const node = selectCardForTask(task, container); + + // Old card must be removed to prevent duplication. + assert.equal(old._removed, true, 'old card should be marked removed'); + // New card returned is not the old card. + assert.notEqual(node, old); + assert.equal(node.dataset.taskId, 't2'); + }); + + it('reuses existing card when it has an input (preserves typing)', () => { + const existing = makeCard('t3', true); // has input + const container = makeContainer([existing]); + const task = { id: 't3' }; + const node = selectCardForTask(task, container); + + assert.equal(node, existing, 'should reuse the existing card'); + assert.equal(existing._removed, false, 'existing card should NOT be removed'); + }); + + it('never produces two cards for the same task id', () => { + // Simulate two poll cycles. + const old = makeCard('t4', false); + const container = makeContainer([old]); + const task = { id: 't4' }; + + // First "refresh" — old card has no input, so remove and insert new. + const newCard = selectCardForTask(task, container); + // Simulate insert: mark old as removed (done by remove()), add new. + container._cards.splice(container._cards.indexOf(old), 1); + if (!container._cards.includes(newCard)) container._cards.push(newCard); + + // Verify at most one card with this id exists. + const survivors = container._cards.filter(c => c.dataset.taskId === 't4' && !c._removed); + assert.equal(survivors.length, 1, 'exactly one card for t4 should remain'); + }); +}); |
