diff options
Diffstat (limited to 'internal')
54 files changed, 7194 insertions, 727 deletions
diff --git a/internal/api/changestats.go b/internal/api/changestats.go deleted file mode 100644 index 4f18f7f..0000000 --- a/internal/api/changestats.go +++ /dev/null @@ -1,15 +0,0 @@ -package api - -import "github.com/thepeterstone/claudomator/internal/task" - -// parseChangestatFromOutput delegates to task.ParseChangestatFromOutput. -// Kept as a package-local wrapper for use within the api package. -func parseChangestatFromOutput(output string) *task.Changestats { - return task.ParseChangestatFromOutput(output) -} - -// parseChangestatFromFile delegates to task.ParseChangestatFromFile. -// Kept as a package-local wrapper for use within the api package. -func parseChangestatFromFile(path string) *task.Changestats { - return task.ParseChangestatFromFile(path) -} diff --git a/internal/api/deployment.go b/internal/api/deployment.go index d927545..8972fe2 100644 --- a/internal/api/deployment.go +++ b/internal/api/deployment.go @@ -23,7 +23,7 @@ func (s *Server) handleGetDeploymentStatus(w http.ResponseWriter, r *http.Reques if err != nil { if err == sql.ErrNoRows { // No execution yet — return status with no fix commits. - status := deployment.Check(nil, tk.Agent.ProjectDir) + status := deployment.Check(nil, tk.RepositoryURL) writeJSON(w, http.StatusOK, status) return } @@ -31,6 +31,6 @@ func (s *Server) handleGetDeploymentStatus(w http.ResponseWriter, r *http.Reques return } - status := deployment.Check(exec.Commits, tk.Agent.ProjectDir) + status := deployment.Check(exec.Commits, tk.RepositoryURL) writeJSON(w, http.StatusOK, status) } diff --git a/internal/api/drops.go b/internal/api/drops.go new file mode 100644 index 0000000..a5000f1 --- /dev/null +++ b/internal/api/drops.go @@ -0,0 +1,165 @@ +package api + +import ( + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +// handleListDrops returns a JSON array of files in the drops directory. +func (s *Server) handleListDrops(w http.ResponseWriter, r *http.Request) { + if s.dropsDir == "" { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "drops directory not configured"}) + return + } + + entries, err := os.ReadDir(s.dropsDir) + if err != nil { + if os.IsNotExist(err) { + writeJSON(w, http.StatusOK, []map[string]interface{}{}) + return + } + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list drops"}) + return + } + + type fileEntry struct { + Name string `json:"name"` + Size int64 `json:"size"` + Modified time.Time `json:"modified"` + } + files := []fileEntry{} + for _, e := range entries { + if e.IsDir() { + continue + } + info, err := e.Info() + if err != nil { + continue + } + files = append(files, fileEntry{ + Name: e.Name(), + Size: info.Size(), + Modified: info.ModTime().UTC(), + }) + } + writeJSON(w, http.StatusOK, files) +} + +// handleGetDrop serves a file from the drops directory as an attachment. +func (s *Server) handleGetDrop(w http.ResponseWriter, r *http.Request) { + if s.dropsDir == "" { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "drops directory not configured"}) + return + } + + filename := r.PathValue("filename") + if strings.Contains(filename, "/") || strings.Contains(filename, "..") { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid filename"}) + return + } + + path := filepath.Join(s.dropsDir, filepath.Clean(filename)) + // Extra safety: ensure the resolved path is still inside dropsDir. + if !strings.HasPrefix(path, s.dropsDir) { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid filename"}) + return + } + + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "file not found"}) + return + } + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to open file"}) + return + } + defer f.Close() + + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename)) + w.Header().Set("Content-Type", "application/octet-stream") + io.Copy(w, f) //nolint:errcheck +} + +// handlePostDrop accepts a file upload (multipart/form-data or raw body with ?filename=). +func (s *Server) handlePostDrop(w http.ResponseWriter, r *http.Request) { + if s.dropsDir == "" { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "drops directory not configured"}) + return + } + + if err := os.MkdirAll(s.dropsDir, 0700); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to create drops directory"}) + return + } + + ct := r.Header.Get("Content-Type") + if strings.Contains(ct, "multipart/form-data") { + s.handleMultipartDrop(w, r) + return + } + + // Raw body with ?filename= query param. + filename := r.URL.Query().Get("filename") + if filename == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "filename query param required for raw upload"}) + return + } + if strings.Contains(filename, "/") || strings.Contains(filename, "..") { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid filename"}) + return + } + path := filepath.Join(s.dropsDir, filename) + data, err := io.ReadAll(r.Body) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to read body"}) + return + } + if err := os.WriteFile(path, data, 0600); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to save file"}) + return + } + writeJSON(w, http.StatusCreated, map[string]interface{}{"name": filename, "size": len(data)}) +} + +func (s *Server) handleMultipartDrop(w http.ResponseWriter, r *http.Request) { + if err := r.ParseMultipartForm(32 << 20); err != nil { // 32 MB limit + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "failed to parse multipart form: " + err.Error()}) + return + } + + file, header, err := r.FormFile("file") + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "missing 'file' field: " + err.Error()}) + return + } + defer file.Close() + + filename := filepath.Base(header.Filename) + if filename == "" || filename == "." { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid filename"}) + return + } + + path := filepath.Join(s.dropsDir, filename) + dst, err := os.Create(path) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to create file"}) + return + } + defer dst.Close() + + n, err := io.Copy(dst, file) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to write file"}) + return + } + + writeJSON(w, http.StatusCreated, map[string]interface{}{"name": filename, "size": n}) +} + diff --git a/internal/api/drops_test.go b/internal/api/drops_test.go new file mode 100644 index 0000000..ab67489 --- /dev/null +++ b/internal/api/drops_test.go @@ -0,0 +1,159 @@ +package api + +import ( + "bytes" + "encoding/json" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func testServerWithDrops(t *testing.T) (*Server, string) { + t.Helper() + srv, _ := testServer(t) + dropsDir := t.TempDir() + srv.SetDropsDir(dropsDir) + return srv, dropsDir +} + +func TestHandleListDrops_Empty(t *testing.T) { + srv, _ := testServerWithDrops(t) + + req := httptest.NewRequest("GET", "/api/drops", nil) + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("want 200, got %d", rec.Code) + } + + var files []map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&files); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(files) != 0 { + t.Errorf("want empty list, got %d entries", len(files)) + } +} + +func TestHandleListDrops_WithFile(t *testing.T) { + srv, dropsDir := testServerWithDrops(t) + + // Create a file in the drops dir. + if err := os.WriteFile(filepath.Join(dropsDir, "hello.txt"), []byte("world"), 0600); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("GET", "/api/drops", nil) + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("want 200, got %d: %s", rec.Code, rec.Body.String()) + } + + var files []map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&files); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(files) != 1 { + t.Fatalf("want 1 file, got %d", len(files)) + } + if files[0]["name"] != "hello.txt" { + t.Errorf("name: want %q, got %v", "hello.txt", files[0]["name"]) + } +} + +func TestHandlePostDrop_Multipart(t *testing.T) { + srv, dropsDir := testServerWithDrops(t) + + var buf bytes.Buffer + w := multipart.NewWriter(&buf) + fw, err := w.CreateFormFile("file", "test.txt") + if err != nil { + t.Fatal(err) + } + fw.Write([]byte("hello world")) //nolint:errcheck + w.Close() + + req := httptest.NewRequest("POST", "/api/drops", &buf) + req.Header.Set("Content-Type", w.FormDataContentType()) + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("want 201, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp["name"] != "test.txt" { + t.Errorf("name: want %q, got %v", "test.txt", resp["name"]) + } + + // Verify file was created on disk. + content, err := os.ReadFile(filepath.Join(dropsDir, "test.txt")) + if err != nil { + t.Fatalf("reading uploaded file: %v", err) + } + if string(content) != "hello world" { + t.Errorf("content: want %q, got %q", "hello world", content) + } +} + +func TestHandleGetDrop_Download(t *testing.T) { + srv, dropsDir := testServerWithDrops(t) + + if err := os.WriteFile(filepath.Join(dropsDir, "download.txt"), []byte("download me"), 0600); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("GET", "/api/drops/download.txt", nil) + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("want 200, got %d", rec.Code) + } + + cd := rec.Header().Get("Content-Disposition") + if !strings.Contains(cd, "attachment") { + t.Errorf("want Content-Disposition: attachment, got %q", cd) + } + if rec.Body.String() != "download me" { + t.Errorf("body: want %q, got %q", "download me", rec.Body.String()) + } +} + +func TestHandleGetDrop_PathTraversal(t *testing.T) { + srv, _ := testServerWithDrops(t) + + // Attempt path traversal — should be rejected. + req := httptest.NewRequest("GET", "/api/drops/..%2Fetc%2Fpasswd", nil) + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + // The Go net/http router will handle %2F-encoded slashes as literal characters, + // so the filename becomes "../etc/passwd". Our handler should reject it. + if rec.Code == http.StatusOK { + t.Error("expected non-200 for path traversal attempt") + } +} + +func TestHandleGetDrop_NotFound(t *testing.T) { + srv, _ := testServerWithDrops(t) + + req := httptest.NewRequest("GET", "/api/drops/notexist.txt", nil) + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("want 404, got %d", rec.Code) + } +} diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go index 30095c8..8676b36 100644 --- a/internal/api/elaborate.go +++ b/internal/api/elaborate.go @@ -172,7 +172,7 @@ func readProjectContext(workDir string) string { return "" } var sb strings.Builder - for _, filename := range []string{"CLAUDE.md", "SESSION_STATE.md"} { + for _, filename := range []string{"CLAUDE.md", ".agent/worklog.md"} { path := filepath.Join(workDir, filename) if data, err := os.ReadFile(path); err == nil { if sb.Len() > 0 { @@ -303,6 +303,197 @@ func (s *Server) elaborateWithGemini(ctx context.Context, workDir, fullPrompt st return &result, nil } +// elaboratedStorySubtask is a leaf unit within a story task. +type elaboratedStorySubtask struct { + Name string `json:"name"` + Instructions string `json:"instructions"` +} + +// elaboratedStoryTask is one independently-deployable unit in a story plan. +type elaboratedStoryTask struct { + Name string `json:"name"` + Instructions string `json:"instructions"` + AcceptanceCriteria string `json:"acceptance_criteria"` + Subtasks []elaboratedStorySubtask `json:"subtasks"` +} + +// elaboratedStoryValidation describes how to verify the story was successful. +type elaboratedStoryValidation struct { + Type string `json:"type"` + Steps []string `json:"steps"` + SuccessCriteria string `json:"success_criteria"` +} + +// elaboratedStory is the full implementation plan produced by story elaboration. +type elaboratedStory struct { + Name string `json:"name"` + BranchName string `json:"branch_name"` + Tasks []elaboratedStoryTask `json:"tasks"` + Validation elaboratedStoryValidation `json:"validation"` +} + +func buildStoryElaboratePrompt() string { + return `You are a software architect. Given a goal, analyze the codebase at /workspace and produce a structured implementation plan as JSON. + +Output ONLY valid JSON matching this schema: +{ + "name": "story name", + "branch_name": "story/kebab-case-name", + "tasks": [ + { + "name": "task name", + "instructions": "detailed instructions including file paths and what to change", + "acceptance_criteria": "specific, verifiable conditions a separate reviewer can check — e.g. 'run go test ./... and verify all pass; confirm GET /api/foo returns 200 with expected JSON shape'", + "subtasks": [ + { "name": "subtask name", "instructions": "..." } + ] + } + ], + "validation": { + "type": "build|test|smoke", + "steps": ["step1", "step2"], + "success_criteria": "what success looks like" + } +} + +Rules: +- Tasks must be independently buildable (each can be deployed alone) +- Subtasks within a task are order-dependent and run sequentially +- Instructions must include specific file paths, function names, and exact changes +- Instructions must end with: git add -A && git commit -m "..." && git push origin <branch> +- Validation should match the scope: small change = build check; new feature = smoke test +- acceptance_criteria must be concrete and verifiable by a separate agent — no vague assertions like "code looks good"` +} + +func (s *Server) elaborateStoryWithClaude(ctx context.Context, workDir, goal string) (*elaboratedStory, error) { + cmd := exec.CommandContext(ctx, s.claudeBinaryPath(), + "-p", goal, + "--system-prompt", buildStoryElaboratePrompt(), + "--output-format", "json", + "--model", "haiku", + ) + if workDir != "" { + cmd.Dir = workDir + } + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + + output := stdout.Bytes() + if len(output) == 0 { + if err != nil { + return nil, fmt.Errorf("claude failed: %w (stderr: %s)", err, stderr.String()) + } + return nil, fmt.Errorf("claude returned no output") + } + + var wrapper claudeJSONResult + if jerr := json.Unmarshal(output, &wrapper); jerr != nil { + return nil, fmt.Errorf("failed to parse claude JSON wrapper: %w (output: %s)", jerr, string(output)) + } + if wrapper.IsError { + return nil, fmt.Errorf("claude error: %s", wrapper.Result) + } + + var result elaboratedStory + if jerr := json.Unmarshal([]byte(extractJSON(wrapper.Result)), &result); jerr != nil { + return nil, fmt.Errorf("failed to parse elaborated story JSON: %w (result: %s)", jerr, wrapper.Result) + } + return &result, nil +} + +func (s *Server) elaborateStoryWithGemini(ctx context.Context, workDir, goal string) (*elaboratedStory, error) { + combinedPrompt := fmt.Sprintf("%s\n\n%s", buildStoryElaboratePrompt(), goal) + cmd := exec.CommandContext(ctx, s.geminiBinaryPath(), + "-p", combinedPrompt, + "--output-format", "json", + "--model", "gemini-2.5-flash-lite", + ) + if workDir != "" { + cmd.Dir = workDir + } + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("gemini failed: %w (stderr: %s)", err, stderr.String()) + } + + var wrapper geminiJSONResult + if err := json.Unmarshal(stdout.Bytes(), &wrapper); err != nil { + return nil, fmt.Errorf("failed to parse gemini JSON wrapper: %w (output: %s)", err, stdout.String()) + } + + var result elaboratedStory + if err := json.Unmarshal([]byte(extractJSON(wrapper.Response)), &result); err != nil { + return nil, fmt.Errorf("failed to parse elaborated story JSON: %w (response: %s)", err, wrapper.Response) + } + return &result, nil +} + +func (s *Server) handleElaborateStory(w http.ResponseWriter, r *http.Request) { + var input struct { + Goal string `json:"goal"` + ProjectID string `json:"project_id"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if input.Goal == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "goal is required"}) + return + } + if input.ProjectID == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "project_id is required"}) + return + } + + proj, err := s.store.GetProject(input.ProjectID) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "project not found"}) + return + } + + // Update git refs without modifying the working tree. + if proj.LocalPath != "" { + gitCmd := exec.Command("git", "-C", proj.LocalPath, "fetch", "origin") + if err := gitCmd.Run(); err != nil { + s.logger.Warn("story elaborate: git fetch failed", "error", err, "path", proj.LocalPath) + } + } + + ctx, cancel := context.WithTimeout(r.Context(), elaborateTimeout) + defer cancel() + + result, err := s.elaborateStoryWithClaude(ctx, proj.LocalPath, input.Goal) + if err != nil { + s.logger.Warn("story elaborate: claude failed, falling back to gemini", "error", err) + result, err = s.elaborateStoryWithGemini(ctx, proj.LocalPath, input.Goal) + if err != nil { + s.logger.Error("story elaborate: fallback gemini also failed", "error", err) + writeJSON(w, http.StatusBadGateway, map[string]string{ + "error": fmt.Sprintf("elaboration failed: %v", err), + }) + return + } + } + + if result.Name == "" { + writeJSON(w, http.StatusBadGateway, map[string]string{ + "error": "elaboration failed: missing required fields in response", + }) + return + } + + writeJSON(w, http.StatusOK, result) +} + func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) { writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) @@ -310,7 +501,9 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { } var input struct { - Prompt string `json:"prompt"` + Prompt string `json:"prompt"` + ProjectID string `json:"project_id"` + // project_dir kept for backward compat; project_id takes precedence ProjectDir string `json:"project_dir"` } if err := json.NewDecoder(r.Body).Decode(&input); err != nil { @@ -323,11 +516,15 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { } workDir := s.workDir - if input.ProjectDir != "" { + if input.ProjectID != "" { + if proj, err := s.store.GetProject(input.ProjectID); err == nil { + workDir = proj.LocalPath + } + } else if input.ProjectDir != "" { workDir = input.ProjectDir } - if input.ProjectDir != "" { + if workDir != s.workDir { go s.appendRawNarrative(workDir, input.Prompt) } diff --git a/internal/api/elaborate_test.go b/internal/api/elaborate_test.go index 0b5c706..32cec3c 100644 --- a/internal/api/elaborate_test.go +++ b/internal/api/elaborate_test.go @@ -350,6 +350,8 @@ func TestElaborateTask_InvalidJSONFromClaude(t *testing.T) { // Fake Claude returns something that is not valid JSON. srv.elaborateCmdPath = createFakeClaude(t, "not valid json at all", 0) + // Ensure Gemini fallback also fails so we get the expected 502. + srv.geminiBinPath = "/nonexistent/gemini" body := `{"prompt":"do something"}` req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(body)) @@ -388,14 +390,17 @@ func createFakeClaudeCapturingArgs(t *testing.T, output string, exitCode int, ar func TestElaborateTask_WithProjectContext(t *testing.T) { srv, _ := testServer(t) - // Create a temporary workspace with CLAUDE.md and SESSION_STATE.md + // Create a temporary workspace with CLAUDE.md and .agent/worklog.md workDir := t.TempDir() claudeContent := "Claude context info" sessionContent := "Session state info" if err := os.WriteFile(filepath.Join(workDir, "CLAUDE.md"), []byte(claudeContent), 0600); err != nil { t.Fatal(err) } - if err := os.WriteFile(filepath.Join(workDir, "SESSION_STATE.md"), []byte(sessionContent), 0600); err != nil { + if err := os.MkdirAll(filepath.Join(workDir, ".agent"), 0700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(workDir, ".agent", "worklog.md"), []byte(sessionContent), 0600); err != nil { t.Fatal(err) } @@ -436,7 +441,7 @@ func TestElaborateTask_WithProjectContext(t *testing.T) { t.Errorf("expected arguments to contain CLAUDE.md content, got %s", argsStr) } if !strings.Contains(argsStr, sessionContent) { - t.Errorf("expected arguments to contain SESSION_STATE.md content, got %s", argsStr) + t.Errorf("expected arguments to contain .agent/worklog.md content, got %s", argsStr) } } diff --git a/internal/api/executions.go b/internal/api/executions.go index 114425e..4d8ba9c 100644 --- a/internal/api/executions.go +++ b/internal/api/executions.go @@ -86,6 +86,48 @@ func (s *Server) handleGetExecutionLog(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, content) } +// handleGetDashboardStats returns pre-aggregated error, throughput, and billing stats. +// GET /api/stats?window=7d|24h +func (s *Server) handleGetDashboardStats(w http.ResponseWriter, r *http.Request) { + window := 7 * 24 * time.Hour + if r.URL.Query().Get("window") == "24h" { + window = 24 * time.Hour + } + since := time.Now().Add(-window) + + stats, err := s.store.QueryDashboardStats(since) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, stats) +} + +// handleGetAgentStatus returns the current status of all agents and recent rate-limit events. +// GET /api/agents/status?since=<RFC3339> +func (s *Server) handleGetAgentStatus(w http.ResponseWriter, r *http.Request) { + since := time.Now().Add(-24 * time.Hour) + if v := r.URL.Query().Get("since"); v != "" { + if t, err := time.Parse(time.RFC3339, v); err == nil { + since = t + } + } + + events, err := s.store.ListAgentEvents(since) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if events == nil { + events = []storage.AgentEvent{} + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "agents": s.pool.AgentStatuses(), + "events": events, + }) +} + // tailLogFile reads the last n lines from the file at path. func tailLogFile(path string, n int) (string, error) { data, err := os.ReadFile(path) diff --git a/internal/api/projects.go b/internal/api/projects.go new file mode 100644 index 0000000..d3dbbf9 --- /dev/null +++ b/internal/api/projects.go @@ -0,0 +1,71 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/google/uuid" + "github.com/thepeterstone/claudomator/internal/task" +) + +func (s *Server) handleListProjects(w http.ResponseWriter, r *http.Request) { + projects, err := s.store.ListProjects() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if projects == nil { + projects = []*task.Project{} + } + writeJSON(w, http.StatusOK, projects) +} + +func (s *Server) handleCreateProject(w http.ResponseWriter, r *http.Request) { + var p task.Project + if err := json.NewDecoder(r.Body).Decode(&p); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if p.Name == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) + return + } + if p.ID == "" { + p.ID = uuid.New().String() + } + if err := s.store.CreateProject(&p); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusCreated, p) +} + +func (s *Server) handleGetProject(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + p, err := s.store.GetProject(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "project not found"}) + return + } + writeJSON(w, http.StatusOK, p) +} + +func (s *Server) handleUpdateProject(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + existing, err := s.store.GetProject(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "project not found"}) + return + } + if err := json.NewDecoder(r.Body).Decode(existing); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + existing.ID = id // ensure ID cannot be changed via body + if err := s.store.UpdateProject(existing); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, existing) +} + diff --git a/internal/api/push.go b/internal/api/push.go new file mode 100644 index 0000000..dde5441 --- /dev/null +++ b/internal/api/push.go @@ -0,0 +1,120 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/google/uuid" + "github.com/thepeterstone/claudomator/internal/storage" + webui "github.com/thepeterstone/claudomator/web" +) + +// pushSubscriptionStore is the minimal interface needed by push handlers. +type pushSubscriptionStore interface { + SavePushSubscription(sub storage.PushSubscription) error + DeletePushSubscription(endpoint string) error + ListPushSubscriptions() ([]storage.PushSubscription, error) +} + +// SetVAPIDConfig configures VAPID keys and email for web push notifications. +func (s *Server) SetVAPIDConfig(pub, priv, email string) { + s.vapidPublicKey = pub + s.vapidPrivateKey = priv + s.vapidEmail = email +} + +// SetPushStore configures the push subscription store. +func (s *Server) SetPushStore(store pushSubscriptionStore) { + s.pushStore = store +} + +// SetDropsDir configures the file drop directory. +func (s *Server) SetDropsDir(dir string) { + s.dropsDir = dir +} + +// handleGetVAPIDKey returns the VAPID public key for client-side push subscription. +func (s *Server) handleGetVAPIDKey(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{"public_key": s.vapidPublicKey}) +} + +// handleServiceWorker serves sw.js with a Service-Worker-Allowed: / header so +// the SW can control the full origin even though it is registered from /api/push/sw.js. +func (s *Server) handleServiceWorker(w http.ResponseWriter, r *http.Request) { + data, err := webui.Files.ReadFile("sw.js") + if err != nil { + http.Error(w, "service worker not found", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/javascript") + w.Header().Set("Service-Worker-Allowed", "/") + w.WriteHeader(http.StatusOK) + w.Write(data) +} + +// handlePushSubscribe saves a new push subscription. +func (s *Server) handlePushSubscribe(w http.ResponseWriter, r *http.Request) { + var input struct { + Endpoint string `json:"endpoint"` + Keys struct { + P256DH string `json:"p256dh"` + Auth string `json:"auth"` + } `json:"keys"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if input.Endpoint == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "endpoint is required"}) + return + } + if input.Keys.P256DH == "" || input.Keys.Auth == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "keys.p256dh and keys.auth are required"}) + return + } + + sub := storage.PushSubscription{ + ID: uuid.New().String(), + Endpoint: input.Endpoint, + P256DHKey: input.Keys.P256DH, + AuthKey: input.Keys.Auth, + } + + store := s.pushStore + if store == nil { + store = s.store + } + + if err := store.SavePushSubscription(sub); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusCreated, map[string]string{"id": sub.ID}) +} + +// handlePushUnsubscribe deletes a push subscription. +func (s *Server) handlePushUnsubscribe(w http.ResponseWriter, r *http.Request) { + var input struct { + Endpoint string `json:"endpoint"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if input.Endpoint == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "endpoint is required"}) + return + } + + store := s.pushStore + if store == nil { + store = s.store + } + + if err := store.DeletePushSubscription(input.Endpoint); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/push_test.go b/internal/api/push_test.go new file mode 100644 index 0000000..dfd5a3a --- /dev/null +++ b/internal/api/push_test.go @@ -0,0 +1,159 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/thepeterstone/claudomator/internal/storage" +) + +// mockPushStore implements pushSubscriptionStore for testing. +type mockPushStore struct { + mu sync.Mutex + subs []storage.PushSubscription +} + +func (m *mockPushStore) SavePushSubscription(sub storage.PushSubscription) error { + m.mu.Lock() + defer m.mu.Unlock() + // Upsert by endpoint. + for i, s := range m.subs { + if s.Endpoint == sub.Endpoint { + m.subs[i] = sub + return nil + } + } + m.subs = append(m.subs, sub) + return nil +} + +func (m *mockPushStore) DeletePushSubscription(endpoint string) error { + m.mu.Lock() + defer m.mu.Unlock() + filtered := m.subs[:0] + for _, s := range m.subs { + if s.Endpoint != endpoint { + filtered = append(filtered, s) + } + } + m.subs = filtered + return nil +} + +func (m *mockPushStore) ListPushSubscriptions() ([]storage.PushSubscription, error) { + m.mu.Lock() + defer m.mu.Unlock() + cp := make([]storage.PushSubscription, len(m.subs)) + copy(cp, m.subs) + return cp, nil +} + +func testServerWithPush(t *testing.T) (*Server, *mockPushStore) { + t.Helper() + srv, _ := testServer(t) + ps := &mockPushStore{} + srv.SetVAPIDConfig("testpub", "testpriv", "mailto:test@example.com") + srv.SetPushStore(ps) + return srv, ps +} + +func TestHandleGetVAPIDKey(t *testing.T) { + srv, _ := testServerWithPush(t) + + req := httptest.NewRequest("GET", "/api/push/vapid-key", nil) + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("want 200, got %d", rec.Code) + } + + var resp map[string]string + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp["public_key"] != "testpub" { + t.Errorf("want public_key %q, got %q", "testpub", resp["public_key"]) + } +} + +func TestHandlePushSubscribe_CreatesSub(t *testing.T) { + srv, ps := testServerWithPush(t) + + body := `{"endpoint":"https://push.example.com/sub1","keys":{"p256dh":"key1","auth":"auth1"}}` + req := httptest.NewRequest("POST", "/api/push/subscribe", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("want 201, got %d: %s", rec.Code, rec.Body.String()) + } + + subs, _ := ps.ListPushSubscriptions() + if len(subs) != 1 { + t.Fatalf("want 1 subscription, got %d", len(subs)) + } + if subs[0].Endpoint != "https://push.example.com/sub1" { + t.Errorf("endpoint: want %q, got %q", "https://push.example.com/sub1", subs[0].Endpoint) + } +} + +func TestHandlePushSubscribe_MissingEndpoint(t *testing.T) { + srv, _ := testServerWithPush(t) + + body := `{"keys":{"p256dh":"key1","auth":"auth1"}}` + req := httptest.NewRequest("POST", "/api/push/subscribe", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("want 400, got %d", rec.Code) + } +} + +func TestHandlePushUnsubscribe_DeletesSub(t *testing.T) { + srv, ps := testServerWithPush(t) + + // Add a subscription. + ps.SavePushSubscription(storage.PushSubscription{ //nolint:errcheck + ID: "sub-1", + Endpoint: "https://push.example.com/todelete", + P256DHKey: "key", + AuthKey: "auth", + }) + + body := `{"endpoint":"https://push.example.com/todelete"}` + req := httptest.NewRequest("DELETE", "/api/push/subscribe", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("want 204, got %d: %s", rec.Code, rec.Body.String()) + } + + subs, _ := ps.ListPushSubscriptions() + if len(subs) != 0 { + t.Errorf("want 0 subscriptions after delete, got %d", len(subs)) + } +} + +func TestHandlePushUnsubscribe_MissingEndpoint(t *testing.T) { + srv, _ := testServerWithPush(t) + + body := `{}` + req := httptest.NewRequest("DELETE", "/api/push/subscribe", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("want 400, got %d", rec.Code) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 33048e4..28cfe4a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -16,6 +16,7 @@ import ( "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" + "github.com/thepeterstone/claudomator/internal/version" webui "github.com/thepeterstone/claudomator/web" "github.com/google/uuid" ) @@ -31,6 +32,7 @@ type questionStore interface { // Server provides the REST API and WebSocket endpoint for Claudomator. type Server struct { + ctx context.Context // server lifecycle context; used for pool submissions store *storage.DB logStore logStore // injectable for tests; defaults to store taskLogStore taskLogStore // injectable for tests; defaults to store @@ -51,7 +53,12 @@ type Server struct { elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints webhookSecret string // HMAC-SHA256 secret for GitHub webhook validation projects []config.Project // configured projects for webhook routing - llm *llm.Client // optional local LLM client; when set, elaboration prefers it + vapidPublicKey string + vapidPrivateKey string + vapidEmail string + pushStore pushSubscriptionStore + dropsDir string + llm *llm.Client } // SetAPIToken configures a bearer token that must be supplied to access the API. @@ -59,6 +66,12 @@ func (s *Server) SetAPIToken(token string) { s.apiToken = token } +// SetContext replaces the server's lifecycle context used for pool submissions. +// Call this before StartHub to tie task submissions to the server's shutdown signal. +func (s *Server) SetContext(ctx context.Context) { + s.ctx = ctx +} + // SetNotifier configures a notifier that is called on every task completion. func (s *Server) SetNotifier(n notify.Notifier) { s.notifier = n @@ -75,6 +88,9 @@ func (s *Server) SetWorkspaceRoot(path string) { s.workspaceRoot = path } +// Pool returns the executor pool, for graceful shutdown by the caller. +func (s *Server) Pool() *executor.Pool { return s.pool } + // SetLLM wires a local OpenAI-compatible LLM client for use by elaboration // (and future internal helpers). When non-nil, elaboration will prefer it // over the Claude CLI; on failure it falls back to claude → gemini. @@ -82,9 +98,11 @@ func (s *Server) SetLLM(c *llm.Client) { s.llm = c } + func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server { wd, _ := os.Getwd() s := &Server{ + ctx: context.Background(), store: store, logStore: store, taskLogStore: store, @@ -125,6 +143,8 @@ func (s *Server) routes() { s.mux.HandleFunc("GET /api/tasks/{id}/subtasks", s.handleListSubtasks) s.mux.HandleFunc("GET /api/tasks/{id}/executions", s.handleListExecutions) s.mux.HandleFunc("GET /api/executions", s.handleListRecentExecutions) + s.mux.HandleFunc("GET /api/stats", s.handleGetDashboardStats) + s.mux.HandleFunc("GET /api/agents/status", s.handleGetAgentStatus) s.mux.HandleFunc("GET /api/executions/{id}", s.handleGetExecution) s.mux.HandleFunc("GET /api/executions/{id}/log", s.handleGetExecutionLog) s.mux.HandleFunc("GET /api/tasks/{id}/logs/stream", s.handleStreamTaskLogs) @@ -135,29 +155,53 @@ func (s *Server) routes() { s.mux.HandleFunc("GET /api/ws", s.handleWebSocket) s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces) s.mux.HandleFunc("GET /api/tasks/{id}/deployment-status", s.handleGetDeploymentStatus) + s.mux.HandleFunc("GET /api/projects", s.handleListProjects) + s.mux.HandleFunc("POST /api/projects", s.handleCreateProject) + s.mux.HandleFunc("GET /api/projects/{id}", s.handleGetProject) + s.mux.HandleFunc("PUT /api/projects/{id}", s.handleUpdateProject) + s.mux.HandleFunc("POST /api/stories/elaborate", s.handleElaborateStory) + s.mux.HandleFunc("POST /api/stories/approve", s.handleApproveStory) + s.mux.HandleFunc("GET /api/stories", s.handleListStories) + s.mux.HandleFunc("POST /api/stories", s.handleCreateStory) + s.mux.HandleFunc("GET /api/stories/{id}", s.handleGetStory) + s.mux.HandleFunc("GET /api/stories/{id}/tasks", s.handleListStoryTasks) + s.mux.HandleFunc("POST /api/stories/{id}/tasks", s.handleAddTaskToStory) + s.mux.HandleFunc("PUT /api/stories/{id}/status", s.handleUpdateStoryStatus) + s.mux.HandleFunc("POST /api/stories/{id}/ship", s.handleShipStory) + s.mux.HandleFunc("GET /api/stories/{id}/deployment-status", s.handleStoryDeploymentStatus) s.mux.HandleFunc("GET /api/health", s.handleHealth) + s.mux.HandleFunc("GET /api/version", s.handleVersion) s.mux.HandleFunc("POST /api/webhooks/github", s.handleGitHubWebhook) + s.mux.HandleFunc("GET /api/push/vapid-key", s.handleGetVAPIDKey) + s.mux.HandleFunc("GET /api/push/sw.js", s.handleServiceWorker) + s.mux.HandleFunc("POST /api/push/subscribe", s.handlePushSubscribe) + s.mux.HandleFunc("DELETE /api/push/subscribe", s.handlePushUnsubscribe) + s.mux.HandleFunc("GET /api/drops", s.handleListDrops) + s.mux.HandleFunc("GET /api/drops/{filename}", s.handleGetDrop) + s.mux.HandleFunc("POST /api/drops", s.handlePostDrop) s.mux.Handle("GET /", http.FileServerFS(webui.Files)) } -// forwardResults listens on the executor pool's result channel and broadcasts via WebSocket. +// forwardResults listens on the executor pool's result and started channels and broadcasts via WebSocket. func (s *Server) forwardResults() { + go func() { + for taskID := range s.pool.Started() { + event := map[string]interface{}{ + "type": "task_started", + "task_id": taskID, + "timestamp": time.Now().UTC(), + } + data, _ := json.Marshal(event) + s.hub.Broadcast(data) + } + }() for result := range s.pool.Results() { s.processResult(result) } } // processResult broadcasts a task completion event via WebSocket and calls the notifier if set. -// It also parses git diff stats from the execution stdout log and persists them. func (s *Server) processResult(result *executor.Result) { - if result.Execution.StdoutPath != "" { - if stats := parseChangestatFromFile(result.Execution.StdoutPath); stats != nil { - if err := s.store.UpdateExecutionChangestats(result.Execution.ID, stats); err != nil { - s.logger.Error("failed to store changestats", "execID", result.Execution.ID, "error", err) - } - } - } - event := map[string]interface{}{ "type": "task_completed", "task_id": result.TaskID, @@ -318,7 +362,7 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { ResumeAnswer: input.Answer, SandboxDir: latest.SandboxDir, } - if err := s.pool.SubmitResume(context.Background(), tk, resumeExec); err != nil { + if err := s.pool.SubmitResume(s.ctx, tk, resumeExec); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": err.Error()}) return } @@ -363,7 +407,7 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request ResumeSessionID: latest.SessionID, ResumeAnswer: resumeMsg, } - if err := s.pool.SubmitResume(context.Background(), tk, resumeExec); err != nil { + if err := s.pool.SubmitResume(s.ctx, tk, resumeExec); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": err.Error()}) return } @@ -415,11 +459,17 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { }) } +func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{"version": version.Version()}) +} + func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { var input struct { Name string `json:"name"` Description string `json:"description"` ElaborationInput string `json:"elaboration_input"` + Project string `json:"project"` + RepositoryURL string `json:"repository_url"` Agent task.AgentConfig `json:"agent"` Claude task.AgentConfig `json:"claude"` // legacy alias Timeout string `json:"timeout"` @@ -443,6 +493,8 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { Name: input.Name, Description: input.Description, ElaborationInput: input.ElaborationInput, + Project: input.Project, + RepositoryURL: input.RepositoryURL, Agent: input.Agent, Priority: task.Priority(input.Priority), Tags: input.Tags, @@ -453,6 +505,7 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { UpdatedAt: now, ParentTaskID: input.ParentTaskID, } + if t.Agent.Type == "" { t.Agent.Type = "claude" } @@ -523,7 +576,11 @@ func (s *Server) handleListTasks(w http.ResponseWriter, r *http.Request) { if tasks == nil { tasks = []*task.Task{} } - writeJSON(w, http.StatusOK, tasks) + views := make([]*taskView, len(tasks)) + for i, tk := range tasks { + views[i] = s.enrichTask(tk) + } + writeJSON(w, http.StatusOK, views) } func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) { @@ -533,8 +590,43 @@ func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return } - writeJSON(w, http.StatusOK, t) + writeJSON(w, http.StatusOK, s.enrichTask(t)) +} +// retryableDepStates are the states from which a dependency can be retried +// when cascading a retry from a dependent task. +var retryableDepStates = map[task.State]bool{ + task.StateFailed: true, + task.StateTimedOut: true, + task.StateCancelled: true, + task.StateBudgetExceeded: true, +} + +// cascadeRetryDeps resets any dependency (recursively) that is in a retryable +// terminal state, and submits it to the pool. This ensures that retrying a +// CANCELLED task that was blocked by a failed dep will also restart that dep. +func (s *Server) cascadeRetryDeps(ctx context.Context, t *task.Task) { + for _, depID := range t.DependsOn { + dep, err := s.store.GetTask(depID) + if err != nil { + s.logger.Warn("cascadeRetryDeps: dep not found", "depID", depID) + continue + } + if !retryableDepStates[dep.State] { + continue + } + // Recursively cascade first (depth-first so root deps go first). + s.cascadeRetryDeps(ctx, dep) + reset, err := s.store.ResetTaskForRetry(depID) + if err != nil { + s.logger.Warn("cascadeRetryDeps: reset failed", "depID", depID, "error", err) + continue + } + if submitErr := s.pool.Submit(ctx, reset); submitErr != nil { + s.logger.Warn("cascadeRetryDeps: submit failed", "depID", depID, "error", submitErr) + } + } } + func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") agentParam := r.URL.Query().Get("agent") // Use a different name to avoid confusion @@ -583,7 +675,11 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { } // The task `t` now has the correct agent configuration. - if err := s.pool.Submit(context.Background(), t); err != nil { + // 6. Cascade-retry any deps that are in a terminal failure state so the + // task isn't immediately re-cancelled by checkDepsReady. + s.cascadeRetryDeps(r.Context(), originalTask) + + if err := s.pool.Submit(s.ctx, t); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)}) return } @@ -611,6 +707,9 @@ func (s *Server) handleAcceptTask(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } + if t.StoryID != "" { + go s.pool.CheckStoryCompletion(r.Context(), t.StoryID) + } writeJSON(w, http.StatusOK, map[string]string{"message": "task accepted", "task_id": id}) } diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 2139e36..2530d55 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -16,6 +16,7 @@ import ( "context" + "github.com/google/uuid" "github.com/thepeterstone/claudomator/internal/executor" "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" @@ -89,6 +90,9 @@ func testServerWithRunner(t *testing.T, runner executor.Runner) (*Server, *stora t.Cleanup(func() { store.Close() }) logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + if mr, ok := runner.(*mockRunner); ok { + mr.logDir = t.TempDir() + } runners := map[string]executor.Runner{ "claude": runner, "gemini": runner, @@ -99,11 +103,39 @@ func testServerWithRunner(t *testing.T, runner executor.Runner) (*Server, *stora } type mockRunner struct { - err error - sleep time.Duration + err error + sleep time.Duration + logDir string + onRun func(*task.Task, *storage.Execution) error +} + +func (m *mockRunner) ExecLogDir(execID string) string { + if m.logDir == "" { + return "" + } + return filepath.Join(m.logDir, execID) } -func (m *mockRunner) Run(ctx context.Context, _ *task.Task, _ *storage.Execution) error { +func (m *mockRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + if e.ID == "" { + e.ID = uuid.New().String() + } + if m.logDir != "" { + dir := m.ExecLogDir(e.ID) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + e.StdoutPath = filepath.Join(dir, "stdout.log") + e.StderrPath = filepath.Join(dir, "stderr.log") + e.ArtifactDir = dir + // Create an empty file at least + os.WriteFile(e.StdoutPath, []byte(""), 0644) + } + if m.onRun != nil { + if err := m.onRun(t, e); err != nil { + return err + } + } if m.sleep > 0 { select { case <-time.After(m.sleep): @@ -143,41 +175,26 @@ func testServerWithGeminiMockRunner(t *testing.T) (*Server, *storage.DB) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) - // Create the mock gemini binary script. Use single-quoted heredoc so - // bash does not try to evaluate the literal backticks as command - // substitution. - mockBinDir := t.TempDir() - mockGeminiPath := filepath.Join(mockBinDir, "mock-gemini-binary.sh") - mockScriptContent := `#!/bin/bash -cat <<'EOF' -` + "```json" + ` -{"type":"content_block_start","content_block":{"text":"Hello, Gemini!","type":"text"}} -{"type":"content_block_delta","content_block":{"text":" How are you?"}} -{"type":"content_block_end"} -{"type":"message_delta","message":{"role":"model"}} -{"type":"message_end"} -` + "```" + ` -EOF -exit 0 -` - if err := os.WriteFile(mockGeminiPath, []byte(mockScriptContent), 0755); err != nil { - t.Fatalf("writing mock gemini script: %v", err) - } - - // Configure GeminiRunner to use the mock script. - geminiRunner := &executor.GeminiRunner{ - BinaryPath: mockGeminiPath, - Logger: logger, - LogDir: t.TempDir(), // Ensure log directory is temporary for test - APIURL: "http://localhost:8080", // Placeholder, not used by this mock + mr := &mockRunner{ + logDir: t.TempDir(), + onRun: func(t *task.Task, e *storage.Execution) error { + lines := []string{ + `{"type":"content_block_start","content_block":{"text":"Hello, Gemini!","type":"text"}}`, + `{"type":"content_block_delta","content_block":{"text":" How are you?"}}`, + `{"type":"content_block_end"}`, + `{"type":"message_delta","message":{"role":"model"}}`, + `{"type":"message_end"}`, + } + return os.WriteFile(e.StdoutPath, []byte(strings.Join(lines, "\n")), 0644) + }, } runners := map[string]executor.Runner{ - "claude": &mockRunner{}, // Keep mock for claude to not interfere - "gemini": geminiRunner, + "claude": mr, + "gemini": mr, } pool := executor.NewPool(2, runners, store, logger) - srv := NewServer(store, pool, logger, "claude", "gemini") // Pass original binary paths + srv := NewServer(store, pool, logger, "claude", "gemini") return srv, store } @@ -200,6 +217,7 @@ func TestGeminiLogs_ParsedCorrectly(t *testing.T) { tk := createTestTask(t, srv, `{ "name": "Gemini Log Test Task", "description": "Test Gemini log parsing", + "repository_url": "https://github.com/user/repo", "agent": { "type": "gemini", "instructions": "generate some output", @@ -346,6 +364,7 @@ func TestCreateTask_Success(t *testing.T) { payload := `{ "name": "API Task", "description": "Created via API", + "repository_url": "https://github.com/user/repo", "agent": { "type": "claude", "instructions": "do the thing", @@ -399,6 +418,50 @@ func TestCreateTask_ValidationFailure(t *testing.T) { } } +func TestProject_RoundTrip(t *testing.T) { + srv, _ := testServer(t) + + payload := `{ + "name": "Project Task", + "project": "test-project", + "repository_url": "https://github.com/user/repo", + "agent": { + "type": "claude", + "instructions": "do the thing", + "model": "sonnet" + } + }` + req := httptest.NewRequest("POST", "/api/tasks", bytes.NewBufferString(payload)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("create: want 201, got %d; body: %s", w.Code, w.Body.String()) + } + + var created task.Task + json.NewDecoder(w.Body).Decode(&created) + if created.Project != "test-project" { + t.Errorf("create response: project want 'test-project', got %q", created.Project) + } + + // GET the task and verify project is persisted + req2 := httptest.NewRequest("GET", "/api/tasks/"+created.ID, nil) + w2 := httptest.NewRecorder() + srv.Handler().ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("get: want 200, got %d; body: %s", w2.Code, w2.Body.String()) + } + + var fetched task.Task + json.NewDecoder(w2.Body).Decode(&fetched) + if fetched.Project != "test-project" { + t.Errorf("get response: project want 'test-project', got %q", fetched.Project) + } +} + func TestListTasks_Empty(t *testing.T) { srv, _ := testServer(t) @@ -436,6 +499,7 @@ func TestListTasks_WithTasks(t *testing.T) { for i := 0; i < 3; i++ { tk := &task.Task{ ID: fmt.Sprintf("lt-%d", i), Name: fmt.Sprintf("T%d", i), + RepositoryURL: "https://github.com/user/repo", Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, Tags: []string{}, DependsOn: []string{}, State: task.StatePending, @@ -473,6 +537,7 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task. tk := &task.Task{ ID: id, Name: "test-task-" + id, + RepositoryURL: "https://github.com/user/repo", Agent: task.AgentConfig{Type: "claude", Instructions: "do something"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, @@ -851,6 +916,7 @@ func TestRunTask_ManualRunIgnoresRetryLimit(t *testing.T) { tk := &task.Task{ ID: "retry-limit-manual", Name: "Retry Limit Task", + RepositoryURL: "https://github.com/user/repo", Agent: task.AgentConfig{Instructions: "do something"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, @@ -888,6 +954,7 @@ func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) { tk := &task.Task{ ID: "retry-within-1", Name: "Retry Within Task", + RepositoryURL: "https://github.com/user/repo", Agent: task.AgentConfig{Instructions: "do something"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"}, @@ -935,7 +1002,7 @@ func TestDeleteTask_Success(t *testing.T) { srv, store := testServer(t) // Create a task to delete. - created := createTestTask(t, srv, `{"name":"Delete Me","agent":{"type":"claude","instructions":"x","model":"sonnet"}}`) + created := createTestTask(t, srv, `{"name":"Delete Me","repository_url":"https://github.com/user/repo","agent":{"type":"claude","instructions":"x","model":"sonnet"}}`) req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil) w := httptest.NewRecorder() @@ -970,6 +1037,7 @@ func TestDeleteTask_RunningTaskRejected(t *testing.T) { tk := &task.Task{ ID: "running-task-del", Name: "Running Task", + RepositoryURL: "https://github.com/user/repo", Agent: task.AgentConfig{Instructions: "x", Model: "sonnet"}, Priority: task.PriorityNormal, Tags: []string{}, @@ -1524,6 +1592,7 @@ func TestRunTask_AgentTimesOut_TaskSetToTimedOut(t *testing.T) { tk := &task.Task{ ID: "async-timeout-1", Name: "timeout-test", + RepositoryURL: "https://github.com/user/repo", Agent: task.AgentConfig{Type: "claude", Instructions: "do something"}, Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, @@ -1581,34 +1650,31 @@ func TestRunTask_AgentCancelled_TaskSetToCancelled(t *testing.T) { } } -// TestGetTask_IncludesChangestats verifies that after processResult parses git diff stats -// from the execution stdout log, they appear in the execution history response. +// TestGetTask_IncludesChangestats verifies that changestats stored on an execution +// are returned correctly by GET /api/tasks/{id}/executions. func TestGetTask_IncludesChangestats(t *testing.T) { srv, store := testServer(t) tk := createTaskWithState(t, store, "cs-task-1", task.StateCompleted) - // Write a stdout log with a git diff --stat summary line. - dir := t.TempDir() - stdoutPath := filepath.Join(dir, "stdout.log") - logContent := "Agent output line 1\n3 files changed, 50 insertions(+), 10 deletions(-)\nAgent output line 2\n" - if err := os.WriteFile(stdoutPath, []byte(logContent), 0600); err != nil { - t.Fatal(err) - } - exec := &storage.Execution{ - ID: "cs-exec-1", - TaskID: tk.ID, - StartTime: time.Now().UTC(), - EndTime: time.Now().UTC().Add(time.Minute), - Status: "COMPLETED", - StdoutPath: stdoutPath, + ID: "cs-exec-1", + TaskID: tk.ID, + StartTime: time.Now().UTC(), + EndTime: time.Now().UTC().Add(time.Minute), + Status: "COMPLETED", } if err := store.CreateExecution(exec); err != nil { t.Fatal(err) } - // processResult should parse changestats from the stdout log and store them. + // Pool stores changestats after execution; simulate by calling UpdateExecutionChangestats directly. + cs := &task.Changestats{FilesChanged: 3, LinesAdded: 50, LinesRemoved: 10} + if err := store.UpdateExecutionChangestats(exec.ID, cs); err != nil { + t.Fatal(err) + } + + // processResult broadcasts but does NOT parse changestats (that's the pool's job). result := &executor.Result{ TaskID: tk.ID, Execution: exec, @@ -1782,3 +1848,299 @@ func TestDeploymentStatus_NotFound(t *testing.T) { t.Fatalf("want 404, got %d", w.Code) } } + +// TestListTasks_ReadyTask_IncludesDeploymentStatus verifies that GET /api/tasks +// returns a deployment_status field for READY tasks containing deployed_commit, +// fix_commits, and includes_fix. +func TestListTasks_ReadyTask_IncludesDeploymentStatus(t *testing.T) { + srv, store := testServer(t) + + tk := createTaskWithState(t, store, "enrich-list-ready-1", task.StateReady) + exec := &storage.Execution{ + ID: "enrich-list-exec-1", + TaskID: tk.ID, + StartTime: time.Now(), + EndTime: time.Now(), + Status: "COMPLETED", + Commits: []task.GitCommit{{Hash: "aabbcc", Message: "fix: list test"}}, + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("GET", "/api/tasks", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String()) + } + + var tasks []map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&tasks); err != nil { + t.Fatalf("decode: %v", err) + } + + var found map[string]interface{} + for _, tsk := range tasks { + if tsk["id"] == tk.ID { + found = tsk + break + } + } + if found == nil { + t.Fatalf("task %q not found in list response", tk.ID) + } + + ds, ok := found["deployment_status"].(map[string]interface{}) + if !ok { + t.Fatalf("READY task missing deployment_status field; got: %v", found["deployment_status"]) + } + if _, ok := ds["deployed_commit"]; !ok { + t.Error("deployment_status missing deployed_commit") + } + if _, ok := ds["includes_fix"]; !ok { + t.Error("deployment_status missing includes_fix") + } +} + +// TestGetTask_ReadyTask_IncludesDeploymentStatus verifies that GET /api/tasks/{id} +// returns a deployment_status field for a READY task. +func TestGetTask_ReadyTask_IncludesDeploymentStatus(t *testing.T) { + srv, store := testServer(t) + + tk := createTaskWithState(t, store, "enrich-get-ready-1", task.StateReady) + exec := &storage.Execution{ + ID: "enrich-get-exec-1", + TaskID: tk.ID, + StartTime: time.Now(), + EndTime: time.Now(), + Status: "COMPLETED", + Commits: []task.GitCommit{{Hash: "ddeeff", Message: "fix: get test"}}, + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("GET", "/api/tasks/"+tk.ID, nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + + ds, ok := resp["deployment_status"].(map[string]interface{}) + if !ok { + t.Fatalf("READY task GET response missing deployment_status; got: %v", resp["deployment_status"]) + } + if _, ok := ds["deployed_commit"]; !ok { + t.Error("deployment_status missing deployed_commit") + } + if _, ok := ds["includes_fix"]; !ok { + t.Error("deployment_status missing includes_fix") + } +} + +// TestListTasks_NonReadyTask_OmitsDeploymentStatus verifies that non-READY tasks +// (e.g. PENDING) do not include a deployment_status field. +func TestListTasks_NonReadyTask_OmitsDeploymentStatus(t *testing.T) { + srv, store := testServer(t) + + tk := createTaskWithState(t, store, "enrich-list-pending-1", task.StatePending) + + req := httptest.NewRequest("GET", "/api/tasks", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d", w.Code) + } + + var tasks []map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&tasks); err != nil { + t.Fatalf("decode: %v", err) + } + + var found map[string]interface{} + for _, tsk := range tasks { + if tsk["id"] == tk.ID { + found = tsk + break + } + } + if found == nil { + t.Fatalf("task %q not found in list", tk.ID) + } + + if _, ok := found["deployment_status"]; ok { + t.Error("PENDING task should not include deployment_status field") + } +} + +func TestProjects_CRUD(t *testing.T) { + srv, _ := testServer(t) + + // Create + body := `{"name":"testproj","local_path":"/workspace/testproj","type":"web"}` + req := httptest.NewRequest("POST", "/api/projects", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("POST /api/projects: want 201, got %d; body: %s", w.Code, w.Body.String()) + } + var created map[string]interface{} + json.NewDecoder(w.Body).Decode(&created) + id, _ := created["id"].(string) + if id == "" { + t.Fatal("created project has no id") + } + + // Get + req = httptest.NewRequest("GET", "/api/projects/"+id, nil) + w = httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("GET /api/projects/%s: want 200, got %d", id, w.Code) + } + + // List + req = httptest.NewRequest("GET", "/api/projects", nil) + w = httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("GET /api/projects: want 200, got %d", w.Code) + } + var list []interface{} + json.NewDecoder(w.Body).Decode(&list) + if len(list) == 0 { + t.Error("expected at least one project in list") + } +} + +func TestHandleRunTask_CascadesRetryToFailedDeps(t *testing.T) { + // Use a blocking runner so tasks stay QUEUED long enough to assert state. + block := make(chan struct{}) + t.Cleanup(func() { close(block) }) + srv, store := testServerWithRunner(t, &mockRunner{onRun: func(*task.Task, *storage.Execution) error { + <-block + return nil + }}) + + now := time.Now().UTC() + + // Task A: the dependency, in FAILED state. + taskA := &task.Task{ + ID: "cascade-dep-a", + Name: "Dep A", + State: task.StateFailed, + DependsOn: []string{}, + Priority: task.PriorityNormal, + Tags: []string{}, + Agent: task.AgentConfig{Type: "claude", Instructions: "do A"}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateTask(taskA); err != nil { + t.Fatalf("CreateTask A: %v", err) + } + + // Task B: depends on A, in CANCELLED state (was cancelled because A failed). + taskB := &task.Task{ + ID: "cascade-task-b", + Name: "Task B", + State: task.StateCancelled, + DependsOn: []string{taskA.ID}, + Priority: task.PriorityNormal, + Tags: []string{}, + Agent: task.AgentConfig{Type: "claude", Instructions: "do B"}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateTask(taskB); err != nil { + t.Fatalf("CreateTask B: %v", err) + } + + // Run task B — should cascade-retry dep A. + req := httptest.NewRequest("POST", "/api/tasks/cascade-task-b/run", nil) + w := httptest.NewRecorder() + srv.mux.ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Fatalf("expected 202, got %d: %s", w.Code, w.Body.String()) + } + + // Dep A should now be QUEUED. + a, err := store.GetTask(taskA.ID) + if err != nil { + t.Fatalf("GetTask A: %v", err) + } + if a.State != task.StateQueued { + t.Errorf("dep A: want QUEUED after cascade, got %s", a.State) + } + + // Task B itself should be QUEUED. + b, err := store.GetTask(taskB.ID) + if err != nil { + t.Fatalf("GetTask B: %v", err) + } + if b.State != task.StateQueued { + t.Errorf("task B: want QUEUED, got %s", b.State) + } +} + +func TestShipStory_ShippableStory_Returns202(t *testing.T) { + srv, store := testServer(t) + + proj := &task.Project{ + ID: "ship-proj-1", Name: "test", RemoteURL: "https://github.com/x/y", + Type: "web", DeployScript: "", + } + if err := store.CreateProject(proj); err != nil { + t.Fatalf("CreateProject: %v", err) + } + + story := &task.Story{ + ID: "ship-story-1", Name: "Ship Test", ProjectID: "ship-proj-1", + Status: task.StoryShippable, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + req := httptest.NewRequest("POST", "/api/stories/ship-story-1/ship", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Errorf("expected 202, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestShipStory_NonShippable_Returns409(t *testing.T) { + srv, store := testServer(t) + + story := &task.Story{ + ID: "nonship-1", Name: "Not Ready", ProjectID: "", + Status: task.StoryInProgress, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + req := httptest.NewRequest("POST", "/api/stories/nonship-1/ship", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("expected 409, got %d", w.Code) + } +} diff --git a/internal/api/stories.go b/internal/api/stories.go new file mode 100644 index 0000000..fa10ccd --- /dev/null +++ b/internal/api/stories.go @@ -0,0 +1,378 @@ +package api + +import ( + "database/sql" + "encoding/json" + "fmt" + "net/http" + "os/exec" + "strings" + "time" + + "github.com/google/uuid" + "github.com/thepeterstone/claudomator/internal/deployment" + "github.com/thepeterstone/claudomator/internal/task" +) + +// createStoryBranch creates a new git branch in localPath from the latest main +// and pushes it to remoteURL (the bare repo). Idempotent: treats "already exists" as success. +func createStoryBranch(localPath, branchName, remoteURL string) error { + // Fetch latest from the bare repo so we have an up-to-date base. + if out, err := exec.Command("git", "-C", localPath, "fetch", remoteURL, "main").CombinedOutput(); err != nil { + return fmt.Errorf("git fetch: %w (output: %s)", err, string(out)) + } + base := "FETCH_HEAD" + out, err := exec.Command("git", "-C", localPath, "checkout", "-b", branchName, base).CombinedOutput() + if err != nil { + if !strings.Contains(string(out), "already exists") { + return fmt.Errorf("git checkout -b: %w (output: %s)", err, string(out)) + } + } + if out, err := exec.Command("git", "-C", localPath, "push", remoteURL, branchName).CombinedOutput(); err != nil { + return fmt.Errorf("git push: %w (output: %s)", err, string(out)) + } + return nil +} + +func (s *Server) handleListStories(w http.ResponseWriter, r *http.Request) { + stories, err := s.store.ListStories() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if stories == nil { + stories = []*task.Story{} + } + writeJSON(w, http.StatusOK, stories) +} + +func (s *Server) handleCreateStory(w http.ResponseWriter, r *http.Request) { + var st task.Story + if err := json.NewDecoder(r.Body).Decode(&st); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if st.Name == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) + return + } + if st.ID == "" { + st.ID = uuid.New().String() + } + if st.Status == "" { + st.Status = task.StoryPending + } + if err := s.store.CreateStory(&st); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusCreated, st) +} + +func (s *Server) handleGetStory(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + st, err := s.store.GetStory(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"}) + return + } + writeJSON(w, http.StatusOK, st) +} + +func (s *Server) handleListStoryTasks(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if _, err := s.store.GetStory(id); err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"}) + return + } + tasks, err := s.store.ListTasksByStory(id) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if tasks == nil { + tasks = []*task.Task{} + } + writeJSON(w, http.StatusOK, tasks) +} + +func (s *Server) handleAddTaskToStory(w http.ResponseWriter, r *http.Request) { + storyID := r.PathValue("id") + st, err := s.store.GetStory(storyID) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"}) + return + } + _ = st + + var input struct { + Name string `json:"name"` + Description string `json:"description"` + Project string `json:"project"` + RepositoryURL string `json:"repository_url"` + Agent task.AgentConfig `json:"agent"` + Claude task.AgentConfig `json:"claude"` + Timeout string `json:"timeout"` + Priority string `json:"priority"` + Tags []string `json:"tags"` + ParentTaskID string `json:"parent_task_id"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if input.Agent.Instructions == "" && input.Claude.Instructions != "" { + input.Agent = input.Claude + } + + existing, err := s.store.ListTasksByStory(storyID) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + now := time.Now().UTC() + t := &task.Task{ + ID: uuid.New().String(), + Name: input.Name, + Description: input.Description, + Project: input.Project, + RepositoryURL: input.RepositoryURL, + Agent: input.Agent, + Priority: task.Priority(input.Priority), + Tags: input.Tags, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + State: task.StatePending, + StoryID: storyID, + ParentTaskID: input.ParentTaskID, + CreatedAt: now, + UpdatedAt: now, + } + + if t.Agent.Type == "" { + t.Agent.Type = "claude" + } + if t.Priority == "" { + t.Priority = task.PriorityNormal + } + if t.Tags == nil { + t.Tags = []string{} + } + if input.Timeout != "" { + dur, err := time.ParseDuration(input.Timeout) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid timeout: " + err.Error()}) + return + } + t.Timeout.Duration = dur + } + + // Auto-wire depends_on: new task depends on the last existing task (sorted ASC by created_at). + if len(existing) > 0 { + lastTask := existing[len(existing)-1] + t.DependsOn = []string{lastTask.ID} + } + + if err := s.store.CreateTask(t); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusCreated, t) +} + +func (s *Server) handleUpdateStoryStatus(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + st, err := s.store.GetStory(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"}) + return + } + + var input struct { + Status task.StoryState `json:"status"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if !task.ValidStoryTransition(st.Status, input.Status) { + writeJSON(w, http.StatusConflict, map[string]string{ + "error": "invalid story status transition from " + string(st.Status) + " to " + string(input.Status), + }) + return + } + if err := s.store.UpdateStoryStatus(id, input.Status); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]string{"message": "story status updated", "story_id": id, "status": string(input.Status)}) +} + +func (s *Server) handleApproveStory(w http.ResponseWriter, r *http.Request) { + var input struct { + Name string `json:"name"` + BranchName string `json:"branch_name"` + ProjectID string `json:"project_id"` + Tasks []elaboratedStoryTask `json:"tasks"` + Validation elaboratedStoryValidation `json:"validation"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + if input.Name == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) + return + } + + validationJSON, _ := json.Marshal(input.Validation) + now := time.Now().UTC() + story := &task.Story{ + ID: uuid.New().String(), + Name: input.Name, + ProjectID: input.ProjectID, + BranchName: input.BranchName, + ValidationJSON: string(validationJSON), + Status: task.StoryPending, + CreatedAt: now, + UpdatedAt: now, + } + if err := s.store.CreateStory(story); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + var repoURL string + if input.ProjectID != "" { + if proj, err := s.store.GetProject(input.ProjectID); err == nil { + repoURL = proj.RemoteURL + } + } + + taskIDs := make([]string, 0, len(input.Tasks)) + var prevTaskID string + for _, tp := range input.Tasks { + t := &task.Task{ + ID: uuid.New().String(), + Name: tp.Name, + Project: input.ProjectID, + RepositoryURL: repoURL, + StoryID: story.ID, + Agent: task.AgentConfig{Type: "claude", Instructions: tp.Instructions}, + AcceptanceCriteria: tp.AcceptanceCriteria, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + State: task.StatePending, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + if prevTaskID != "" { + t.DependsOn = []string{prevTaskID} + } + if err := s.store.CreateTask(t); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + taskIDs = append(taskIDs, t.ID) + + var prevSubtaskID string + for _, sub := range tp.Subtasks { + st := &task.Task{ + ID: uuid.New().String(), + Name: sub.Name, + Project: input.ProjectID, + RepositoryURL: repoURL, + StoryID: story.ID, + ParentTaskID: t.ID, + Agent: task.AgentConfig{Type: "claude", Instructions: sub.Instructions}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + State: task.StatePending, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + if prevSubtaskID != "" { + st.DependsOn = []string{prevSubtaskID} + } + if err := s.store.CreateTask(st); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + prevSubtaskID = st.ID + } + prevTaskID = t.ID + } + + // Create the story branch (non-fatal if it fails). + if input.BranchName != "" && input.ProjectID != "" { + if proj, err := s.store.GetProject(input.ProjectID); err == nil && proj.LocalPath != "" && proj.RemoteURL != "" { + if err := createStoryBranch(proj.LocalPath, input.BranchName, proj.RemoteURL); err != nil { + s.logger.Warn("story approve: failed to create branch", "error", err, "branch", input.BranchName) + } + } + } + + writeJSON(w, http.StatusCreated, map[string]interface{}{ + "story": story, + "task_ids": taskIDs, + }) +} + +// handleShipStory triggers the merge + deploy for a SHIPPABLE story. +// POST /api/stories/{id}/ship +func (s *Server) handleShipStory(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := s.pool.ShipStory(r.Context(), id); err != nil { + writeJSON(w, http.StatusConflict, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusAccepted, map[string]string{"message": "story shipping initiated", "story_id": id}) +} + +// handleStoryDeploymentStatus aggregates the deployment status across all tasks in a story. +// GET /api/stories/{id}/deployment-status +func (s *Server) handleStoryDeploymentStatus(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + story, err := s.store.GetStory(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"}) + return + } + + tasks, err := s.store.ListTasksByStory(id) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + // Collect all commits from the latest execution of each task. + var allCommits []task.GitCommit + for _, t := range tasks { + exec, err := s.store.GetLatestExecution(t.ID) + if err != nil { + if err == sql.ErrNoRows { + continue + } + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + allCommits = append(allCommits, exec.Commits...) + } + + // Determine project remote URL for the deployment check. + projectRemoteURL := "" + if story.ProjectID != "" { + if proj, err := s.store.GetProject(story.ProjectID); err == nil { + projectRemoteURL = proj.RemoteURL + } + } + + status := deployment.Check(allCommits, projectRemoteURL) + writeJSON(w, http.StatusOK, status) +} diff --git a/internal/api/stories_test.go b/internal/api/stories_test.go new file mode 100644 index 0000000..f43ad86 --- /dev/null +++ b/internal/api/stories_test.go @@ -0,0 +1,351 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/thepeterstone/claudomator/internal/deployment" + "github.com/thepeterstone/claudomator/internal/task" +) + +func TestCreateStory_API(t *testing.T) { + srv, _ := testServer(t) + + body := `{"name":"My Story","project_id":"proj-1"}` + req := httptest.NewRequest("POST", "/api/stories", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.mux.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) + } + var st task.Story + if err := json.NewDecoder(w.Body).Decode(&st); err != nil { + t.Fatalf("decode response: %v", err) + } + if st.Name != "My Story" { + t.Errorf("Name: want 'My Story', got %q", st.Name) + } + if st.ID == "" { + t.Error("ID should be auto-generated") + } + if st.Status != task.StoryPending { + t.Errorf("Status: want PENDING, got %q", st.Status) + } +} + +func TestGetStory_API(t *testing.T) { + srv, _ := testServer(t) + + // Create a story first. + body := `{"name":"Get Me"}` + req := httptest.NewRequest("POST", "/api/stories", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.mux.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("create story: expected 201, got %d", w.Code) + } + var created task.Story + json.NewDecoder(w.Body).Decode(&created) + + // Fetch it. + req2 := httptest.NewRequest("GET", "/api/stories/"+created.ID, nil) + w2 := httptest.NewRecorder() + srv.mux.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("get story: expected 200, got %d: %s", w2.Code, w2.Body.String()) + } + var got task.Story + if err := json.NewDecoder(w2.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.ID != created.ID { + t.Errorf("ID: want %q, got %q", created.ID, got.ID) + } + if got.Name != "Get Me" { + t.Errorf("Name: want 'Get Me', got %q", got.Name) + } +} + +func TestAddTaskToStory_AutoWiresDependsOn(t *testing.T) { + srv, _ := testServer(t) + + // Create a story. + storyBody := `{"name":"Story For Tasks"}` + req := httptest.NewRequest("POST", "/api/stories", bytes.NewBufferString(storyBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.mux.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("create story: %d %s", w.Code, w.Body.String()) + } + var story task.Story + json.NewDecoder(w.Body).Decode(&story) + + addTask := func(name string) *task.Task { + body := `{"name":"` + name + `","agent":{"type":"claude","instructions":"do it"}}` + r := httptest.NewRequest("POST", "/api/stories/"+story.ID+"/tasks", bytes.NewBufferString(body)) + r.Header.Set("Content-Type", "application/json") + wr := httptest.NewRecorder() + srv.mux.ServeHTTP(wr, r) + if wr.Code != http.StatusCreated { + t.Fatalf("add task %s: expected 201, got %d: %s", name, wr.Code, wr.Body.String()) + } + var tk task.Task + json.NewDecoder(wr.Body).Decode(&tk) + return &tk + } + + task1 := addTask("Task 1") + task2 := addTask("Task 2") + task3 := addTask("Task 3") + + // task1 has no dependencies. + if len(task1.DependsOn) != 0 { + t.Errorf("task1.DependsOn: want [], got %v", task1.DependsOn) + } + // task2 depends on task1. + if len(task2.DependsOn) != 1 || task2.DependsOn[0] != task1.ID { + t.Errorf("task2.DependsOn: want [%s], got %v", task1.ID, task2.DependsOn) + } + // task3 depends on task2. + if len(task3.DependsOn) != 1 || task3.DependsOn[0] != task2.ID { + t.Errorf("task3.DependsOn: want [%s], got %v", task2.ID, task3.DependsOn) + } +} + +func TestBuildStoryElaboratePrompt(t *testing.T) { + prompt := buildStoryElaboratePrompt() + checks := []struct { + label string + want string + }{ + {"schema: name field", `"name"`}, + {"schema: branch_name field", `"branch_name"`}, + {"schema: tasks field", `"tasks"`}, + {"schema: validation field", `"validation"`}, + {"rule: git push", "git push origin"}, + {"rule: sequential subtasks", "sequentially"}, + {"rule: specific file paths", "file paths"}, + } + for _, c := range checks { + if !strings.Contains(prompt, c.want) { + t.Errorf("%s: prompt should contain %q", c.label, c.want) + } + } +} + +func TestHandleStoryApprove_WiresDepends(t *testing.T) { + srv, _ := testServer(t) + + body := `{ + "name": "My Story", + "branch_name": "story/my-story", + "tasks": [ + {"name": "Task 1", "instructions": "do task 1", "subtasks": []}, + {"name": "Task 2", "instructions": "do task 2", "subtasks": []}, + {"name": "Task 3", "instructions": "do task 3", "subtasks": []} + ], + "validation": {"type": "build", "steps": ["go build ./..."], "success_criteria": "compiles"} + }` + req := httptest.NewRequest("POST", "/api/stories/approve", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.mux.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp struct { + Story task.Story `json:"story"` + TaskIDs []string `json:"task_ids"` + } + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(resp.TaskIDs) != 3 { + t.Fatalf("expected 3 task IDs, got %d", len(resp.TaskIDs)) + } + if resp.Story.Name != "My Story" { + t.Errorf("story name: want 'My Story', got %q", resp.Story.Name) + } + + // Verify depends_on chain via the store. + store := srv.store + task1, err := store.GetTask(resp.TaskIDs[0]) + if err != nil { + t.Fatalf("GetTask[0]: %v", err) + } + task2, err := store.GetTask(resp.TaskIDs[1]) + if err != nil { + t.Fatalf("GetTask[1]: %v", err) + } + task3, err := store.GetTask(resp.TaskIDs[2]) + if err != nil { + t.Fatalf("GetTask[2]: %v", err) + } + + if len(task1.DependsOn) != 0 { + t.Errorf("task1.DependsOn: want [], got %v", task1.DependsOn) + } + if len(task2.DependsOn) != 1 || task2.DependsOn[0] != task1.ID { + t.Errorf("task2.DependsOn: want [%s], got %v", task1.ID, task2.DependsOn) + } + if len(task3.DependsOn) != 1 || task3.DependsOn[0] != task2.ID { + t.Errorf("task3.DependsOn: want [%s], got %v", task2.ID, task3.DependsOn) + } +} + +func TestHandleStoryApprove_SetsRepositoryURL(t *testing.T) { + srv, store := testServer(t) + + proj := &task.Project{ + ID: "proj-repo", + Name: "claudomator", + RemoteURL: "/site/git.terst.org/repos/claudomator.git", + // LocalPath intentionally empty: branch creation is a non-fatal side effect, + // omitting it keeps the test fast and free of real git operations. + } + if err := store.CreateProject(proj); err != nil { + t.Fatalf("CreateProject: %v", err) + } + + body := `{ + "name": "Repo URL Story", + "branch_name": "story/repo-url", + "project_id": "proj-repo", + "tasks": [ + {"name": "Task A", "instructions": "do A", "subtasks": []}, + {"name": "Task B", "instructions": "do B", "subtasks": [ + {"name": "Sub B1", "instructions": "do B1"} + ]} + ], + "validation": {"type": "build", "steps": ["go build ./..."], "success_criteria": "ok"} + }` + req := httptest.NewRequest("POST", "/api/stories/approve", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.mux.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp struct { + TaskIDs []string `json:"task_ids"` + } + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + + for _, id := range resp.TaskIDs { + tk, err := store.GetTask(id) + if err != nil { + t.Fatalf("GetTask %s: %v", id, err) + } + if tk.RepositoryURL != proj.RemoteURL { + t.Errorf("task %s RepositoryURL: want %q, got %q", tk.ID, proj.RemoteURL, tk.RepositoryURL) + } + } +} + +func TestApproveStory_AcceptanceCriteriaStored(t *testing.T) { + srv, store := testServer(t) + + proj := &task.Project{ + ID: "ac-proj", Name: "test", RemoteURL: "https://github.com/x/y", + Type: "web", DeployScript: "", + } + store.CreateProject(proj) + + body := `{ + "name": "AC Story", + "branch_name": "story/ac-test", + "project_id": "ac-proj", + "tasks": [ + { + "name": "Add feature", + "instructions": "implement the thing", + "acceptance_criteria": "run go test ./... and verify all pass", + "subtasks": [] + } + ], + "validation": {"type": "test", "steps": [], "success_criteria": "tests pass"} + }` + req := httptest.NewRequest("POST", "/api/stories/approve", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.mux.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp struct { + TaskIDs []string `json:"task_ids"` + } + json.NewDecoder(w.Body).Decode(&resp) + if len(resp.TaskIDs) == 0 { + t.Fatal("expected task_ids in response") + } + + tk, err := store.GetTask(resp.TaskIDs[0]) + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if tk.AcceptanceCriteria != "run go test ./... and verify all pass" { + t.Errorf("expected acceptance criteria stored on task, got %q", tk.AcceptanceCriteria) + } +} + +func TestHandleStoryDeploymentStatus(t *testing.T) { + srv, store := testServer(t) + + // Create a story. + now := time.Now().UTC() + story := &task.Story{ + ID: "deploy-story-1", + Name: "Deploy Status Story", + Status: task.StoryInProgress, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + // Request deployment status — no tasks yet. + req := httptest.NewRequest("GET", "/api/stories/deploy-story-1/deployment-status", nil) + w := httptest.NewRecorder() + srv.mux.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var status deployment.Status + if err := json.NewDecoder(w.Body).Decode(&status); err != nil { + t.Fatalf("decode: %v", err) + } + // No tasks → no commits → IncludesFix = false (nothing to check). + if status.IncludesFix { + t.Error("expected IncludesFix=false when no commits") + } + + // 404 for unknown story. + req2 := httptest.NewRequest("GET", "/api/stories/nonexistent/deployment-status", nil) + w2 := httptest.NewRecorder() + srv.mux.ServeHTTP(w2, req2) + if w2.Code != http.StatusNotFound { + t.Errorf("expected 404 for unknown story, got %d", w2.Code) + } +} diff --git a/internal/api/task_view.go b/internal/api/task_view.go new file mode 100644 index 0000000..6a4b58e --- /dev/null +++ b/internal/api/task_view.go @@ -0,0 +1,47 @@ +package api + +import ( + "database/sql" + + "github.com/thepeterstone/claudomator/internal/deployment" + "github.com/thepeterstone/claudomator/internal/task" +) + +// taskView wraps a task with computed fields that are derived from execution +// history and deployment state. It is used as the JSON response type for task +// list and get endpoints so that callers receive enriched data in one request. +type taskView struct { + *task.Task + Changestats *task.Changestats `json:"changestats,omitempty"` + DeploymentStatus *deployment.Status `json:"deployment_status,omitempty"` + ErrorMsg string `json:"error_msg,omitempty"` +} + +var failedStates = map[task.State]bool{ + task.StateFailed: true, + task.StateBudgetExceeded: true, + task.StateTimedOut: true, +} + +// enrichTask fetches the latest execution for the given task and attaches +// changestats, deployment_status, and error_msg fields. +func (s *Server) enrichTask(tk *task.Task) *taskView { + view := &taskView{Task: tk} + + exec, err := s.store.GetLatestExecution(tk.ID) + if err != nil { + if err == sql.ErrNoRows && tk.State == task.StateReady { + view.DeploymentStatus = deployment.Check(nil, tk.RepositoryURL) + } + return view + } + + if failedStates[tk.State] && exec.ErrorMsg != "" { + view.ErrorMsg = exec.ErrorMsg + } + if tk.State == task.StateReady { + view.Changestats = exec.Changestats + view.DeploymentStatus = deployment.Check(exec.Commits, tk.RepositoryURL) + } + return view +} diff --git a/internal/api/webhook.go b/internal/api/webhook.go index 9437f7d..3af4cc8 100644 --- a/internal/api/webhook.go +++ b/internal/api/webhook.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "path/filepath" "strings" @@ -18,17 +19,26 @@ import ( "github.com/thepeterstone/claudomator/internal/task" ) +// prRef is a minimal pull request entry used to extract branch names. +type prRef struct { + Head struct { + Ref string `json:"ref"` + } `json:"head"` +} + // checkRunPayload is the GitHub check_run webhook payload. type checkRunPayload struct { Action string `json:"action"` CheckRun struct { - Name string `json:"name"` - Conclusion string `json:"conclusion"` - HTMLURL string `json:"html_url"` - HeadSHA string `json:"head_sha"` - CheckSuite struct { + Name string `json:"name"` + Conclusion string `json:"conclusion"` + HTMLURL string `json:"html_url"` + DetailsURL string `json:"details_url"` + HeadSHA string `json:"head_sha"` + CheckSuite struct { HeadBranch string `json:"head_branch"` } `json:"check_suite"` + PullRequests []prRef `json:"pull_requests"` } `json:"check_run"` Repository struct { Name string `json:"name"` @@ -40,11 +50,12 @@ type checkRunPayload struct { type workflowRunPayload struct { Action string `json:"action"` WorkflowRun struct { - Name string `json:"name"` - Conclusion string `json:"conclusion"` - HTMLURL string `json:"html_url"` - HeadSHA string `json:"head_sha"` - HeadBranch string `json:"head_branch"` + Name string `json:"name"` + Conclusion string `json:"conclusion"` + HTMLURL string `json:"html_url"` + HeadSHA string `json:"head_sha"` + HeadBranch string `json:"head_branch"` + PullRequests []prRef `json:"pull_requests"` } `json:"workflow_run"` Repository struct { Name string `json:"name"` @@ -98,6 +109,7 @@ func (s *Server) handleGitHubWebhook(w http.ResponseWriter, r *http.Request) { } eventType := r.Header.Get("X-GitHub-Event") + slog.Info("github webhook received", "event", eventType, "bytes", len(body)) switch eventType { case "check_run": s.handleCheckRunEvent(w, body) @@ -118,13 +130,22 @@ func (s *Server) handleCheckRunEvent(w http.ResponseWriter, body []byte) { w.WriteHeader(http.StatusNoContent) return } + branch := p.CheckRun.CheckSuite.HeadBranch + if branch == "" && len(p.CheckRun.PullRequests) > 0 { + branch = p.CheckRun.PullRequests[0].Head.Ref + } + htmlURL := p.CheckRun.HTMLURL + if htmlURL == "" { + htmlURL = p.CheckRun.DetailsURL + } + slog.Info("check_run webhook", "repo", p.Repository.FullName, "conclusion", p.CheckRun.Conclusion, "branch", branch, "html_url", htmlURL) s.createCIFailureTask(w, p.Repository.Name, p.Repository.FullName, - p.CheckRun.CheckSuite.HeadBranch, + branch, p.CheckRun.HeadSHA, p.CheckRun.Name, - p.CheckRun.HTMLURL, + htmlURL, ) } @@ -142,10 +163,15 @@ func (s *Server) handleWorkflowRunEvent(w http.ResponseWriter, body []byte) { w.WriteHeader(http.StatusNoContent) return } + branch := p.WorkflowRun.HeadBranch + if branch == "" && len(p.WorkflowRun.PullRequests) > 0 { + branch = p.WorkflowRun.PullRequests[0].Head.Ref + } + slog.Info("workflow_run webhook", "repo", p.Repository.FullName, "conclusion", p.WorkflowRun.Conclusion, "branch", branch, "html_url", p.WorkflowRun.HTMLURL) s.createCIFailureTask(w, p.Repository.Name, p.Repository.FullName, - p.WorkflowRun.HeadBranch, + branch, p.WorkflowRun.HeadSHA, p.WorkflowRun.Name, p.WorkflowRun.HTMLURL, @@ -155,6 +181,10 @@ func (s *Server) handleWorkflowRunEvent(w http.ResponseWriter, body []byte) { func (s *Server) createCIFailureTask(w http.ResponseWriter, repoName, fullName, branch, sha, checkName, htmlURL string) { project := matchProject(s.projects, repoName) + if htmlURL == "" && fullName != "" && sha != "" { + htmlURL = fmt.Sprintf("https://github.com/%s/commit/%s", fullName, sha) + } + fallback := fmt.Sprintf( "A CI failure has been detected and requires investigation.\n\n"+ "Repository: %s\n"+ @@ -188,20 +218,22 @@ func (s *Server) createCIFailureTask(w http.ResponseWriter, repoName, fullName, Name: fmt.Sprintf("Fix CI failure: %s on %s", checkName, branch), Agent: task.AgentConfig{ Type: "claude", + Model: "sonnet", Instructions: instructions, MaxBudgetUSD: 3.0, AllowedTools: []string{"Read", "Edit", "Bash", "Glob", "Grep"}, }, - Priority: task.PriorityNormal, - Tags: []string{"ci", "auto"}, - DependsOn: []string{}, - Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, - State: task.StatePending, - CreatedAt: now, - UpdatedAt: now, + Priority: task.PriorityNormal, + Tags: []string{"ci", "auto"}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + RepositoryURL: fmt.Sprintf("https://github.com/%s.git", fullName), } if project != nil { - t.Agent.ProjectDir = project.Dir + t.Project = project.Name } if err := s.store.CreateTask(t); err != nil { diff --git a/internal/api/webhook_test.go b/internal/api/webhook_test.go index 8b0599a..967b62b 100644 --- a/internal/api/webhook_test.go +++ b/internal/api/webhook_test.go @@ -124,8 +124,8 @@ func TestGitHubWebhook_CheckRunFailure_CreatesTask(t *testing.T) { if !strings.Contains(tk.Name, "main") { t.Errorf("task name %q does not contain branch", tk.Name) } - if tk.Agent.ProjectDir != "/workspace/myrepo" { - t.Errorf("task project dir = %q, want /workspace/myrepo", tk.Agent.ProjectDir) + if tk.RepositoryURL != "https://github.com/owner/myrepo.git" { + t.Errorf("task repository url = %q, want https://github.com/owner/myrepo.git", tk.RepositoryURL) } if !contains(tk.Tags, "ci") || !contains(tk.Tags, "auto") { t.Errorf("task tags %v missing expected ci/auto tags", tk.Tags) @@ -237,6 +237,104 @@ func TestGitHubWebhook_NoSecretConfigured_SkipsHMACCheck(t *testing.T) { } } +func TestGitHubWebhook_CreatesTask_WithDefaultModel(t *testing.T) { + srv, store := testServer(t) + srv.projects = []config.Project{{Name: "myrepo", Dir: "/workspace/myrepo"}} + + w := webhookPost(t, srv, "check_run", checkRunFailurePayload, "") + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d", w.Code) + } + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + tk, err := store.GetTask(resp["task_id"]) + if err != nil { + t.Fatalf("task not found: %v", err) + } + if tk.Agent.Model == "" { + t.Error("expected model to be set, got empty string") + } +} + +const checkRunNullBranchPayload = `{ + "action": "completed", + "check_run": { + "name": "CI Build", + "conclusion": "failure", + "html_url": "", + "details_url": "https://github.com/owner/myrepo/actions/runs/999/jobs/123", + "head_sha": "abc123def", + "check_suite": { + "head_branch": null + }, + "pull_requests": [{"head": {"ref": "feature/my-branch"}}] + }, + "repository": { + "name": "myrepo", + "full_name": "owner/myrepo" + } +}` + +func TestGitHubWebhook_CheckRunNullBranch_UsesPRRefAndDetailsURL(t *testing.T) { + srv, store := testServer(t) + srv.projects = []config.Project{{Name: "myrepo", Dir: "/workspace/myrepo"}} + + w := webhookPost(t, srv, "check_run", checkRunNullBranchPayload, "") + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String()) + } + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + tk, err := store.GetTask(resp["task_id"]) + if err != nil { + t.Fatalf("task not found: %v", err) + } + if !strings.Contains(tk.Name, "feature/my-branch") { + t.Errorf("task name %q should contain PR branch", tk.Name) + } + if !strings.Contains(tk.Agent.Instructions, "actions/runs/999") { + t.Errorf("instructions should contain details_url fallback, got: %s", tk.Agent.Instructions) + } +} + +const workflowRunNullBranchPayload = `{ + "action": "completed", + "workflow_run": { + "name": "CI Pipeline", + "conclusion": "failure", + "html_url": "", + "head_sha": "def456abc", + "head_branch": null, + "pull_requests": [{"head": {"ref": "fix/something"}}] + }, + "repository": { + "name": "myrepo", + "full_name": "owner/myrepo" + } +}` + +func TestGitHubWebhook_WorkflowRunNullBranch_UsesPRRef(t *testing.T) { + srv, store := testServer(t) + srv.projects = []config.Project{{Name: "myrepo", Dir: "/workspace/myrepo"}} + + w := webhookPost(t, srv, "workflow_run", workflowRunNullBranchPayload, "") + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String()) + } + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + tk, err := store.GetTask(resp["task_id"]) + if err != nil { + t.Fatalf("task not found: %v", err) + } + if !strings.Contains(tk.Name, "fix/something") { + t.Errorf("task name %q should contain PR branch", tk.Name) + } +} + func TestGitHubWebhook_UnknownEvent_Returns204(t *testing.T) { srv, _ := testServer(t) @@ -277,14 +375,14 @@ func TestGitHubWebhook_FallbackToSingleProject(t *testing.T) { if err != nil { t.Fatalf("task not found: %v", err) } - if tk.Agent.ProjectDir != "/workspace/someapp" { - t.Errorf("expected fallback to /workspace/someapp, got %q", tk.Agent.ProjectDir) + if tk.RepositoryURL != "https://github.com/owner/myrepo.git" { + t.Errorf("expected fallback repository url, got %q", tk.RepositoryURL) } } -func TestGitHubWebhook_NoProjectsConfigured_CreatesTaskWithoutProjectDir(t *testing.T) { +func TestGitHubWebhook_NoProjectsConfigured_CreatesTaskWithGitHubURL(t *testing.T) { srv, store := testServer(t) - // No projects configured — task should still be created, just no project dir set. + // No projects configured — task should still be created with the GitHub remote URL. w := webhookPost(t, srv, "check_run", checkRunFailurePayload, "") @@ -297,8 +395,8 @@ func TestGitHubWebhook_NoProjectsConfigured_CreatesTaskWithoutProjectDir(t *test if err != nil { t.Fatalf("task not found: %v", err) } - if tk.Agent.ProjectDir != "" { - t.Errorf("expected empty project dir, got %q", tk.Agent.ProjectDir) + if tk.RepositoryURL == "" { + t.Error("expected non-empty repository_url from GitHub webhook payload") } } diff --git a/internal/cli/list.go b/internal/cli/list.go index 3425388..ab80868 100644 --- a/internal/cli/list.go +++ b/internal/cli/list.go @@ -49,10 +49,10 @@ func listTasks(state string) error { } w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) - fmt.Fprintln(w, "ID\tNAME\tSTATE\tPRIORITY\tCREATED") + fmt.Fprintln(w, "ID\tNAME\tPROJECT\tSTATE\tPRIORITY\tCREATED") for _, t := range tasks { - fmt.Fprintf(w, "%.8s\t%s\t%s\t%s\t%s\n", - t.ID, t.Name, t.State, t.Priority, t.CreatedAt.Format("2006-01-02 15:04")) + fmt.Fprintf(w, "%.8s\t%s\t%s\t%s\t%s\t%s\n", + t.ID, t.Name, t.Project, t.State, t.Priority, t.CreatedAt.Format("2006-01-02 15:04")) } w.Flush() return nil diff --git a/internal/cli/project_test.go b/internal/cli/project_test.go new file mode 100644 index 0000000..c62e181 --- /dev/null +++ b/internal/cli/project_test.go @@ -0,0 +1,102 @@ +package cli + +import ( + "bytes" + "io" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/thepeterstone/claudomator/internal/config" + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +func makeProjectTask(t *testing.T, dir string) *task.Task { + t.Helper() + db, err := storage.Open(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("storage.Open: %v", err) + } + defer db.Close() + + now := time.Now().UTC() + tk := &task.Task{ + ID: "proj-task-id", + Name: "Project Task", + Project: "test-project", + Agent: task.AgentConfig{Type: "claude", Instructions: "do it", Model: "sonnet"}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + if err := db.CreateTask(tk); err != nil { + t.Fatalf("CreateTask: %v", err) + } + return tk +} + +func captureStdout(fn func()) string { + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + fn() + + w.Close() + os.Stdout = old + var buf bytes.Buffer + io.Copy(&buf, r) + return buf.String() +} + +func withDB(t *testing.T, dbPath string, fn func()) { + t.Helper() + origCfg := cfg + if cfg == nil { + cfg = &config.Config{} + } + cfg.DBPath = dbPath + defer func() { cfg = origCfg }() + fn() +} + +func TestListTasks_ShowsProject(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + makeProjectTask(t, dir) + + withDB(t, dbPath, func() { + out := captureStdout(func() { + if err := listTasks(""); err != nil { + t.Fatalf("listTasks: %v", err) + } + }) + if !strings.Contains(out, "test-project") { + t.Errorf("list output missing project 'test-project':\n%s", out) + } + }) +} + +func TestStatusCmd_ShowsProject(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + tk := makeProjectTask(t, dir) + + withDB(t, dbPath, func() { + out := captureStdout(func() { + if err := showStatus(tk.ID); err != nil { + t.Fatalf("showStatus: %v", err) + } + }) + if !strings.Contains(out, "test-project") { + t.Errorf("status output missing project 'test-project':\n%s", out) + } + }) +} diff --git a/internal/cli/root.go b/internal/cli/root.go index 7c4f2ff..e57a9d9 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -60,6 +60,7 @@ func NewRootCmd() *cobra.Command { } cfg.DBPath = filepath.Join(cfg.DataDir, "claudomator.db") cfg.LogDir = filepath.Join(cfg.DataDir, "executions") + cfg.DropsDir = filepath.Join(cfg.DataDir, "drops") return nil } @@ -73,6 +74,7 @@ func NewRootCmd() *cobra.Command { newStartCmd(), newCreateCmd(), newReportCmd(), + newVersionCmd(), ) return cmd diff --git a/internal/cli/run.go b/internal/cli/run.go index 2d7c3d7..48f34b7 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -72,16 +72,31 @@ func runTasks(file string, parallel int, dryRun bool) error { logger := newLogger(verbose) + apiURL := "http://localhost" + cfg.ServerAddr + if len(cfg.ServerAddr) > 0 && cfg.ServerAddr[0] != ':' { + apiURL = "http://" + cfg.ServerAddr + } + runners := map[string]executor.Runner{ - "claude": &executor.ClaudeRunner{ - BinaryPath: cfg.ClaudeBinaryPath, - Logger: logger, - LogDir: cfg.LogDir, + "claude": &executor.ContainerRunner{ + Image: cfg.ClaudeImage, + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + DropsDir: cfg.DropsDir, + SSHAuthSock: cfg.SSHAuthSock, + ClaudeBinary: cfg.ClaudeBinaryPath, + GeminiBinary: cfg.GeminiBinaryPath, }, - "gemini": &executor.GeminiRunner{ - BinaryPath: cfg.GeminiBinaryPath, - Logger: logger, - LogDir: cfg.LogDir, + "gemini": &executor.ContainerRunner{ + Image: cfg.GeminiImage, + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + DropsDir: cfg.DropsDir, + SSHAuthSock: cfg.SSHAuthSock, + ClaudeBinary: cfg.ClaudeBinaryPath, + GeminiBinary: cfg.GeminiBinaryPath, }, } @@ -95,6 +110,7 @@ func runTasks(file string, parallel int, dryRun bool) error { } } + pool := executor.NewPool(parallel, runners, store, logger) pool.Classifier = &executor.Classifier{ LLM: localClient, diff --git a/internal/cli/serve.go b/internal/cli/serve.go index 5101b81..459c35b 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -35,6 +35,8 @@ func newServeCmd() *cobra.Command { cmd.Flags().StringVar(&addr, "addr", ":8484", "listen address") cmd.Flags().StringVar(&workspaceRoot, "workspace-root", "/workspace", "root directory for listing workspaces") + cmd.Flags().StringVar(&cfg.ClaudeImage, "claude-image", cfg.ClaudeImage, "docker image for claude agents") + cmd.Flags().StringVar(&cfg.GeminiImage, "gemini-image", cfg.GeminiImage, "docker image for gemini agents") return cmd } @@ -50,25 +52,68 @@ func serve(addr string) error { } defer store.Close() + // Load VAPID keys from DB; generate and persist if missing. + if cfg.VAPIDPublicKey == "" || cfg.VAPIDPrivateKey == "" { + pub, _ := store.GetSetting("vapid_public_key") + priv, _ := store.GetSetting("vapid_private_key") + if pub == "" || priv == "" || !notify.ValidateVAPIDPublicKey(pub) { + pub, priv, err = notify.GenerateVAPIDKeys() + if err != nil { + return fmt.Errorf("generating VAPID keys: %w", err) + } + _ = store.SetSetting("vapid_public_key", pub) + _ = store.SetSetting("vapid_private_key", priv) + } + cfg.VAPIDPublicKey = pub + cfg.VAPIDPrivateKey = priv + } + logger := newLogger(verbose) apiURL := "http://localhost" + addr if len(addr) > 0 && addr[0] != ':' { apiURL = "http://" + addr } - + + // Use configured credentials dir; sync-credentials keeps this populated. + claudeConfigDir := cfg.ClaudeConfigDir + + repoDir, _ := os.Getwd() runners := map[string]executor.Runner{ - "claude": &executor.ClaudeRunner{ - BinaryPath: cfg.ClaudeBinaryPath, - Logger: logger, - LogDir: cfg.LogDir, - APIURL: apiURL, + // ContainerRunner: binaries are resolved via PATH inside the container image, + // so ClaudeBinary/GeminiBinary are left empty (host paths would not exist inside). + "claude": &executor.ContainerRunner{ + Image: cfg.ClaudeImage, + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + DropsDir: cfg.DropsDir, + SSHAuthSock: cfg.SSHAuthSock, + ClaudeConfigDir: claudeConfigDir, + CredentialSyncCmd: filepath.Join(repoDir, "scripts", "sync-credentials"), + Store: store, + }, + "gemini": &executor.ContainerRunner{ + Image: cfg.GeminiImage, + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + DropsDir: cfg.DropsDir, + SSHAuthSock: cfg.SSHAuthSock, + ClaudeConfigDir: claudeConfigDir, + CredentialSyncCmd: filepath.Join(repoDir, "scripts", "sync-credentials"), + Store: store, }, - "gemini": &executor.GeminiRunner{ - BinaryPath: cfg.GeminiBinaryPath, - Logger: logger, - LogDir: cfg.LogDir, - APIURL: apiURL, + "container": &executor.ContainerRunner{ + Image: "claudomator-agent:latest", + Logger: logger, + LogDir: cfg.LogDir, + APIURL: apiURL, + DropsDir: cfg.DropsDir, + SSHAuthSock: cfg.SSHAuthSock, + ClaudeConfigDir: claudeConfigDir, + CredentialSyncCmd: filepath.Join(repoDir, "scripts", "sync-credentials"), + Store: store, }, } @@ -83,6 +128,7 @@ func serve(addr string) error { logger.Info("local runner registered", "endpoint", cfg.LocalModel.Endpoint, "model", cfg.LocalModel.Model) } + pool := executor.NewPool(cfg.MaxConcurrent, runners, store, logger) pool.Classifier = &executor.Classifier{ LLM: localClient, @@ -91,14 +137,36 @@ func serve(addr string) error { if localClient != nil { pool.LLM = localClient } + + if err := store.SeedProjects(); err != nil { + logger.Error("failed to seed projects", "error", err) + } + pool.RecoverStaleRunning(context.Background()) pool.RecoverStaleQueued(context.Background()) pool.RecoverStaleBlocked() srv := api.NewServer(store, pool, logger, cfg.ClaudeBinaryPath, cfg.GeminiBinaryPath) + + // Configure notifiers: combine webhook (if set) with web push. + notifiers := []notify.Notifier{} if cfg.WebhookURL != "" { - srv.SetNotifier(notify.NewWebhookNotifier(cfg.WebhookURL, logger)) + notifiers = append(notifiers, notify.NewWebhookNotifier(cfg.WebhookURL, logger)) + } + webPushNotifier := ¬ify.WebPushNotifier{ + Store: store, + VAPIDPublicKey: cfg.VAPIDPublicKey, + VAPIDPrivateKey: cfg.VAPIDPrivateKey, + VAPIDEmail: cfg.VAPIDEmail, + Logger: logger, } + notifiers = append(notifiers, webPushNotifier) + srv.SetNotifier(notify.NewMultiNotifier(logger, notifiers...)) + + srv.SetVAPIDConfig(cfg.VAPIDPublicKey, cfg.VAPIDPrivateKey, cfg.VAPIDEmail) + srv.SetPushStore(store) + srv.SetDropsDir(cfg.DropsDir) + if cfg.WorkspaceRoot != "" { srv.SetWorkspaceRoot(cfg.WorkspaceRoot) } @@ -115,6 +183,11 @@ func serve(addr string) error { "deploy": filepath.Join(wd, "scripts", "deploy"), }) + // Graceful shutdown. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + srv.SetContext(ctx) srv.StartHub() httpSrv := &http.Server{ @@ -122,19 +195,31 @@ func serve(addr string) error { Handler: srv.Handler(), } - // Graceful shutdown. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + workerTimeout := 3 * time.Minute + if cfg.ShutdownTimeout > 0 { + workerTimeout = cfg.ShutdownTimeout + } sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) go func() { <-sigCh - logger.Info("shutting down server...") - shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second) - defer shutdownCancel() - if err := httpSrv.Shutdown(shutdownCtx); err != nil { - logger.Warn("shutdown error", "err", err) + logger.Info("shutting down: draining workers...", "timeout", workerTimeout) + + // Stop the HTTP server so no new requests come in. + httpCtx, httpCancel := context.WithTimeout(ctx, 5*time.Second) + defer httpCancel() + if err := httpSrv.Shutdown(httpCtx); err != nil { + logger.Warn("http shutdown error", "err", err) + } + + // Wait for in-flight task workers to finish. + workerCtx, workerCancel := context.WithTimeout(context.Background(), workerTimeout) + defer workerCancel() + if err := srv.Pool().Shutdown(workerCtx); err != nil { + logger.Warn("worker drain timed out", "err", err) + } else { + logger.Info("all workers finished cleanly") } }() @@ -144,3 +229,4 @@ func serve(addr string) error { } return nil } + diff --git a/internal/cli/status.go b/internal/cli/status.go index 16b88b0..77a30d5 100644 --- a/internal/cli/status.go +++ b/internal/cli/status.go @@ -39,6 +39,9 @@ func showStatus(id string) error { fmt.Printf("State: %s\n", t.State) fmt.Printf("Priority: %s\n", t.Priority) fmt.Printf("Model: %s\n", t.Agent.Model) + if t.Project != "" { + fmt.Printf("Project: %s\n", t.Project) + } if t.Description != "" { fmt.Printf("Description: %s\n", t.Description) } diff --git a/internal/cli/version.go b/internal/cli/version.go new file mode 100644 index 0000000..789416a --- /dev/null +++ b/internal/cli/version.go @@ -0,0 +1,18 @@ +package cli + +import ( + "fmt" + + "github.com/thepeterstone/claudomator/internal/version" + "github.com/spf13/cobra" +) + +func newVersionCmd() *cobra.Command { + return &cobra.Command{ + Use: "version", + Short: "Show the version of claudomator", + Run: func(cmd *cobra.Command, args []string) { + fmt.Printf("claudomator version %s\n", version.Version()) + }, + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 5801239..25187cf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "time" "github.com/BurntSushi/toml" ) @@ -45,19 +46,28 @@ func (m LocalModel) UseForElaborate() bool { } type Config struct { - DataDir string `toml:"data_dir"` - DBPath string `toml:"-"` - LogDir string `toml:"-"` - ClaudeBinaryPath string `toml:"claude_binary_path"` - GeminiBinaryPath string `toml:"gemini_binary_path"` - MaxConcurrent int `toml:"max_concurrent"` - DefaultTimeout string `toml:"default_timeout"` - ServerAddr string `toml:"server_addr"` - WebhookURL string `toml:"webhook_url"` - WorkspaceRoot string `toml:"workspace_root"` - WebhookSecret string `toml:"webhook_secret"` - Projects []Project `toml:"projects"` - LocalModel LocalModel `toml:"local_model"` + DataDir string `toml:"data_dir"` + DBPath string `toml:"-"` + LogDir string `toml:"-"` + DropsDir string `toml:"-"` + SSHAuthSock string `toml:"ssh_auth_sock"` + ClaudeBinaryPath string `toml:"claude_binary_path"` + GeminiBinaryPath string `toml:"gemini_binary_path"` + ClaudeImage string `toml:"claude_image"` + GeminiImage string `toml:"gemini_image"` + MaxConcurrent int `toml:"max_concurrent"` + ShutdownTimeout time.Duration `toml:"shutdown_timeout"` + DefaultTimeout string `toml:"default_timeout"` + ServerAddr string `toml:"server_addr"` + WebhookURL string `toml:"webhook_url"` + WorkspaceRoot string `toml:"workspace_root"` + WebhookSecret string `toml:"webhook_secret"` + Projects []Project `toml:"projects"` + VAPIDPublicKey string `toml:"vapid_public_key"` + VAPIDPrivateKey string `toml:"vapid_private_key"` + VAPIDEmail string `toml:"vapid_email"` + ClaudeConfigDir string `toml:"claude_config_dir"` + LocalModel LocalModel `toml:"local_model"` } func Default() (*Config, error) { @@ -73,12 +83,17 @@ func Default() (*Config, error) { DataDir: dataDir, DBPath: filepath.Join(dataDir, "claudomator.db"), LogDir: filepath.Join(dataDir, "executions"), + DropsDir: filepath.Join(dataDir, "drops"), + SSHAuthSock: os.Getenv("SSH_AUTH_SOCK"), ClaudeBinaryPath: "claude", GeminiBinaryPath: "gemini", + ClaudeImage: "claudomator-agent:latest", + GeminiImage: "claudomator-agent:latest", MaxConcurrent: 3, DefaultTimeout: "15m", ServerAddr: ":8484", WorkspaceRoot: "/workspace", + ClaudeConfigDir: "/workspace/claudomator/credentials/claude", }, nil } @@ -97,7 +112,7 @@ func LoadFile(path string) (*Config, error) { // EnsureDirs creates the data directory structure. func (c *Config) EnsureDirs() error { - for _, dir := range []string{c.DataDir, c.LogDir} { + for _, dir := range []string{c.DataDir, c.LogDir, c.DropsDir} { if err := os.MkdirAll(dir, 0700); err != nil { return err } diff --git a/internal/executor/claude.go b/internal/executor/claude.go index fa68382..3c87f26 100644 --- a/internal/executor/claude.go +++ b/internal/executor/claude.go @@ -1,11 +1,8 @@ package executor import ( - "bufio" "context" - "encoding/json" "fmt" - "io" "log/slog" "os" "os/exec" @@ -30,14 +27,6 @@ type ClaudeRunner struct { // BlockedError is returned by Run when the agent wrote a question file and exited. // The pool transitions the task to BLOCKED and stores the question for the user. -type BlockedError struct { - QuestionJSON string // raw JSON from the question file - SessionID string // claude session to resume once the user answers - SandboxDir string // preserved sandbox path; resume must run here so Claude finds its session files -} - -func (e *BlockedError) Error() string { return fmt.Sprintf("task blocked: %s", e.QuestionJSON) } - // ExecLogDir returns the log directory for the given execution ID. // Implements LogPather so the pool can persist paths before execution starts. func (r *ClaudeRunner) ExecLogDir(execID string) string { @@ -200,50 +189,6 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi return nil } -// isCompletionReport returns true when a question-file JSON looks like a -// completion report rather than a real user question. Heuristic: no options -// (or empty options) and no "?" anywhere in the text. -func isCompletionReport(questionJSON string) bool { - var q struct { - Text string `json:"text"` - Options []string `json:"options"` - } - if err := json.Unmarshal([]byte(questionJSON), &q); err != nil { - return false - } - return len(q.Options) == 0 && !strings.Contains(q.Text, "?") -} - -// extractQuestionText returns the "text" field from a question-file JSON, or -// the raw string if parsing fails. -func extractQuestionText(questionJSON string) string { - var q struct { - Text string `json:"text"` - } - if err := json.Unmarshal([]byte(questionJSON), &q); err != nil { - return questionJSON - } - return strings.TrimSpace(q.Text) -} - -// gitSafe returns git arguments that prepend safety overrides so that -// commands succeed regardless of the repository owner or the host's global -// git configuration. Specifically: -// -// - "-c safe.directory=*" lets us operate on directories owned by a -// different OS user. -// - "-c commit.gpgsign=false" / "-c tag.gpgsign=false" stop git from -// trying to sign commits via the host's signing tooling. Sandbox commits -// are internal and don't need to be signed; an unconfigured or broken -// signing setup on the host should never block a sandbox merge. -func gitSafe(args ...string) []string { - return append([]string{ - "-c", "safe.directory=*", - "-c", "commit.gpgsign=false", - "-c", "tag.gpgsign=false", - }, args...) -} - // sandboxCloneSource returns the URL to clone the sandbox from. It prefers a // remote named "local" (a local bare repo that accepts pushes cleanly), then // falls back to "origin", then to the working copy path itself. @@ -497,7 +442,7 @@ func (r *ClaudeRunner) execOnce(ctx context.Context, args []string, workingDir, wg.Add(1) go func() { defer wg.Done() - costUSD, streamErr = parseStream(stdoutR, stdoutFile, r.Logger) + costUSD, _, streamErr = parseStream(stdoutR, stdoutFile, r.Logger) stdoutR.Close() }() @@ -605,116 +550,3 @@ func (r *ClaudeRunner) buildArgs(t *task.Task, e *storage.Execution, questionFil return args } -// parseStream reads streaming JSON from claude, writes to w, and returns -// (costUSD, error). error is non-nil if the stream signals task failure: -// - result message has is_error:true -// - a tool_result was denied due to missing permissions -func parseStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, error) { - tee := io.TeeReader(r, w) - scanner := bufio.NewScanner(tee) - scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines - - var totalCost float64 - var streamErr error - - for scanner.Scan() { - line := scanner.Bytes() - var msg map[string]interface{} - if err := json.Unmarshal(line, &msg); err != nil { - continue - } - - msgType, _ := msg["type"].(string) - switch msgType { - case "rate_limit_event": - if info, ok := msg["rate_limit_info"].(map[string]interface{}); ok { - status, _ := info["status"].(string) - if status == "rejected" { - streamErr = fmt.Errorf("claude rate limit reached (rejected): %v", msg) - // Immediately break since we can't continue anyway - break - } - } - case "assistant": - if errStr, ok := msg["error"].(string); ok && errStr == "rate_limit" { - streamErr = fmt.Errorf("claude rate limit reached: %v", msg) - } - case "result": - if isErr, _ := msg["is_error"].(bool); isErr { - result, _ := msg["result"].(string) - if result != "" { - streamErr = fmt.Errorf("claude task failed: %s", result) - } else { - streamErr = fmt.Errorf("claude task failed (is_error=true in result)") - } - } - // Prefer total_cost_usd from result message; fall through to legacy check below. - if cost, ok := msg["total_cost_usd"].(float64); ok { - totalCost = cost - } - case "user": - // Detect permission-denial tool_results. These occur when permission_mode - // is not bypassPermissions and claude exits 0 without completing its task. - if err := permissionDenialError(msg); err != nil && streamErr == nil { - streamErr = err - } - } - - // Legacy cost field used by older claude versions. - if cost, ok := msg["cost_usd"].(float64); ok { - totalCost = cost - } - } - - return totalCost, streamErr -} - -// permissionDenialError inspects a "user" stream message for tool_result entries -// that were denied due to missing permissions. Returns an error if found. -func permissionDenialError(msg map[string]interface{}) error { - message, ok := msg["message"].(map[string]interface{}) - if !ok { - return nil - } - content, ok := message["content"].([]interface{}) - if !ok { - return nil - } - for _, item := range content { - itemMap, ok := item.(map[string]interface{}) - if !ok { - continue - } - if itemMap["type"] != "tool_result" { - continue - } - if isErr, _ := itemMap["is_error"].(bool); !isErr { - continue - } - text, _ := itemMap["content"].(string) - if strings.Contains(text, "requested permissions") || strings.Contains(text, "haven't granted") { - return fmt.Errorf("permission denied by host: %s", text) - } - } - return nil -} - -// tailFile returns the last n lines of the file at path, or empty string if -// the file cannot be read. Used to surface subprocess stderr on failure. -func tailFile(path string, n int) string { - f, err := os.Open(path) - if err != nil { - return "" - } - defer f.Close() - - var lines []string - scanner := bufio.NewScanner(f) - for scanner.Scan() { - lines = append(lines, scanner.Text()) - if len(lines) > n { - lines = lines[1:] - } - } - return strings.Join(lines, "\n") -} diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go index cbb5947..c01e160 100644 --- a/internal/executor/claude_test.go +++ b/internal/executor/claude_test.go @@ -2,7 +2,6 @@ package executor import ( "context" - "errors" "fmt" "io" "log/slog" @@ -697,57 +696,6 @@ func TestTeardownSandbox_CleanSandboxWithNoNewCommits_RemovesSandbox(t *testing. } } -// TestBlockedError_IncludesSandboxDir verifies that when a task is blocked in a -// sandbox, the BlockedError carries the sandbox path so the resume execution can -// run in the same directory (where Claude's session files are stored). -func TestBlockedError_IncludesSandboxDir(t *testing.T) { - src := t.TempDir() - initGitRepo(t, src) - - logDir := t.TempDir() - - // Use a script that writes question.json to the env-var path and exits 0 - // (simulating a blocked agent that asks a question before exiting). - scriptPath := filepath.Join(t.TempDir(), "fake-claude.sh") - if err := os.WriteFile(scriptPath, []byte(`#!/bin/sh -if [ -n "$CLAUDOMATOR_QUESTION_FILE" ]; then - printf '{"text":"Should I continue?"}' > "$CLAUDOMATOR_QUESTION_FILE" -fi -`), 0755); err != nil { - t.Fatalf("write script: %v", err) - } - - r := &ClaudeRunner{ - BinaryPath: scriptPath, - Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), - LogDir: logDir, - } - tk := &task.Task{ - Agent: task.AgentConfig{ - Type: "claude", - Instructions: "do something", - ProjectDir: src, - SkipPlanning: true, - }, - } - exec := &storage.Execution{ID: "blocked-exec-uuid", TaskID: "task-1"} - - err := r.Run(context.Background(), tk, exec) - - var blocked *BlockedError - if !errors.As(err, &blocked) { - t.Fatalf("expected BlockedError, got: %v", err) - } - if blocked.SandboxDir == "" { - t.Error("BlockedError.SandboxDir should be set when task runs in a sandbox") - } - // Sandbox should still exist (preserved for resume). - if _, statErr := os.Stat(blocked.SandboxDir); os.IsNotExist(statErr) { - t.Error("sandbox directory should be preserved when blocked") - } else { - os.RemoveAll(blocked.SandboxDir) // cleanup - } -} // TestClaudeRunner_Run_ResumeUsesStoredSandboxDir verifies that when a resume // execution has SandboxDir set, the runner uses that directory (not project_dir) @@ -853,69 +801,6 @@ func TestClaudeRunner_Run_StaleSandboxDir_ClonesAfresh(t *testing.T) { } } -func TestIsCompletionReport(t *testing.T) { - tests := []struct { - name string - json string - expected bool - }{ - { - name: "real question with options", - json: `{"text": "Should I proceed with implementation?", "options": ["Yes", "No"]}`, - expected: false, - }, - { - name: "real question no options", - json: `{"text": "Which approach do you prefer?"}`, - expected: false, - }, - { - name: "completion report no options no question mark", - json: `{"text": "All tests pass. Implementation complete. Summary written to CLAUDOMATOR_SUMMARY_FILE."}`, - expected: true, - }, - { - name: "completion report with empty options", - json: `{"text": "Feature implemented and committed.", "options": []}`, - expected: true, - }, - { - name: "invalid json treated as not a report", - json: `not json`, - expected: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := isCompletionReport(tt.json) - if got != tt.expected { - t.Errorf("isCompletionReport(%q) = %v, want %v", tt.json, got, tt.expected) - } - }) - } -} - -func TestTailFile_ReturnsLastNLines(t *testing.T) { - f, err := os.CreateTemp("", "tailfile-*") - if err != nil { - t.Fatal(err) - } - defer os.Remove(f.Name()) - for i := 1; i <= 30; i++ { - fmt.Fprintf(f, "line %d\n", i) - } - f.Close() - - got := tailFile(f.Name(), 5) - lines := strings.Split(got, "\n") - if len(lines) != 5 { - t.Fatalf("want 5 lines, got %d: %q", len(lines), got) - } - if lines[0] != "line 26" || lines[4] != "line 30" { - t.Errorf("want lines 26-30, got: %q", got) - } -} - func TestTailFile_MissingFile_ReturnsEmpty(t *testing.T) { got := tailFile("/nonexistent/path/file.log", 10) if got != "" { @@ -923,15 +808,3 @@ func TestTailFile_MissingFile_ReturnsEmpty(t *testing.T) { } } -func TestGitSafe_PrependsSafeDirectory(t *testing.T) { - got := gitSafe("-C", "/some/path", "status") - want := []string{"-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-c", "tag.gpgsign=false", "-C", "/some/path", "status"} - if len(got) != len(want) { - t.Fatalf("gitSafe() = %v, want %v", got, want) - } - for i := range want { - if got[i] != want[i] { - t.Errorf("gitSafe()[%d] = %q, want %q", i, got[i], want[i]) - } - } -} diff --git a/internal/executor/container.go b/internal/executor/container.go new file mode 100644 index 0000000..61ac29c --- /dev/null +++ b/internal/executor/container.go @@ -0,0 +1,549 @@ +package executor + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "syscall" + + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +// ContainerRunner executes an agent inside a container. +type ContainerRunner struct { + Image string // default image if not specified in task + Logger *slog.Logger + LogDir string + APIURL string + DropsDir string + SSHAuthSock string // optional path to host SSH agent + ClaudeBinary string // optional path to claude binary in container + GeminiBinary string // optional path to gemini binary in container + ClaudeConfigDir string // host path to ~/.claude; mounted into container for auth credentials + CredentialSyncCmd string // optional path to sync-credentials script for auth-error auto-recovery + Store Store // optional; used to look up stories and projects for story-aware cloning + // Command allows mocking exec.CommandContext for tests. + Command func(ctx context.Context, name string, arg ...string) *exec.Cmd +} + +func isAuthError(err error) bool { + if err == nil { + return false + } + s := err.Error() + return strings.Contains(s, "Not logged in") || + strings.Contains(s, "OAuth token has expired") || + strings.Contains(s, "authentication_error") || + strings.Contains(s, "Please run /login") +} + +func (r *ContainerRunner) command(ctx context.Context, name string, arg ...string) *exec.Cmd { + if r.Command != nil { + return r.Command(ctx, name, arg...) + } + return exec.CommandContext(ctx, name, arg...) +} + +func (r *ContainerRunner) ExecLogDir(execID string) string { + if r.LogDir == "" { + return "" + } + return filepath.Join(r.LogDir, execID) +} + +// ensureStoryBranch checks whether branchName exists in remoteURL and creates +// it from main if not. Uses localPath as a reference clone for speed if set. +func (r *ContainerRunner) ensureStoryBranch(ctx context.Context, remoteURL, branchName, localPath string) error { + // Check if branch already exists. + out, err := r.command(ctx, "git", "ls-remote", "--heads", remoteURL, branchName).CombinedOutput() + if err == nil && len(strings.TrimSpace(string(out))) > 0 { + return nil // already exists + } + + r.Logger.Info("story branch missing, creating from main", "branch", branchName, "remote", remoteURL) + + // Clone into a temp dir so we can create the branch. + tmp, err := os.MkdirTemp("", "claudomator-branchsetup-*") + if err != nil { + return fmt.Errorf("mktemp for branch setup: %w", err) + } + defer os.RemoveAll(tmp) + + // Remove the dir git clone expects to create. + if err := os.Remove(tmp); err != nil { + return fmt.Errorf("removing tmp dir before clone: %w", err) + } + + var cloneArgs []string + if localPath != "" { + cloneArgs = []string{"clone", "--reference", localPath, remoteURL, tmp} + } else { + cloneArgs = []string{"clone", remoteURL, tmp} + } + if out, err := r.command(ctx, "git", cloneArgs...).CombinedOutput(); err != nil { + return fmt.Errorf("git clone for branch setup: %w\n%s", err, string(out)) + } + if out, err := r.command(ctx, "git", "-C", tmp, "checkout", "-b", branchName).CombinedOutput(); err != nil { + return fmt.Errorf("git checkout -b %q: %w\n%s", branchName, err, string(out)) + } + if out, err := r.command(ctx, "git", "-C", tmp, "push", "origin", branchName).CombinedOutput(); err != nil { + return fmt.Errorf("git push %q: %w\n%s", branchName, err, string(out)) + } + r.Logger.Info("story branch created and pushed", "branch", branchName) + return nil +} + +func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + var err error + repoURL := t.RepositoryURL + if repoURL == "" { + return fmt.Errorf("task %s has no repository_url", t.ID) + } + + // 1. Setup workspace on host + var workspace string + isResume := false + if e.SandboxDir != "" { + if _, err = os.Stat(e.SandboxDir); err == nil { + workspace = e.SandboxDir + isResume = true + r.Logger.Info("resuming in preserved workspace", "path", workspace) + } + } + + if workspace == "" { + workspace, err = os.MkdirTemp("", "claudomator-workspace-*") + if err != nil { + return fmt.Errorf("creating workspace: %w", err) + } + // chmod applied after clone; see step 2. + } + + // Note: workspace is only removed on success. On failure, it's preserved for debugging. + // If the task becomes BLOCKED, it's also preserved for resumption. + success := false + isBlocked := false + defer func() { + if success && !isBlocked { + os.RemoveAll(workspace) + } else { + r.Logger.Warn("preserving workspace", "path", workspace, "success", success, "blocked", isBlocked) + } + }() + + // Resolve story branch and project local path if this is a story task. + var storyBranch string + var storyLocalPath string + if t.StoryID != "" && r.Store != nil { + if story, err := r.Store.GetStory(t.StoryID); err == nil && story != nil { + storyBranch = story.BranchName + if story.ProjectID != "" { + if proj, err := r.Store.GetProject(story.ProjectID); err == nil && proj != nil { + storyLocalPath = proj.LocalPath + } + } + } + } + // Fall back to task-level BranchName (e.g. set explicitly by executor or tests). + if storyBranch == "" { + storyBranch = t.BranchName + } + + // 2. Ensure story branch exists in the remote before cloning. + // If the branch is missing (e.g. story approved before fix, or branch push failed), + // create it from main using the project local path as a reference repo. + if storyBranch != "" && !isResume { + if err := r.ensureStoryBranch(ctx, repoURL, storyBranch, storyLocalPath); err != nil { + r.Logger.Warn("ensureStoryBranch failed (will attempt checkout anyway)", "branch", storyBranch, "error", err) + } + } + + // 3. Clone repo into workspace if not resuming. + // git clone requires the target directory to not exist; remove the MkdirTemp-created dir first. + if !isResume { + if err := os.Remove(workspace); err != nil { + return fmt.Errorf("removing workspace before clone: %w", err) + } + r.Logger.Info("cloning repository", "url", repoURL, "workspace", workspace) + var cloneArgs []string + if storyLocalPath != "" { + cloneArgs = []string{"clone", "--reference", storyLocalPath, repoURL, workspace} + } else { + cloneArgs = []string{"clone", repoURL, workspace} + } + if out, err := r.command(ctx, "git", cloneArgs...).CombinedOutput(); err != nil { + return fmt.Errorf("git clone failed: %w\n%s", err, string(out)) + } + if storyBranch != "" { + r.Logger.Info("checking out story branch", "branch", storyBranch) + if out, err := r.command(ctx, "git", "-C", workspace, "checkout", storyBranch).CombinedOutput(); err != nil { + return fmt.Errorf("git checkout story branch %q failed: %w\n%s", storyBranch, err, string(out)) + } + } + if err = os.Chmod(workspace, 0755); err != nil { + return fmt.Errorf("chmod cloned workspace: %w", err) + } + } + e.SandboxDir = workspace + + // Set up a writable $HOME staging dir so any agent tool (claude, gemini, etc.) + // can freely create subdirs (session-env, .gemini, .cache, …) without hitting + // a non-existent or read-only home. We copy only the claude credentials into it. + agentHome := filepath.Join(workspace, ".agent-home") + if err := os.MkdirAll(filepath.Join(agentHome, ".claude"), 0755); err != nil { + return fmt.Errorf("creating agent home staging dir: %w", err) + } + if err := os.MkdirAll(filepath.Join(agentHome, ".gemini"), 0755); err != nil { + return fmt.Errorf("creating .gemini dir: %w", err) + } + if r.ClaudeConfigDir != "" { + // credentials + if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".credentials.json")); readErr == nil { + _ = os.WriteFile(filepath.Join(agentHome, ".claude", ".credentials.json"), srcData, 0600) + } + // settings (used by claude CLI; copy so it can write updates without hitting the host) + if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".claude.json")); readErr == nil { + _ = os.WriteFile(filepath.Join(agentHome, ".claude.json"), srcData, 0644) + } + } + + // Pre-flight: verify credentials were actually copied before spinning up a container. + if r.ClaudeConfigDir != "" { + credsPath := filepath.Join(agentHome, ".claude", ".credentials.json") + settingsPath := filepath.Join(agentHome, ".claude.json") + if _, err := os.Stat(credsPath); os.IsNotExist(err) { + return fmt.Errorf("credentials not found at %s — run sync-credentials", r.ClaudeConfigDir) + } + if _, err := os.Stat(settingsPath); os.IsNotExist(err) { + return fmt.Errorf("claude settings (.claude.json) not found at %s — run sync-credentials", r.ClaudeConfigDir) + } + } + + // Run container (with auth retry on failure). + runErr := r.runContainer(ctx, t, e, workspace, agentHome, isResume, storyBranch) + if runErr != nil && isAuthError(runErr) && r.CredentialSyncCmd != "" { + r.Logger.Warn("auth failure detected, syncing credentials and retrying once", "taskID", t.ID) + syncOut, syncErr := r.command(ctx, r.CredentialSyncCmd).CombinedOutput() + if syncErr != nil { + r.Logger.Warn("sync-credentials failed", "error", syncErr, "output", string(syncOut)) + } + // Re-copy credentials into agentHome with fresh files. + if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".credentials.json")); readErr == nil { + _ = os.WriteFile(filepath.Join(agentHome, ".claude", ".credentials.json"), srcData, 0600) + } + if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".claude.json")); readErr == nil { + _ = os.WriteFile(filepath.Join(agentHome, ".claude.json"), srcData, 0644) + } + runErr = r.runContainer(ctx, t, e, workspace, agentHome, isResume, storyBranch) + } + + if runErr == nil { + success = true + } + var blockedErr *BlockedError + if errors.As(runErr, &blockedErr) { + isBlocked = true + success = true // preserve workspace for resumption + } + return runErr +} + +// runContainer runs the docker container for the given task and handles log setup, +// environment files, instructions, and post-execution git operations. +func (r *ContainerRunner) runContainer(ctx context.Context, t *task.Task, e *storage.Execution, workspace, agentHome string, isResume bool, storyBranch string) error { + repoURL := t.RepositoryURL + + image := t.Agent.ContainerImage + if image == "" { + image = r.Image + } + if image == "" { + image = "claudomator-agent:latest" + } + + // 3. Prepare logs + logDir := r.ExecLogDir(e.ID) + if logDir == "" { + logDir = filepath.Join(workspace, ".claudomator-logs") + } + if err := os.MkdirAll(logDir, 0700); err != nil { + return fmt.Errorf("creating log dir: %w", err) + } + e.StdoutPath = filepath.Join(logDir, "stdout.log") + e.StderrPath = filepath.Join(logDir, "stderr.log") + e.ArtifactDir = logDir + + stdoutFile, err := os.Create(e.StdoutPath) + if err != nil { + return fmt.Errorf("creating stdout log: %w", err) + } + defer stdoutFile.Close() + + stderrFile, err := os.Create(e.StderrPath) + if err != nil { + return fmt.Errorf("creating stderr log: %w", err) + } + defer stderrFile.Close() + + // 4. Run container + + // Write API keys to a temporary env file to avoid exposure in 'ps' or 'docker inspect' + envFile := filepath.Join(workspace, ".claudomator-env") + envContent := fmt.Sprintf("ANTHROPIC_API_KEY=%s\nGOOGLE_API_KEY=%s\nGEMINI_API_KEY=%s\n", os.Getenv("ANTHROPIC_API_KEY"), os.Getenv("GOOGLE_API_KEY"), os.Getenv("GEMINI_API_KEY")) + if err := os.WriteFile(envFile, []byte(envContent), 0600); err != nil { + return fmt.Errorf("writing env file: %w", err) + } + + // Inject custom instructions via file to avoid CLI length limits + instructionsFile := filepath.Join(workspace, ".claudomator-instructions.txt") + if err := os.WriteFile(instructionsFile, []byte(t.Agent.Instructions), 0644); err != nil { + return fmt.Errorf("writing instructions: %w", err) + } + + args := r.buildDockerArgs(workspace, agentHome, e.TaskID) + innerCmd := r.buildInnerCmd(t, e, isResume) + + fullArgs := append(args, image) + fullArgs = append(fullArgs, innerCmd...) + + r.Logger.Info("starting container", "image", image, "taskID", t.ID) + cmd := r.command(ctx, "docker", fullArgs...) + cmd.Stderr = stderrFile + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + // Use os.Pipe for stdout so we can parse it in real-time + var stdoutR, stdoutW *os.File + stdoutR, stdoutW, err = os.Pipe() + if err != nil { + return fmt.Errorf("creating stdout pipe: %w", err) + } + cmd.Stdout = stdoutW + + if err := cmd.Start(); err != nil { + stdoutW.Close() + stdoutR.Close() + return fmt.Errorf("starting container: %w", err) + } + stdoutW.Close() + + // Watch for context cancellation to kill the process group (Issue 1) + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + r.Logger.Info("killing container process group due to context cancellation", "taskID", t.ID) + syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + case <-done: + } + }() + + // Stream stdout to the log file and parse cost/errors. + var costUSD float64 + var sessionID string + var streamErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + costUSD, sessionID, streamErr = parseStream(stdoutR, stdoutFile, r.Logger) + stdoutR.Close() + }() + + waitErr := cmd.Wait() + wg.Wait() + + e.CostUSD = costUSD + if sessionID != "" { + e.SessionID = sessionID + } + + // Check whether the agent left a question before exiting. + questionFile := filepath.Join(logDir, "question.json") + if data, readErr := os.ReadFile(questionFile); readErr == nil { + os.Remove(questionFile) // consumed + questionJSON := strings.TrimSpace(string(data)) + if isCompletionReport(questionJSON) { + r.Logger.Info("treating question file as completion report", "taskID", e.TaskID) + e.Summary = extractQuestionText(questionJSON) + } else { + if e.SessionID == "" { + r.Logger.Warn("missing session ID; resume will start fresh", "taskID", e.TaskID) + } + return &BlockedError{ + QuestionJSON: questionJSON, + SessionID: e.SessionID, + SandboxDir: workspace, + } + } + } + + // Read agent summary if written. + summaryFile := filepath.Join(logDir, "summary.txt") + if summaryData, readErr := os.ReadFile(summaryFile); readErr == nil { + os.Remove(summaryFile) // consumed + e.Summary = strings.TrimSpace(string(summaryData)) + } + + // 5. Post-execution: push changes if successful + if waitErr == nil && streamErr == nil { + // Check if there are any commits to push (HEAD ahead of origin/HEAD). + // If origin/HEAD doesn't exist (e.g. fresh clone with no commits), we attempt push anyway. + hasCommits := true + if out, err := r.command(ctx, "git", "-C", workspace, "rev-list", "origin/HEAD..HEAD").CombinedOutput(); err == nil { + if len(strings.TrimSpace(string(out))) == 0 { + hasCommits = false + } + } + + if hasCommits { + pushRef := "HEAD" + if storyBranch != "" { + pushRef = storyBranch + } + r.Logger.Info("pushing changes back to remote", "url", repoURL, "ref", pushRef) + if out, err := r.command(ctx, "git", "-C", workspace, "push", "origin", pushRef).CombinedOutput(); err != nil { + r.Logger.Warn("git push failed", "error", err, "output", string(out)) + return fmt.Errorf("git push failed: %w\n%s", err, string(out)) + } + } else { + // No commits pushed — check whether the agent left uncommitted work behind. + // If so, fail loudly: the work would be silently lost when the sandbox is deleted. + if err := detectUncommittedChanges(workspace); err != nil { + return err + } + r.Logger.Info("no new commits to push", "taskID", t.ID) + } + } + + if waitErr != nil { + // Append the tail of stderr so error classifiers (isQuotaExhausted, isRateLimitError) + // can inspect agent-specific messages (e.g. Gemini TerminalQuotaError). + stderrTail := readFileTail(e.StderrPath, 4096) + if stderrTail != "" { + return fmt.Errorf("container execution failed: %w\n%s", waitErr, stderrTail) + } + return fmt.Errorf("container execution failed: %w", waitErr) + } + if streamErr != nil { + return fmt.Errorf("stream parsing failed: %w", streamErr) + } + + return nil +} + +func (r *ContainerRunner) buildDockerArgs(workspace, claudeHome, taskID string) []string { + // --env-file takes a HOST path. + hostEnvFile := filepath.Join(workspace, ".claudomator-env") + + // Replace localhost with host.docker.internal so the container can reach the host API. + apiURL := strings.ReplaceAll(r.APIURL, "localhost", "host.docker.internal") + + args := []string{ + "run", "--rm", + // Allow container to reach the host via host.docker.internal. + "--add-host=host.docker.internal:host-gateway", + // Run as the current process UID:GID so the container can read host-owned files. + fmt.Sprintf("--user=%d:%d", os.Getuid(), os.Getgid()), + "-v", workspace + ":/workspace", + "-v", claudeHome + ":/home/agent", + "-w", "/workspace", + "--env-file", hostEnvFile, + "-e", "HOME=/home/agent", + "-e", "CLAUDOMATOR_API_URL=" + apiURL, + "-e", "CLAUDOMATOR_TASK_ID=" + taskID, + "-e", "CLAUDOMATOR_DROP_DIR=" + r.DropsDir, + } + if r.SSHAuthSock != "" { + args = append(args, "-v", r.SSHAuthSock+":/tmp/ssh-auth.sock", "-e", "SSH_AUTH_SOCK=/tmp/ssh-auth.sock") + } + return args +} + +func (r *ContainerRunner) buildInnerCmd(t *task.Task, e *storage.Execution, isResume bool) []string { + // Claude CLI uses -p for prompt text. To pass a file, we use a shell to cat it. + // We use a shell variable to capture the expansion to avoid quoting issues with instructions contents. + // The outer single quotes around the sh -c argument prevent host-side expansion. + + claudeBin := r.ClaudeBinary + if claudeBin == "" { + claudeBin = "claude" + } + geminiBin := r.GeminiBinary + if geminiBin == "" { + geminiBin = "gemini" + } + + if t.Agent.Type == "gemini" { + return []string{"sh", "-c", fmt.Sprintf("INST=$(cat /workspace/.claudomator-instructions.txt); %s -p \"$INST\"", geminiBin)} + } + + // Claude + var claudeCmd strings.Builder + claudeCmd.WriteString(fmt.Sprintf("INST=$(cat /workspace/.claudomator-instructions.txt); %s -p \"$INST\"", claudeBin)) + if isResume && e.ResumeSessionID != "" { + claudeCmd.WriteString(fmt.Sprintf(" --resume %s", e.ResumeSessionID)) + } + claudeCmd.WriteString(" --output-format stream-json --verbose --permission-mode bypassPermissions") + + return []string{"sh", "-c", claudeCmd.String()} +} + +// scaffoldPrefixes are files/dirs written by the harness into the workspace before the agent +// runs. They are not part of the repo and must not trigger the uncommitted-changes check. +var scaffoldPrefixes = []string{ + ".claudomator-env", + ".claudomator-instructions.txt", + ".agent-home", +} + +func isScaffold(path string) bool { + for _, p := range scaffoldPrefixes { + if path == p || strings.HasPrefix(path, p+"/") { + return true + } + } + return false +} + +// detectUncommittedChanges returns an error if the workspace contains modified or +// untracked source files that the agent forgot to commit. Scaffold files written by +// the harness (.claudomator-env, .claudomator-instructions.txt, .agent-home/) are +// excluded from the check. +func detectUncommittedChanges(workspace string) error { + // Modified or staged tracked files + diffOut, err := exec.Command("git", "-c", "safe.directory=*", "-C", workspace, + "diff", "--name-only", "HEAD").CombinedOutput() + if err == nil { + for _, line := range strings.Split(strings.TrimSpace(string(diffOut)), "\n") { + if line != "" && !isScaffold(line) { + return fmt.Errorf("agent left uncommitted changes (work would be lost on sandbox deletion):\n%s\nInstructions must include: git add -A && git commit && git push origin main", strings.TrimSpace(string(diffOut))) + } + } + } + + // Untracked new source files (excludes gitignored files) + lsOut, err := exec.Command("git", "-c", "safe.directory=*", "-C", workspace, + "ls-files", "--others", "--exclude-standard").CombinedOutput() + if err == nil { + var dirty []string + for _, line := range strings.Split(strings.TrimSpace(string(lsOut)), "\n") { + if line != "" && !isScaffold(line) { + dirty = append(dirty, line) + } + } + if len(dirty) > 0 { + return fmt.Errorf("agent left untracked files not committed (work would be lost on sandbox deletion):\n%s\nInstructions must include: git add -A && git commit && git push origin main", strings.Join(dirty, "\n")) + } + } + + return nil +} + diff --git a/internal/executor/container_test.go b/internal/executor/container_test.go new file mode 100644 index 0000000..f0b2a3a --- /dev/null +++ b/internal/executor/container_test.go @@ -0,0 +1,687 @@ +package executor + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/thepeterstone/claudomator/internal/storage" + "github.com/thepeterstone/claudomator/internal/task" +) + +func TestContainerRunner_BuildDockerArgs(t *testing.T) { + runner := &ContainerRunner{ + APIURL: "http://localhost:8484", + DropsDir: "/data/drops", + SSHAuthSock: "/tmp/ssh.sock", + } + workspace := "/tmp/ws" + taskID := "task-123" + + agentHome := "/tmp/ws/.agent-home" + args := runner.buildDockerArgs(workspace, agentHome, taskID) + + expected := []string{ + "run", "--rm", + "--add-host=host.docker.internal:host-gateway", + fmt.Sprintf("--user=%d:%d", os.Getuid(), os.Getgid()), + "-v", "/tmp/ws:/workspace", + "-v", "/tmp/ws/.agent-home:/home/agent", + "-w", "/workspace", + "--env-file", "/tmp/ws/.claudomator-env", + "-e", "HOME=/home/agent", + "-e", "CLAUDOMATOR_API_URL=http://host.docker.internal:8484", + "-e", "CLAUDOMATOR_TASK_ID=task-123", + "-e", "CLAUDOMATOR_DROP_DIR=/data/drops", + "-v", "/tmp/ssh.sock:/tmp/ssh-auth.sock", + "-e", "SSH_AUTH_SOCK=/tmp/ssh-auth.sock", + } + + if len(args) != len(expected) { + t.Fatalf("expected %d args, got %d. Got: %v", len(expected), len(args), args) + } + for i, v := range args { + if v != expected[i] { + t.Errorf("arg %d: expected %q, got %q", i, expected[i], v) + } + } +} + +func TestContainerRunner_BuildInnerCmd(t *testing.T) { + runner := &ContainerRunner{} + + t.Run("claude-fresh", func(t *testing.T) { + tk := &task.Task{Agent: task.AgentConfig{Type: "claude"}} + exec := &storage.Execution{} + cmd := runner.buildInnerCmd(tk, exec, false) + + cmdStr := strings.Join(cmd, " ") + if strings.Contains(cmdStr, "--resume") { + t.Errorf("unexpected --resume flag in fresh run: %q", cmdStr) + } + if !strings.Contains(cmdStr, "INST=$(cat /workspace/.claudomator-instructions.txt); claude -p \"$INST\"") { + t.Errorf("expected cat instructions in sh command, got %q", cmdStr) + } + }) + + t.Run("claude-resume", func(t *testing.T) { + tk := &task.Task{Agent: task.AgentConfig{Type: "claude"}} + exec := &storage.Execution{ResumeSessionID: "orig-session-123"} + cmd := runner.buildInnerCmd(tk, exec, true) + + cmdStr := strings.Join(cmd, " ") + if !strings.Contains(cmdStr, "--resume orig-session-123") { + t.Errorf("expected --resume flag with correct session ID, got %q", cmdStr) + } + }) + + t.Run("gemini", func(t *testing.T) { + tk := &task.Task{Agent: task.AgentConfig{Type: "gemini"}} + exec := &storage.Execution{} + cmd := runner.buildInnerCmd(tk, exec, false) + + cmdStr := strings.Join(cmd, " ") + if !strings.Contains(cmdStr, "gemini -p \"$INST\"") { + t.Errorf("expected gemini command with safer quoting, got %q", cmdStr) + } + }) + + t.Run("custom-binaries", func(t *testing.T) { + runnerCustom := &ContainerRunner{ + ClaudeBinary: "/usr/bin/claude-v2", + GeminiBinary: "/usr/local/bin/gemini-pro", + } + + tkClaude := &task.Task{Agent: task.AgentConfig{Type: "claude"}} + cmdClaude := runnerCustom.buildInnerCmd(tkClaude, &storage.Execution{}, false) + if !strings.Contains(strings.Join(cmdClaude, " "), "/usr/bin/claude-v2 -p") { + t.Errorf("expected custom claude binary, got %q", cmdClaude) + } + + tkGemini := &task.Task{Agent: task.AgentConfig{Type: "gemini"}} + cmdGemini := runnerCustom.buildInnerCmd(tkGemini, &storage.Execution{}, false) + if !strings.Contains(strings.Join(cmdGemini, " "), "/usr/local/bin/gemini-pro -p") { + t.Errorf("expected custom gemini binary, got %q", cmdGemini) + } + }) +} + +func TestContainerRunner_Run_PreservesWorkspaceOnFailure(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + // Mock docker run to exit 1 + if name == "docker" { + return exec.Command("sh", "-c", "exit 1") + } + // Mock git clone to succeed and create the directory + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "test-task", + RepositoryURL: "https://github.com/example/repo.git", + Agent: task.AgentConfig{Type: "claude"}, + } + exec := &storage.Execution{ID: "test-exec", TaskID: "test-task"} + + err := runner.Run(context.Background(), tk, exec) + if err == nil { + t.Fatal("expected error due to mocked docker failure") + } + + // Verify SandboxDir was set and directory exists. + if exec.SandboxDir == "" { + t.Fatal("expected SandboxDir to be set even on failure") + } + if _, statErr := os.Stat(exec.SandboxDir); statErr != nil { + t.Errorf("expected sandbox directory to be preserved, but stat failed: %v", statErr) + } else { + os.RemoveAll(exec.SandboxDir) + } +} + +func TestBlockedError_IncludesSandboxDir(t *testing.T) { + // This test requires mocking 'docker run' or the whole Run() which is hard. + // But we can test that returning BlockedError works. + err := &BlockedError{ + QuestionJSON: `{"text":"?"}`, + SessionID: "s1", + SandboxDir: "/tmp/s1", + } + if !strings.Contains(err.Error(), "task blocked") { + t.Errorf("wrong error message: %v", err) + } +} + +func TestIsCompletionReport(t *testing.T) { + tests := []struct { + name string + json string + expected bool + }{ + { + name: "real question with options", + json: `{"text": "Should I proceed with implementation?", "options": ["Yes", "No"]}`, + expected: false, + }, + { + name: "real question no options", + json: `{"text": "Which approach do you prefer?"}`, + expected: false, + }, + { + name: "completion report no options no question mark", + json: `{"text": "All tests pass. Implementation complete. Summary written to CLAUDOMATOR_SUMMARY_FILE."}`, + expected: true, + }, + { + name: "completion report with empty options", + json: `{"text": "Feature implemented and committed.", "options": []}`, + expected: true, + }, + { + name: "invalid json treated as not a report", + json: `not json`, + expected: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isCompletionReport(tt.json) + if got != tt.expected { + t.Errorf("isCompletionReport(%q) = %v, want %v", tt.json, got, tt.expected) + } + }) + } +} + +func TestTailFile_ReturnsLastNLines(t *testing.T) { + f, err := os.CreateTemp("", "tailfile-*") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + for i := 1; i <= 30; i++ { + fmt.Fprintf(f, "line %d\n", i) + } + f.Close() + + got := tailFile(f.Name(), 5) + lines := strings.Split(strings.TrimSpace(got), "\n") + if len(lines) != 5 { + t.Fatalf("want 5 lines, got %d: %q", len(lines), got) + } + if lines[0] != "line 26" || lines[4] != "line 30" { + t.Errorf("want lines 26-30, got: %q", got) + } +} + +func TestDetectUncommittedChanges_ModifiedFile(t *testing.T) { + dir := t.TempDir() + run := func(args ...string) { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("%v: %s", args, out) + } + } + run("git", "init", dir) + run("git", "config", "user.email", "test@test.com") + run("git", "config", "user.name", "Test") + // Create and commit a file + if err := os.WriteFile(dir+"/main.go", []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + run("git", "add", "main.go") + run("git", "commit", "-m", "init") + // Now modify without committing — simulates agent that forgot to commit + if err := os.WriteFile(dir+"/main.go", []byte("package main\n// changed"), 0644); err != nil { + t.Fatal(err) + } + err := detectUncommittedChanges(dir) + if err == nil { + t.Fatal("expected error for modified uncommitted file, got nil") + } + if !strings.Contains(err.Error(), "uncommitted") { + t.Errorf("error should mention uncommitted, got: %v", err) + } +} + +func TestDetectUncommittedChanges_NewUntrackedSourceFile(t *testing.T) { + dir := t.TempDir() + run := func(args ...string) { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("%v: %s", args, out) + } + } + run("git", "init", dir) + run("git", "config", "user.email", "test@test.com") + run("git", "config", "user.name", "Test") + run("git", "commit", "--allow-empty", "-m", "init") + // Agent wrote a new file but never committed it + if err := os.WriteFile(dir+"/newfile.go", []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + err := detectUncommittedChanges(dir) + if err == nil { + t.Fatal("expected error for new untracked source file, got nil") + } +} + +func TestDetectUncommittedChanges_ScaffoldFilesIgnored(t *testing.T) { + dir := t.TempDir() + run := func(args ...string) { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("%v: %s", args, out) + } + } + run("git", "init", dir) + run("git", "config", "user.email", "test@test.com") + run("git", "config", "user.name", "Test") + run("git", "commit", "--allow-empty", "-m", "init") + // Write only scaffold files that the harness injects — should not trigger error + _ = os.WriteFile(dir+"/.claudomator-env", []byte("KEY=val"), 0600) + _ = os.WriteFile(dir+"/.claudomator-instructions.txt", []byte("do stuff"), 0644) + _ = os.MkdirAll(dir+"/.agent-home/.claude", 0755) + err := detectUncommittedChanges(dir) + if err != nil { + t.Errorf("scaffold files should not trigger uncommitted error, got: %v", err) + } +} + +func TestDetectUncommittedChanges_CleanRepo(t *testing.T) { + dir := t.TempDir() + run := func(args ...string) { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("%v: %s", args, out) + } + } + run("git", "init", dir) + run("git", "config", "user.email", "test@test.com") + run("git", "config", "user.name", "Test") + if err := os.WriteFile(dir+"/main.go", []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + run("git", "add", "main.go") + run("git", "commit", "-m", "init") + // No modifications — should pass + err := detectUncommittedChanges(dir) + if err != nil { + t.Errorf("clean repo should not error, got: %v", err) + } +} + +func TestGitSafe_PrependsSafeDirectory(t *testing.T) { + got := gitSafe("-C", "/some/path", "status") + want := []string{"-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-c", "tag.gpgsign=false", "-C", "/some/path", "status"} + if len(got) != len(want) { + t.Fatalf("gitSafe() = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("gitSafe()[%d] = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestContainerRunner_MissingCredentials_FailsFast(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + claudeConfigDir := t.TempDir() + + // Set up ClaudeConfigDir with MISSING credentials (so pre-flight fails) + // Don't create .credentials.json + // But DO create .claude.json so the test isolates the credentials check + if err := os.WriteFile(filepath.Join(claudeConfigDir, ".claude.json"), []byte("{}"), 0644); err != nil { + t.Fatal(err) + } + + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + ClaudeConfigDir: claudeConfigDir, + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "test-missing-creds", + RepositoryURL: "https://github.com/example/repo.git", + Agent: task.AgentConfig{Type: "claude"}, + } + e := &storage.Execution{ID: "test-exec", TaskID: "test-missing-creds"} + + err := runner.Run(context.Background(), tk, e) + if err == nil { + t.Fatal("expected error due to missing credentials, got nil") + } + if !strings.Contains(err.Error(), "credentials not found") { + t.Errorf("expected 'credentials not found' error, got: %v", err) + } +} + +func TestContainerRunner_MissingSettings_FailsFast(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + claudeConfigDir := t.TempDir() + + // Only create credentials but NOT .claude.json + if err := os.WriteFile(filepath.Join(claudeConfigDir, ".credentials.json"), []byte("{}"), 0600); err != nil { + t.Fatal(err) + } + + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + ClaudeConfigDir: claudeConfigDir, + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "test-missing-settings", + RepositoryURL: "https://github.com/example/repo.git", + Agent: task.AgentConfig{Type: "claude"}, + } + e := &storage.Execution{ID: "test-exec-2", TaskID: "test-missing-settings"} + + err := runner.Run(context.Background(), tk, e) + if err == nil { + t.Fatal("expected error due to missing settings, got nil") + } + if !strings.Contains(err.Error(), "claude settings") { + t.Errorf("expected 'claude settings' error, got: %v", err) + } +} + +func TestIsAuthError_DetectsAllVariants(t *testing.T) { + tests := []struct { + msg string + want bool + }{ + {"Not logged in", true}, + {"OAuth token has expired", true}, + {"authentication_error: invalid token", true}, + {"Please run /login to authenticate", true}, + {"container execution failed: exit status 1", false}, + {"git clone failed", false}, + {"", false}, + } + for _, tt := range tests { + var err error + if tt.msg != "" { + err = fmt.Errorf("%s", tt.msg) + } + got := isAuthError(err) + if got != tt.want { + t.Errorf("isAuthError(%q) = %v, want %v", tt.msg, got, tt.want) + } + } +} + +func TestContainerRunner_AuthError_SyncsAndRetries(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + // Create a sync script that creates a marker file + syncDir := t.TempDir() + syncMarker := filepath.Join(syncDir, "sync-called") + syncScript := filepath.Join(syncDir, "sync-creds") + os.WriteFile(syncScript, []byte("#!/bin/sh\ntouch "+syncMarker+"\n"), 0755) + + claudeConfigDir := t.TempDir() + // Create both credential files in ClaudeConfigDir + os.WriteFile(filepath.Join(claudeConfigDir, ".credentials.json"), []byte(`{"token":"fresh"}`), 0600) + os.WriteFile(filepath.Join(claudeConfigDir, ".claude.json"), []byte("{}"), 0644) + + callCount := 0 + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + ClaudeConfigDir: claudeConfigDir, + CredentialSyncCmd: syncScript, + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" { + if len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + } + return exec.Command("true") + } + if name == "docker" { + callCount++ + if callCount == 1 { + // First docker call fails with auth error + return exec.Command("sh", "-c", "echo 'Not logged in' >&2; exit 1") + } + // Second docker call "succeeds" + return exec.Command("sh", "-c", "exit 0") + } + if name == syncScript { + return exec.Command("sh", "-c", "touch "+syncMarker) + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "auth-retry-test", + RepositoryURL: "https://github.com/example/repo.git", + Agent: task.AgentConfig{Type: "claude", Instructions: "test"}, + } + e := &storage.Execution{ID: "auth-retry-exec", TaskID: "auth-retry-test"} + + // Run — first attempt will fail with auth error, triggering sync+retry + runner.Run(context.Background(), tk, e) + // We don't check error strictly since second run may also fail (git push etc.) + // What we care about is that docker was called twice and sync was called + if callCount < 2 { + t.Errorf("expected docker to be called at least twice (original + retry), got %d", callCount) + } + if _, err := os.Stat(syncMarker); os.IsNotExist(err) { + t.Error("expected sync-credentials to be called, but marker file not found") + } +} + +func TestContainerRunner_ClonesStoryBranch(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + var checkoutArgs []string + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + // Capture checkout calls: both "git checkout <branch>" and "git -C <dir> checkout <branch>" + for i, a := range arg { + if a == "checkout" { + checkoutArgs = append([]string{}, arg[i:]...) + break + } + } + if name == "docker" { + return exec.Command("sh", "-c", "exit 1") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "story-branch-test", + RepositoryURL: "https://example.com/repo.git", + BranchName: "story/my-feature", + Agent: task.AgentConfig{Type: "claude"}, + } + e := &storage.Execution{ID: "exec-1", TaskID: "story-branch-test"} + + runner.Run(context.Background(), tk, e) + os.RemoveAll(e.SandboxDir) + + // Assert git checkout was called with the story branch name. + if len(checkoutArgs) == 0 { + t.Fatal("expected git checkout to be called for story branch, but it was not") + } + found := false + for _, a := range checkoutArgs { + if a == "story/my-feature" { + found = true + break + } + } + if !found { + t.Errorf("expected git checkout story/my-feature, got args: %v", checkoutArgs) + } +} + +func TestContainerRunner_ClonesDefaultBranchWhenNoBranchName(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + var cloneArgs []string + runner := &ContainerRunner{ + Logger: logger, + Image: "busybox", + Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "git" && len(arg) > 0 && arg[0] == "clone" { + cloneArgs = append([]string{}, arg...) + dir := arg[len(arg)-1] + os.MkdirAll(dir, 0755) + return exec.Command("true") + } + if name == "docker" { + return exec.Command("sh", "-c", "exit 1") + } + return exec.Command("true") + }, + } + + tk := &task.Task{ + ID: "no-branch-test", + RepositoryURL: "https://example.com/repo.git", + Agent: task.AgentConfig{Type: "claude"}, + } + e := &storage.Execution{ID: "exec-2", TaskID: "no-branch-test"} + + runner.Run(context.Background(), tk, e) + os.RemoveAll(e.SandboxDir) + + for _, a := range cloneArgs { + if a == "--branch" { + t.Errorf("expected no --branch flag for task without BranchName, got args: %v", cloneArgs) + } + } +} + +func TestEnsureStoryBranch_CreatesMissingBranch(t *testing.T) { + // Set up a bare repo and a local clone to test branch creation. + dir := t.TempDir() + bare := filepath.Join(dir, "bare.git") + local := filepath.Join(dir, "local") + + // Create bare repo with an initial commit. + if out, err := exec.Command("git", "init", "--bare", bare).CombinedOutput(); err != nil { + t.Fatalf("git init bare: %v\n%s", err, out) + } + if out, err := exec.Command("git", "clone", bare, local).CombinedOutput(); err != nil { + t.Fatalf("git clone: %v\n%s", err, out) + } + if out, err := exec.Command("git", "-C", local, "commit", "--allow-empty", "-m", "init").CombinedOutput(); err != nil { + t.Fatalf("git commit: %v\n%s", err, out) + } + if out, err := exec.Command("git", "-C", local, "push", "origin", "main").CombinedOutput(); err != nil { + // try master + if out2, err2 := exec.Command("git", "-C", local, "push", "origin", "HEAD:main").CombinedOutput(); err2 != nil { + t.Fatalf("git push main: %v\n%s\n%s", err, out, out2) + } + } + + runner := &ContainerRunner{Logger: slog.Default()} + + branch := "story/test-branch" + + // Branch should not exist yet. + out, _ := exec.Command("git", "ls-remote", "--heads", bare, branch).CombinedOutput() + if len(strings.TrimSpace(string(out))) > 0 { + t.Fatal("branch should not exist before ensureStoryBranch") + } + + if err := runner.ensureStoryBranch(context.Background(), bare, branch, ""); err != nil { + t.Fatalf("ensureStoryBranch: %v", err) + } + + // Branch should now exist in the bare repo. + out, err := exec.Command("git", "ls-remote", "--heads", bare, branch).CombinedOutput() + if err != nil || len(strings.TrimSpace(string(out))) == 0 { + t.Errorf("branch %q not found in bare repo after ensureStoryBranch: %s", branch, out) + } +} + +func TestEnsureStoryBranch_IdempotentIfExists(t *testing.T) { + dir := t.TempDir() + bare := filepath.Join(dir, "bare.git") + local := filepath.Join(dir, "local") + + if out, err := exec.Command("git", "init", "--bare", bare).CombinedOutput(); err != nil { + t.Fatalf("git init bare: %v\n%s", err, out) + } + if out, err := exec.Command("git", "clone", bare, local).CombinedOutput(); err != nil { + t.Fatalf("git clone: %v\n%s", err, out) + } + if out, err := exec.Command("git", "-C", local, "commit", "--allow-empty", "-m", "init").CombinedOutput(); err != nil { + t.Fatalf("git commit: %v\n%s", err, out) + } + if _, err := exec.Command("git", "-C", local, "push", "origin", "HEAD:main").CombinedOutput(); err != nil { + t.Fatalf("push main: %v", err) + } + + branch := "story/existing-branch" + // Pre-create the branch. + if out, err := exec.Command("git", "-C", local, "checkout", "-b", branch).CombinedOutput(); err != nil { + t.Fatalf("checkout -b: %v\n%s", err, out) + } + if out, err := exec.Command("git", "-C", local, "push", "origin", branch).CombinedOutput(); err != nil { + t.Fatalf("push branch: %v\n%s", err, out) + } + + runner := &ContainerRunner{Logger: slog.Default()} + + // Should be a no-op, not an error. + if err := runner.ensureStoryBranch(context.Background(), bare, branch, ""); err != nil { + t.Fatalf("ensureStoryBranch on existing branch: %v", err) + } +} diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 315030d..09169bd 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -2,9 +2,11 @@ package executor import ( "context" + "encoding/json" "errors" "fmt" "log/slog" + "os/exec" "path/filepath" "strings" "sync" @@ -25,6 +27,7 @@ type Store interface { ListSubtasks(parentID string) ([]*task.Task, error) ListExecutions(taskID string) ([]*storage.Execution, error) CreateExecution(e *storage.Execution) error + CreateExecutionAndSetRunning(e *storage.Execution) error UpdateExecution(e *storage.Execution) error UpdateTaskState(id string, newState task.State) error UpdateTaskQuestion(taskID, questionJSON string) error @@ -32,6 +35,14 @@ type Store interface { AppendTaskInteraction(taskID string, interaction task.Interaction) error UpdateTaskAgent(id string, agent task.AgentConfig) error UpdateExecutionChangestats(execID string, stats *task.Changestats) error + RecordAgentEvent(e storage.AgentEvent) error + GetProject(id string) (*task.Project, error) + GetStory(id string) (*task.Story, error) + ListTasksByStory(storyID string) ([]*task.Task, error) + UpdateStoryStatus(id string, status task.StoryState) error + CreateTask(t *task.Task) error + UpdateTaskCheckerReport(id, report string) error + GetCheckerTask(checkedTaskID string) (*task.Task, error) } // LogPather is an optional interface runners can implement to provide the log @@ -56,24 +67,28 @@ type workItem struct { // Pool manages a bounded set of concurrent task workers. type Pool struct { maxConcurrent int + maxPerAgent int runners map[string]Runner store Store logger *slog.Logger - depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s - - mu sync.Mutex - active int - activePerAgent map[string]int - rateLimited map[string]time.Time // agentType -> until - cancels map[string]context.CancelFunc // taskID → cancel - resultCh chan *Result - workCh chan workItem // internal bounded queue; Submit enqueues here - doneCh chan struct{} // signals when a worker slot is freed - Questions *QuestionRegistry - Classifier *Classifier - // LLM, when non-nil, enables LLM-synthesized summaries for executions - // whose stdout did not include a "## Summary" heading. - LLM *llm.Client + depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s + requeueDelay time.Duration // how long to wait before requeuing a blocked-per-agent task; defaults to 30s + + mu sync.Mutex + active int + activePerAgent map[string]int + rateLimited map[string]time.Time // agentType -> until + cancels map[string]context.CancelFunc // taskID → cancel + consecutiveFailures map[string]int // agentType -> count + closed bool // set to true when Shutdown has been called + resultCh chan *Result + startedCh chan string // task IDs that just transitioned to RUNNING + workCh chan workItem // internal bounded queue; Submit enqueues here + doneCh chan struct{} // signals when a worker slot is freed + workerWg sync.WaitGroup // tracks in-flight execute/executeResume goroutines + dispatchDone chan struct{} // closed when the dispatch goroutine exits + Classifier *Classifier + LLM *llm.Client } // Result is emitted when a task execution completes. @@ -88,18 +103,22 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger * maxConcurrent = 1 } p := &Pool{ - maxConcurrent: maxConcurrent, - runners: runners, - store: store, - logger: logger, - depPollInterval: 5 * time.Second, - activePerAgent: make(map[string]int), - rateLimited: make(map[string]time.Time), - cancels: make(map[string]context.CancelFunc), - resultCh: make(chan *Result, maxConcurrent*2), - workCh: make(chan workItem, maxConcurrent*10+100), - doneCh: make(chan struct{}, maxConcurrent), - Questions: NewQuestionRegistry(), + maxConcurrent: maxConcurrent, + maxPerAgent: 1, + runners: runners, + store: store, + logger: logger, + depPollInterval: 5 * time.Second, + requeueDelay: 30 * time.Second, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), + cancels: make(map[string]context.CancelFunc), + consecutiveFailures: make(map[string]int), + resultCh: make(chan *Result, maxConcurrent*2), + startedCh: make(chan string, maxConcurrent*2), + workCh: make(chan workItem, maxConcurrent*10+100), + doneCh: make(chan struct{}, maxConcurrent), + dispatchDone: make(chan struct{}), } go p.dispatch() return p @@ -109,6 +128,7 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger * // and launches goroutines as soon as a pool slot is available. This prevents // tasks from being rejected when the pool is temporarily at capacity. func (p *Pool) dispatch() { + defer close(p.dispatchDone) for item := range p.workCh { for { p.mu.Lock() @@ -116,9 +136,9 @@ func (p *Pool) dispatch() { p.active++ p.mu.Unlock() if item.exec != nil { - go p.executeResume(item.ctx, item.task, item.exec) + p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.executeResume(i.ctx, i.task, i.exec) }(item) } else { - go p.execute(item.ctx, item.task) + p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.execute(i.ctx, i.task) }(item) } break } @@ -132,19 +152,64 @@ func (p *Pool) dispatch() { // work queue is full. When the pool is at capacity the task is buffered and // dispatched as soon as a slot becomes available. func (p *Pool) Submit(ctx context.Context, t *task.Task) error { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return fmt.Errorf("executor pool is shut down") + } + // Send while holding the lock so that Shutdown cannot close workCh between + // the closed-check above and the send below. The dispatch goroutine never + // holds p.mu while receiving from workCh, so this cannot deadlock. select { case p.workCh <- workItem{ctx: ctx, task: t}: + p.mu.Unlock() return nil default: + p.mu.Unlock() return fmt.Errorf("executor work queue full (capacity %d)", cap(p.workCh)) } } +// Started returns a channel that emits task IDs when they transition to RUNNING. +func (p *Pool) Started() <-chan string { + return p.startedCh +} + // Results returns the channel for reading execution results. func (p *Pool) Results() <-chan *Result { return p.resultCh } +// Shutdown stops accepting new work and waits for all in-flight workers to +// finish. Returns ctx.Err() if the context deadline is exceeded before all +// workers complete. +func (p *Pool) Shutdown(ctx context.Context) error { + // Stop the dispatch goroutine. We must wait for it to exit before calling + // workerWg.Wait() to avoid a race between dispatch's Add(1) and Wait(). + p.mu.Lock() + p.closed = true + p.mu.Unlock() + close(p.workCh) + select { + case <-p.dispatchDone: + case <-ctx.Done(): + return ctx.Err() + } + + done := make(chan struct{}) + go func() { + p.workerWg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Cancel requests cancellation of a running task. Returns false if the task // is not currently running in this pool. func (p *Pool) Cancel(taskID string) bool { @@ -250,11 +315,12 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex exec.StartTime = time.Now().UTC() exec.Status = "RUNNING" - if err := p.store.CreateExecution(exec); err != nil { + if err := p.store.CreateExecutionAndSetRunning(exec); err != nil { p.logger.Error("failed to create resume execution record", "error", err) } - if err := p.store.UpdateTaskState(t.ID, task.StateRunning); err != nil { - p.logger.Error("failed to update task state", "error", err) + select { + case p.startedCh <- t.ID: + default: } var cancel context.CancelFunc @@ -273,6 +339,19 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex p.mu.Unlock() }() + // Populate RepositoryURL from Project registry if missing (ADR-007). + if t.RepositoryURL == "" && t.Project != "" { + if proj, err := p.store.GetProject(t.Project); err == nil && proj.RemoteURL != "" { + t.RepositoryURL = proj.RemoteURL + } + } + // Populate BranchName from Story if missing (ADR-007). + if t.BranchName == "" && t.StoryID != "" { + if story, err := p.store.GetStory(t.StoryID); err == nil && story.BranchName != "" { + t.BranchName = story.BranchName + } + } + err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() @@ -289,16 +368,32 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. if retry.IsRateLimitError(err) || isQuotaExhausted(err) { p.mu.Lock() retryAfter := retry.ParseRetryAfter(err.Error()) - if retryAfter == 0 { - if isQuotaExhausted(err) { + reason := "transient" + if isQuotaExhausted(err) { + reason = "quota" + if retryAfter == 0 { retryAfter = 5 * time.Hour - } else { - retryAfter = 1 * time.Minute } + } else if retryAfter == 0 { + retryAfter = 1 * time.Minute } - p.rateLimited[agentType] = time.Now().Add(retryAfter) + until := time.Now().Add(retryAfter) + p.rateLimited[agentType] = until p.logger.Info("agent rate limited", "agent", agentType, "retryAfter", retryAfter, "quotaExhausted", isQuotaExhausted(err)) p.mu.Unlock() + go func() { + ev := storage.AgentEvent{ + ID: uuid.New().String(), + Agent: agentType, + Event: "rate_limited", + Timestamp: time.Now(), + Until: &until, + Reason: reason, + } + if recErr := p.store.RecordAgentEvent(ev); recErr != nil { + p.logger.Warn("failed to record agent event", "error", recErr) + } + }() } var blockedErr *BlockedError @@ -335,9 +430,51 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil { p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err) } + p.mu.Lock() + p.consecutiveFailures[agentType]++ + p.mu.Unlock() + } + // If this is a checker task, attach the failure report for any terminal + // failure state (FAILED, TIMED_OUT, CANCELLED, BUDGET_EXCEEDED). + if t.CheckerForTaskID != "" && exec.ErrorMsg != "" { + if reportErr := p.store.UpdateTaskCheckerReport(t.CheckerForTaskID, exec.ErrorMsg); reportErr != nil { + p.logger.Error("handleRunResult: failed to set checker report", "taskID", t.CheckerForTaskID, "error", reportErr) + } + } + if t.StoryID != "" && exec.Status == "FAILED" { + storyID := t.StoryID + errMsg := exec.ErrorMsg + go func() { + story, getErr := p.store.GetStory(storyID) + if getErr != nil { + return + } + if story.Status == task.StoryValidating { + p.checkValidationResult(ctx, storyID, task.StateFailed, errMsg) + } + }() } } else { - if t.ParentTaskID == "" { + p.mu.Lock() + p.consecutiveFailures[agentType] = 0 + p.mu.Unlock() + if t.CheckerForTaskID != "" { + // Checker task succeeded — auto-accept the checked task. + exec.Status = "COMPLETED" + if err := p.store.UpdateTaskState(t.ID, task.StateCompleted); err != nil { + p.logger.Error("handleRunResult: failed to complete checker task", "taskID", t.ID, "error", err) + } + checkedTask, getErr := p.store.GetTask(t.CheckerForTaskID) + if getErr == nil { + if acceptErr := p.store.UpdateTaskState(t.CheckerForTaskID, task.StateCompleted); acceptErr != nil { + p.logger.Error("handleRunResult: failed to auto-accept checked task", "taskID", t.CheckerForTaskID, "error", acceptErr) + } else if checkedTask.StoryID != "" { + go p.checkStoryCompletion(context.Background(), checkedTask.StoryID) + } + } else { + p.logger.Error("handleRunResult: failed to get checked task", "taskID", t.CheckerForTaskID, "error", getErr) + } + } else if t.ParentTaskID == "" { subtasks, subErr := p.store.ListSubtasks(t.ID) if subErr != nil { p.logger.Error("failed to list subtasks", "taskID", t.ID, "error", subErr) @@ -352,6 +489,7 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. if err := p.store.UpdateTaskState(t.ID, task.StateReady); err != nil { p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateReady, "error", err) } + go p.spawnCheckerTask(context.Background(), t) } } else { exec.Status = "COMPLETED" @@ -360,6 +498,21 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. } p.maybeUnblockParent(t.ParentTaskID) } + if t.StoryID != "" { + storyID := t.StoryID + go func() { + story, getErr := p.store.GetStory(storyID) + if getErr != nil { + p.logger.Error("handleRunResult: failed to get story", "storyID", storyID, "error", getErr) + return + } + if story.Status == task.StoryValidating { + p.checkValidationResult(ctx, storyID, task.StateCompleted, "") + } else { + p.checkStoryCompletion(ctx, storyID) + } + }() + } } summary := exec.Summary @@ -374,6 +527,13 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. p.logger.Error("failed to update task summary", "taskID", t.ID, "error", summaryErr) } } + terminalFailure := exec.Status == "FAILED" || exec.Status == "TIMED_OUT" || exec.Status == "CANCELLED" || exec.Status == "BUDGET_EXCEEDED" + if t.CheckerForTaskID != "" && terminalFailure && summary != "" { + // Overwrite the initial error-message report with the richer summary. + if reportErr := p.store.UpdateTaskCheckerReport(t.CheckerForTaskID, summary); reportErr != nil { + p.logger.Error("handleRunResult: failed to update checker report with summary", "taskID", t.CheckerForTaskID, "error", reportErr) + } + } if exec.StdoutPath != "" { if cs := task.ParseChangestatFromFile(exec.StdoutPath); cs != nil { exec.Changestats = cs @@ -388,6 +548,256 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage. p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} } +// checkStoryCompletion checks whether all top-level tasks in a story have reached +// a terminal success state and transitions the story to SHIPPABLE if so. +// Subtasks are intentionally excluded — a parent task reaching READY/COMPLETED +// already accounts for its subtasks. +// CheckStoryCompletion is the exported entry point for story completion checks +// called from outside the package (e.g. the API accept handler). +func (p *Pool) CheckStoryCompletion(ctx context.Context, storyID string) { + p.checkStoryCompletion(ctx, storyID) +} + +func (p *Pool) checkStoryCompletion(ctx context.Context, storyID string) { + story, err := p.store.GetStory(storyID) + if err != nil { + p.logger.Error("checkStoryCompletion: failed to get story", "storyID", storyID, "error", err) + return + } + if story.Status != task.StoryInProgress { + return // already SHIPPABLE or beyond — nothing to do + } + tasks, err := p.store.ListTasksByStory(storyID) + if err != nil { + p.logger.Error("checkStoryCompletion: failed to list tasks", "storyID", storyID, "error", err) + return + } + if len(tasks) == 0 { + return + } + topLevelCount := 0 + for _, t := range tasks { + if t.ParentTaskID != "" { + continue // subtasks are covered by their parent + } + topLevelCount++ + if t.State != task.StateCompleted { + return // not all top-level tasks done; READY alone is not sufficient (checker may be pending) + } + } + if topLevelCount == 0 { + return // no top-level tasks — don't auto-complete + } + if err := p.store.UpdateStoryStatus(storyID, task.StoryShippable); err != nil { + p.logger.Error("checkStoryCompletion: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("story transitioned to SHIPPABLE", "storyID", storyID) + // Deploy is now triggered explicitly by the human via POST /api/stories/{id}/ship. +} + +// ShipStory merges the story branch and runs the deploy script. +// Returns an error if the story is not in SHIPPABLE state. +func (p *Pool) ShipStory(ctx context.Context, storyID string) error { + story, err := p.store.GetStory(storyID) + if err != nil { + return fmt.Errorf("story not found: %w", err) + } + if story.Status != task.StoryShippable { + return fmt.Errorf("story is not SHIPPABLE (current status: %s)", story.Status) + } + go p.triggerStoryDeploy(context.Background(), storyID) + return nil +} + +// spawnCheckerTask creates and submits a checker task for the given completed task. +// Guards: not called for subtasks, checker tasks, tasks without a repository URL, +// or tasks that already have a checker. +func (p *Pool) spawnCheckerTask(ctx context.Context, checked *task.Task) { + // Never spawn a checker for subtasks, checker tasks, or tasks without a repository. + if checked.ParentTaskID != "" || checked.CheckerForTaskID != "" || checked.RepositoryURL == "" { + return + } + // Idempotent: don't create a second checker if one already exists. + existing, err := p.store.GetCheckerTask(checked.ID) + if err != nil { + p.logger.Error("spawnCheckerTask: GetCheckerTask failed", "taskID", checked.ID, "error", err) + return + } + if existing != nil { + return + } + + criteria := checked.AcceptanceCriteria + if criteria == "" { + criteria = checked.Agent.Instructions + } + + instructions := fmt.Sprintf(`You are validating a completed task. Do not make any changes to the code or repository. + +Task: %s +Instructions given to the implementor: +%s + +Acceptance criteria: +%s + +Steps: +1. Clone the repository and review the changes made. +2. Verify each acceptance criterion is met. Run tests or make HTTP requests as needed. +3. If all criteria are satisfied, exit normally (success). +4. If any criterion is not met, use the Bash tool to exit with a non-zero code: + bash -c "exit 1" + Before exiting, write a brief summary of what failed.`, checked.Name, checked.Agent.Instructions, criteria) + + now := time.Now().UTC() + checker := &task.Task{ + ID: uuid.New().String(), + Name: "Check: " + checked.Name, + CheckerForTaskID: checked.ID, + RepositoryURL: checked.RepositoryURL, + Agent: task.AgentConfig{ + Type: "claude", + Instructions: instructions, + MaxBudgetUSD: 0.50, + AllowedTools: []string{"Bash", "Read", "Glob", "Grep"}, + }, + Timeout: task.Duration{Duration: 10 * time.Minute}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + + if err := p.store.CreateTask(checker); err != nil { + p.logger.Error("spawnCheckerTask: CreateTask failed", "error", err) + return + } + checker.State = task.StateQueued + if err := p.store.UpdateTaskState(checker.ID, task.StateQueued); err != nil { + p.logger.Error("spawnCheckerTask: UpdateTaskState failed", "error", err) + return + } + if err := p.Submit(ctx, checker); err != nil { + p.logger.Error("spawnCheckerTask: Submit failed", "error", err) + } +} + +// triggerStoryDeploy runs the project deploy script for a SHIPPABLE story +// and advances it to DEPLOYED on success. +func (p *Pool) triggerStoryDeploy(ctx context.Context, storyID string) { + story, err := p.store.GetStory(storyID) + if err != nil { + p.logger.Error("triggerStoryDeploy: failed to get story", "storyID", storyID, "error", err) + return + } + if story.ProjectID == "" { + return + } + proj, err := p.store.GetProject(story.ProjectID) + if err != nil { + p.logger.Error("triggerStoryDeploy: failed to get project", "storyID", storyID, "projectID", story.ProjectID, "error", err) + return + } + if proj.DeployScript == "" { + return + } + // Merge story branch to main before deploying (ADR-007). + if story.BranchName != "" && proj.LocalPath != "" { + mergeSteps := [][]string{ + {"git", "-C", proj.LocalPath, "fetch", "origin"}, + {"git", "-C", proj.LocalPath, "checkout", "main"}, + {"git", "-C", proj.LocalPath, "merge", "--no-ff", story.BranchName, "-m", "Merge " + story.BranchName}, + {"git", "-C", proj.LocalPath, "push", "origin", "main"}, + } + for _, args := range mergeSteps { + if mergeOut, mergeErr := exec.CommandContext(ctx, args[0], args[1:]...).CombinedOutput(); mergeErr != nil { + p.logger.Error("triggerStoryDeploy: merge failed", "cmd", args, "output", string(mergeOut), "error", mergeErr) + return + } + } + p.logger.Info("story branch merged to main", "storyID", storyID, "branch", story.BranchName) + } + out, err := exec.CommandContext(ctx, proj.DeployScript).CombinedOutput() + if err != nil { + p.logger.Error("triggerStoryDeploy: deploy script failed", "storyID", storyID, "script", proj.DeployScript, "output", string(out), "error", err) + return + } + if err := p.store.UpdateStoryStatus(storyID, task.StoryDeployed); err != nil { + p.logger.Error("triggerStoryDeploy: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("story transitioned to DEPLOYED", "storyID", storyID) + go p.createValidationTask(ctx, storyID) +} + +// createValidationTask creates a validation subtask from the story's ValidationJSON +// and transitions the story to VALIDATING. +func (p *Pool) createValidationTask(ctx context.Context, storyID string) { + story, err := p.store.GetStory(storyID) + if err != nil { + p.logger.Error("createValidationTask: failed to get story", "storyID", storyID, "error", err) + return + } + if story.ValidationJSON == "" { + p.logger.Warn("createValidationTask: story has no ValidationJSON, skipping", "storyID", storyID) + return + } + + var spec map[string]interface{} + if err := json.Unmarshal([]byte(story.ValidationJSON), &spec); err != nil { + p.logger.Error("createValidationTask: failed to parse ValidationJSON", "storyID", storyID, "error", err) + return + } + + instructions := fmt.Sprintf("Validate the deployment for story %q.\n\nValidation spec:\n%s", story.Name, story.ValidationJSON) + + now := time.Now().UTC() + vtask := &task.Task{ + ID: uuid.New().String(), + Name: fmt.Sprintf("validation: %s", story.Name), + StoryID: storyID, + State: task.StateQueued, + Agent: task.AgentConfig{Type: "claude", Instructions: instructions}, + Tags: []string{}, + DependsOn: []string{}, + CreatedAt: now, + UpdatedAt: now, + } + + if err := p.store.CreateTask(vtask); err != nil { + p.logger.Error("createValidationTask: failed to create task", "storyID", storyID, "error", err) + return + } + if err := p.store.UpdateStoryStatus(storyID, task.StoryValidating); err != nil { + p.logger.Error("createValidationTask: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("validation task created and story transitioned to VALIDATING", "storyID", storyID, "taskID", vtask.ID) + p.Submit(ctx, vtask) //nolint:errcheck +} + +// checkValidationResult inspects a completed validation task and transitions +// the story to REVIEW_READY or NEEDS_FIX accordingly. +func (p *Pool) checkValidationResult(ctx context.Context, storyID string, taskState task.State, errorMsg string) { + if taskState == task.StateCompleted { + if err := p.store.UpdateStoryStatus(storyID, task.StoryReviewReady); err != nil { + p.logger.Error("checkValidationResult: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("story transitioned to REVIEW_READY", "storyID", storyID) + } else { + if err := p.store.UpdateStoryStatus(storyID, task.StoryNeedsFix); err != nil { + p.logger.Error("checkValidationResult: failed to update story status", "storyID", storyID, "error", err) + return + } + p.logger.Info("story transitioned to NEEDS_FIX", "storyID", storyID, "error", errorMsg) + } +} + // ActiveCount returns the number of currently running tasks. func (p *Pool) ActiveCount() int { p.mu.Lock() @@ -395,6 +805,34 @@ func (p *Pool) ActiveCount() int { return p.active } +// AgentStatusInfo holds the current state of a single agent. +type AgentStatusInfo struct { + Agent string `json:"agent"` + ActiveTasks int `json:"active_tasks"` + RateLimited bool `json:"rate_limited"` + Until *time.Time `json:"until,omitempty"` +} + +// AgentStatuses returns the current status of all registered agents. +func (p *Pool) AgentStatuses() []AgentStatusInfo { + p.mu.Lock() + defer p.mu.Unlock() + now := time.Now() + var out []AgentStatusInfo + for agent := range p.runners { + info := AgentStatusInfo{ + Agent: agent, + ActiveTasks: p.activePerAgent[agent], + } + if deadline, ok := p.rateLimited[agent]; ok && now.Before(deadline) { + info.RateLimited = true + info.Until = &deadline + } + out = append(out, info) + } + return out +} + // pickAgent selects the best agent from the given SystemStatus using explicit // load balancing: prefer the available (non-rate-limited) agent with the fewest // active tasks. If all agents are rate-limited, fall back to fewest active. @@ -436,6 +874,18 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { activeTasks[agent] = p.activePerAgent[agent] if deadline, ok := p.rateLimited[agent]; ok && now.After(deadline) { delete(p.rateLimited, agent) + agentName := agent + go func() { + ev := storage.AgentEvent{ + ID: uuid.New().String(), + Agent: agentName, + Event: "available", + Timestamp: time.Now(), + } + if recErr := p.store.RecordAgentEvent(ev); recErr != nil { + p.logger.Warn("failed to record agent available event", "error", recErr) + } + }() } rateLimited[agent] = now.Before(p.rateLimited[agent]) } @@ -479,9 +929,58 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { agentType = "claude" } + // Check dependencies before taking the per-agent slot to avoid deadlock: + // if a dependent task holds the slot while waiting for its dependency to run, + // the dependency can never start (maxPerAgent=1). + if len(t.DependsOn) > 0 { + ready, depErr := p.checkDepsReady(t) + if depErr != nil { + // A dependency hit a terminal failure — cancel this task immediately. + now := time.Now().UTC() + exec := &storage.Execution{ + ID: uuid.New().String(), + TaskID: t.ID, + StartTime: now, + EndTime: now, + Status: "CANCELLED", + ErrorMsg: depErr.Error(), + } + if createErr := p.store.CreateExecution(exec); createErr != nil { + p.logger.Error("failed to create execution record", "error", createErr) + } + if err := p.store.UpdateTaskState(t.ID, task.StateCancelled); err != nil { + p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateCancelled, "error", err) + } + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: depErr} + return + } + if !ready { + // Dependencies not yet done — requeue without holding the slot. + time.AfterFunc(p.requeueDelay, func() { p.workCh <- workItem{ctx: ctx, task: t} }) + return + } + } p.mu.Lock() + + if p.activePerAgent[agentType] >= p.maxPerAgent { + p.mu.Unlock() + time.AfterFunc(p.requeueDelay, func() { p.workCh <- workItem{ctx: ctx, task: t} }) + return + } if deadline, ok := p.rateLimited[agentType]; ok && time.Now().After(deadline) { delete(p.rateLimited, agentType) + agentName := agentType + go func() { + ev := storage.AgentEvent{ + ID: uuid.New().String(), + Agent: agentName, + Event: "available", + Timestamp: time.Now(), + } + if recErr := p.store.RecordAgentEvent(ev); recErr != nil { + p.logger.Warn("failed to record agent available event", "error", recErr) + } + }() } p.activePerAgent[agentType]++ p.mu.Unlock() @@ -512,30 +1011,6 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { return } - // Wait for all dependencies to complete before starting execution. - if len(t.DependsOn) > 0 { - if err := p.waitForDependencies(ctx, t); err != nil { - now := time.Now().UTC() - exec := &storage.Execution{ - ID: uuid.New().String(), - TaskID: t.ID, - StartTime: now, - EndTime: now, - Status: "FAILED", - ErrorMsg: err.Error(), - } - if createErr := p.store.CreateExecution(exec); createErr != nil { - p.logger.Error("failed to create execution record", "error", createErr) - } - if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil { - p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err) - } - p.decActiveAgent(agentType, &cleaned) - p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} - return - } - } - execID := uuid.New().String() exec := &storage.Execution{ ID: execID, @@ -554,12 +1029,13 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { } } - // Record execution start. - if err := p.store.CreateExecution(exec); err != nil { + // Record execution start atomically with the RUNNING state transition. + if err := p.store.CreateExecutionAndSetRunning(exec); err != nil { p.logger.Error("failed to create execution record", "error", err) } - if err := p.store.UpdateTaskState(t.ID, task.StateRunning); err != nil { - p.logger.Error("failed to update task state", "error", err) + select { + case p.startedCh <- t.ID: + default: } // Apply task timeout and register cancel so callers can stop this task. @@ -583,6 +1059,19 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) { priorExecs, priorErr := p.store.ListExecutions(t.ID) t = withFailureHistory(t, priorExecs, priorErr) + // Populate RepositoryURL from Project registry if missing (ADR-007). + if t.RepositoryURL == "" && t.Project != "" { + if proj, err := p.store.GetProject(t.Project); err == nil && proj.RemoteURL != "" { + t.RepositoryURL = proj.RemoteURL + } + } + // Populate BranchName from Story if missing (ADR-007). + if t.BranchName == "" && t.StoryID != "" { + if story, err := p.store.GetStory(t.StoryID); err == nil && story.BranchName != "" { + t.BranchName = story.BranchName + } + } + // Run the task. err = runner.Run(ctx, t, exec) exec.EndTime = time.Now().UTC() @@ -650,18 +1139,31 @@ func (p *Pool) RecoverStaleQueued(ctx context.Context) { } } -// RecoverStaleBlocked promotes any BLOCKED parent task to READY when all of its -// subtasks are already COMPLETED. This handles the case where the server was -// restarted after subtasks finished but before maybeUnblockParent could fire. +// RecoverStaleBlocked promotes any BLOCKED or QUEUED parent task to READY when +// all of its subtasks are already COMPLETED. This handles the case where the +// server was restarted after subtasks finished but before maybeUnblockParent +// could fire, and also the case where story approval pre-created subtasks +// without ever running the parent task. // Call this once on server startup, after RecoverStaleRunning and RecoverStaleQueued. func (p *Pool) RecoverStaleBlocked() { - tasks, err := p.store.ListTasks(storage.TaskFilter{State: task.StateBlocked}) - if err != nil { - p.logger.Error("RecoverStaleBlocked: list tasks", "error", err) - return - } - for _, t := range tasks { - p.maybeUnblockParent(t.ID) + ctx := context.Background() + for _, state := range []task.State{task.StateBlocked, task.StateQueued} { + tasks, err := p.store.ListTasks(storage.TaskFilter{State: state}) + if err != nil { + p.logger.Error("RecoverStaleBlocked: list tasks", "error", err, "state", state) + continue + } + for _, t := range tasks { + if t.ParentTaskID != "" { + continue // only promote actual parents + } + before := t.State + p.maybeUnblockParent(t.ID) + // If the parent was promoted, check story completion. + if after, err := p.store.GetTask(t.ID); err == nil && after.State != before && t.StoryID != "" { + p.checkStoryCompletion(ctx, t.StoryID) + } + } } } @@ -673,6 +1175,32 @@ var terminalFailureStates = map[task.State]bool{ task.StateBudgetExceeded: true, } +// depDoneStates are task states that satisfy a DependsOn dependency. +var depDoneStates = map[task.State]bool{ + task.StateCompleted: true, + task.StateReady: true, // leaf tasks finish at READY +} + +// checkDepsReady does a single synchronous check of t.DependsOn. +// Returns (true, nil) if all deps are done, (false, nil) if any are still pending, +// or (false, err) if a dep entered a terminal failure state. +func (p *Pool) checkDepsReady(t *task.Task) (bool, error) { + for _, depID := range t.DependsOn { + dep, err := p.store.GetTask(depID) + if err != nil { + return false, fmt.Errorf("dependency %q not found: %w", depID, err) + } + if depDoneStates[dep.State] { + continue + } + if terminalFailureStates[dep.State] { + return false, fmt.Errorf("dependency %q ended in state %s", depID, dep.State) + } + return false, nil // still pending + } + return true, nil +} + // withFailureHistory returns a shallow copy of t with prior failed execution // error messages prepended to SystemPromptAppend so the agent knows what went // wrong in previous attempts. @@ -710,16 +1238,16 @@ func withFailureHistory(t *task.Task, execs []*storage.Execution, err error) *ta return © } -// maybeUnblockParent transitions the parent task from BLOCKED to READY if all -// of its subtasks are in the COMPLETED state. If any subtask is not COMPLETED -// (including FAILED, CANCELLED, RUNNING, etc.) the parent stays BLOCKED. +// maybeUnblockParent transitions the parent task to READY if all of its subtasks +// are in the COMPLETED state. Handles both BLOCKED parents (ran, created subtasks, +// paused) and QUEUED parents (story approval created subtasks without running parent). func (p *Pool) maybeUnblockParent(parentID string) { parent, err := p.store.GetTask(parentID) if err != nil { p.logger.Error("maybeUnblockParent: get parent", "parentID", parentID, "error", err) return } - if parent.State != task.StateBlocked { + if parent.State != task.StateBlocked && parent.State != task.StateQueued { return } subtasks, err := p.store.ListSubtasks(parentID) @@ -727,6 +1255,11 @@ func (p *Pool) maybeUnblockParent(parentID string) { p.logger.Error("maybeUnblockParent: list subtasks", "parentID", parentID, "error", err) return } + // A task with no subtasks was never blocked by subtask delegation — don't promote it. + // This prevents incorrectly promoting leaf tasks that are stuck in QUEUED to READY. + if len(subtasks) == 0 { + return + } for _, sub := range subtasks { if sub.State != task.StateCompleted { return @@ -747,7 +1280,7 @@ func (p *Pool) waitForDependencies(ctx context.Context, t *task.Task) error { if err != nil { return fmt.Errorf("dependency %q not found: %w", depID, err) } - if dep.State == task.StateCompleted { + if depDoneStates[dep.State] { continue } if terminalFailureStates[dep.State] { diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index b1173cb..9214872 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "os" + "os/exec" "path/filepath" "strings" "sync" @@ -600,10 +601,17 @@ func TestPool_RecoverStaleRunning(t *testing.T) { // Execution record should be closed as FAILED. execs, _ := store.ListExecutions(tk.ID) - if len(execs) == 0 || execs[0].Status != "FAILED" { + var failedExec *storage.Execution + for _, e := range execs { + if e.ID == "exec-stale-1" { + failedExec = e + break + } + } + if failedExec == nil || failedExec.Status != "FAILED" { t.Errorf("execution status: want FAILED, got %+v", execs) } - if execs[0].ErrorMsg == "" { + if failedExec.ErrorMsg == "" { t.Error("expected non-empty error message on recovered execution") } @@ -739,6 +747,119 @@ func TestPool_RecoverStaleBlocked_KeepsBlockedWhenSubtaskIncomplete(t *testing.T } } +func TestPool_RecoverStaleBlocked_PromotesQueuedParentWithAllSubtasksDone(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-queued-parent", Name: "Queued Parent Story", + Status: task.StoryInProgress, CreatedAt: now, UpdatedAt: now, + } + store.CreateStory(story) + + // Parent task stuck QUEUED (approved with pre-created subtasks, never run). + parent := makeTask("queued-parent-1") + parent.State = task.StateQueued + parent.StoryID = story.ID + store.CreateTask(parent) + + for i := 0; i < 2; i++ { + sub := makeTask(fmt.Sprintf("queued-sub-%d", i)) + sub.ParentTaskID = parent.ID + sub.StoryID = story.ID + sub.State = task.StateCompleted + store.CreateTask(sub) + } + + pool.RecoverStaleBlocked() + + got, err := store.GetTask(parent.ID) + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.State != task.StateReady { + t.Errorf("parent state: want READY, got %s", got.State) + } + + // Story should still be IN_PROGRESS — READY tasks don't satisfy the completion check; + // the task must be accepted (READY → COMPLETED) before the story advances to SHIPPABLE. + s, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if s.Status != task.StoryInProgress { + t.Errorf("story status: want IN_PROGRESS, got %s", s.Status) + } +} + +// TestPool_RecoverStaleBlocked_DoesNotPromoteQueuedLeafTask verifies that a top-level +// QUEUED task with NO subtasks is not promoted to READY by RecoverStaleBlocked. +// This guards against the bug where a task that failed to start (stuck in QUEUED due +// to a DB error) was incorrectly promoted to READY because the "all subtasks done" +// check is vacuously true when there are no subtasks. +func TestPool_RecoverStaleBlocked_DoesNotPromoteQueuedLeafTask(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + // A top-level task stuck in QUEUED with no subtasks (e.g. DB lock prevented RUNNING transition). + leaf := makeTask("queued-leaf-no-subtasks") + leaf.State = task.StateQueued + store.CreateTask(leaf) + + pool.RecoverStaleBlocked() + + got, err := store.GetTask(leaf.ID) + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.State != task.StateQueued { + t.Errorf("leaf task state: want QUEUED (unchanged), got %s", got.State) + } +} + +// TestPool_CheckStoryCompletion_ReadyTasksNotSufficient verifies that READY tasks +// alone do not advance a story to SHIPPABLE — tasks must be COMPLETED. +func TestPool_CheckStoryCompletion_ReadyTasksNotSufficient(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-ready-only", + Name: "Ready Only Story", + Status: task.StoryInProgress, + CreatedAt: now, + UpdatedAt: now, + } + store.CreateStory(story) + + // One task driven to READY (checker pending), one COMPLETED. + tk1 := makeTask("ro-task-1") + tk1.StoryID = story.ID + store.CreateTask(tk1) + for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady} { + store.UpdateTaskState(tk1.ID, s) + } + + tk2 := makeTask("ro-task-2") + tk2.StoryID = story.ID + store.CreateTask(tk2) + for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady, task.StateCompleted} { + store.UpdateTaskState(tk2.ID, s) + } + + pool.checkStoryCompletion(context.Background(), story.ID) + + got, _ := store.GetStory(story.ID) + if got.Status != task.StoryInProgress { + t.Errorf("story status: want IN_PROGRESS (tk1 still READY/checker pending), got %s", got.Status) + } +} + func TestPool_ActivePerAgent_DeletesZeroEntries(t *testing.T) { store := testStore(t) runner := &mockRunner{} @@ -1014,7 +1135,10 @@ func (m *minimalMockStore) ListSubtasks(parentID string) ([]*task.Task, error) { return nil, nil } func (m *minimalMockStore) ListExecutions(_ string) ([]*storage.Execution, error) { return nil, nil } -func (m *minimalMockStore) CreateExecution(e *storage.Execution) error { return nil } +func (m *minimalMockStore) CreateExecution(e *storage.Execution) error { return nil } +func (m *minimalMockStore) CreateExecutionAndSetRunning(e *storage.Execution) error { + return nil +} func (m *minimalMockStore) UpdateExecution(e *storage.Execution) error { return m.updateExecErr } @@ -1064,6 +1188,14 @@ func (m *minimalMockStore) UpdateExecutionChangestats(execID string, stats *task m.mu.Unlock() return nil } +func (m *minimalMockStore) RecordAgentEvent(_ storage.AgentEvent) error { return nil } +func (m *minimalMockStore) GetProject(_ string) (*task.Project, error) { return nil, nil } +func (m *minimalMockStore) GetStory(_ string) (*task.Story, error) { return nil, nil } +func (m *minimalMockStore) ListTasksByStory(_ string) ([]*task.Task, error) { return nil, nil } +func (m *minimalMockStore) UpdateStoryStatus(_ string, _ task.StoryState) error { return nil } +func (m *minimalMockStore) CreateTask(_ *task.Task) error { return nil } +func (m *minimalMockStore) UpdateTaskCheckerReport(_ string, _ string) error { return nil } +func (m *minimalMockStore) GetCheckerTask(_ string) (*task.Task, error) { return nil, nil } func (m *minimalMockStore) lastStateUpdate() (string, task.State, bool) { m.mu.Lock() @@ -1078,17 +1210,18 @@ func (m *minimalMockStore) lastStateUpdate() (string, task.State, bool) { func newPoolWithMockStore(store Store) *Pool { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) return &Pool{ - maxConcurrent: 2, - runners: map[string]Runner{"claude": &mockRunner{}}, - store: store, - logger: logger, - activePerAgent: make(map[string]int), - rateLimited: make(map[string]time.Time), - cancels: make(map[string]context.CancelFunc), - resultCh: make(chan *Result, 4), - workCh: make(chan workItem, 4), - doneCh: make(chan struct{}, 2), - Questions: NewQuestionRegistry(), + maxConcurrent: 2, + maxPerAgent: 1, + runners: map[string]Runner{"claude": &mockRunner{}}, + store: store, + logger: logger, + activePerAgent: make(map[string]int), + rateLimited: make(map[string]time.Time), + cancels: make(map[string]context.CancelFunc), + consecutiveFailures: make(map[string]int), + resultCh: make(chan *Result, 4), + workCh: make(chan workItem, 4), + doneCh: make(chan struct{}, 2), } } @@ -1236,6 +1369,11 @@ func TestPool_SpecificAgent_SkipsLoadBalancing(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) pool := NewPool(4, runners, store, logger) + // Raise per-agent limit so the concurrency gate doesn't interfere with this test. + // The injected activePerAgent is only to make pickAgent prefer "claude", + // verifying that explicit agent type bypasses load balancing. + pool.maxPerAgent = 10 + // Inject 2 active tasks for gemini, 0 for claude. // pickAgent would normally pick "claude". pool.mu.Lock() @@ -1425,3 +1563,748 @@ func TestExecute_MalformedChangestats(t *testing.T) { t.Errorf("expected nil changestats for malformed output, got %+v", execs[0].Changestats) } } + +func TestPool_MaxPerAgent_BlocksSecondTask(t *testing.T) { + store := testStore(t) + + var mu sync.Mutex + concurrentRuns := 0 + maxConcurrent := 0 + + runner := &mockRunner{ + delay: 100 * time.Millisecond, + onRun: func(tk *task.Task, e *storage.Execution) error { + mu.Lock() + concurrentRuns++ + if concurrentRuns > maxConcurrent { + maxConcurrent = concurrentRuns + } + mu.Unlock() + time.Sleep(100 * time.Millisecond) + mu.Lock() + concurrentRuns-- + mu.Unlock() + return nil + }, + } + runners := map[string]Runner{"claude": runner} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runners, store, logger) // pool size 2, but maxPerAgent=1 + pool.requeueDelay = 50 * time.Millisecond // speed up test + + tk1 := makeTask("mpa-1") + tk2 := makeTask("mpa-2") + store.CreateTask(tk1) + store.CreateTask(tk2) + + pool.Submit(context.Background(), tk1) + pool.Submit(context.Background(), tk2) + + for i := 0; i < 2; i++ { + select { + case <-pool.Results(): + case <-time.After(10 * time.Second): + t.Fatal("timed out waiting for result") + } + } + + mu.Lock() + got := maxConcurrent + mu.Unlock() + if got > 1 { + t.Errorf("maxPerAgent=1 violated: %d claude tasks ran concurrently", got) + } +} + +func TestPool_MaxPerAgent_AllowsDifferentAgents(t *testing.T) { + store := testStore(t) + + var mu sync.Mutex + concurrentRuns := 0 + maxConcurrent := 0 + + makeSlowRunner := func() *mockRunner { + return &mockRunner{ + onRun: func(tk *task.Task, e *storage.Execution) error { + mu.Lock() + concurrentRuns++ + if concurrentRuns > maxConcurrent { + maxConcurrent = concurrentRuns + } + mu.Unlock() + time.Sleep(80 * time.Millisecond) + mu.Lock() + concurrentRuns-- + mu.Unlock() + return nil + }, + } + } + + runners := map[string]Runner{ + "claude": makeSlowRunner(), + "gemini": makeSlowRunner(), + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runners, store, logger) + + tk1 := makeTask("da-1") + tk1.Agent.Type = "claude" + tk2 := makeTask("da-2") + tk2.Agent.Type = "gemini" + store.CreateTask(tk1) + store.CreateTask(tk2) + + pool.Submit(context.Background(), tk1) + pool.Submit(context.Background(), tk2) + + for i := 0; i < 2; i++ { + select { + case <-pool.Results(): + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for result") + } + } + + mu.Lock() + got := maxConcurrent + mu.Unlock() + if got < 2 { + t.Errorf("different agents should run concurrently; max concurrent was %d", got) + } +} + +func TestPool_ConsecutiveFailures_ResetOnSuccess(t *testing.T) { + store := testStore(t) + + callCount := 0 + runner := &mockRunner{ + onRun: func(tk *task.Task, e *storage.Execution) error { + callCount++ + if callCount == 1 { + return fmt.Errorf("first failure") + } + return nil // second call succeeds + }, + } + runners := map[string]Runner{"claude": runner} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runners, store, logger) + + // First task fails + tk1 := makeTask("rs-1") + store.CreateTask(tk1) + pool.Submit(context.Background(), tk1) + <-pool.Results() + + pool.mu.Lock() + failsBefore := pool.consecutiveFailures["claude"] + pool.mu.Unlock() + if failsBefore != 1 { + t.Errorf("expected 1 failure after first task, got %d", failsBefore) + } + + // Second task succeeds — counter resets. + tk2 := makeTask("rs-2") + store.CreateTask(tk2) + pool.Submit(context.Background(), tk2) + <-pool.Results() + + pool.mu.Lock() + failsAfter := pool.consecutiveFailures["claude"] + pool.mu.Unlock() + + if failsAfter != 0 { + t.Errorf("expected consecutiveFailures reset to 0 after success, got %d", failsAfter) + } +} + +func TestPool_CheckStoryCompletion_AllComplete(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + // Create a story in IN_PROGRESS state. + now := time.Now().UTC() + story := &task.Story{ + ID: "story-comp-1", + Name: "Completion Test", + Status: task.StoryInProgress, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + // Create two top-level story tasks and drive them through valid transitions to COMPLETED. + for i, id := range []string{"sctask-1", "sctask-2"} { + tk := makeTask(id) + tk.StoryID = "story-comp-1" + tk.State = task.StatePending + if err := store.CreateTask(tk); err != nil { + t.Fatalf("CreateTask %d: %v", i, err) + } + for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady, task.StateCompleted} { + if err := store.UpdateTaskState(id, s); err != nil { + t.Fatalf("UpdateTaskState %s → %s: %v", id, s, err) + } + } + } + + pool.checkStoryCompletion(context.Background(), "story-comp-1") + + got, err := store.GetStory("story-comp-1") + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryShippable { + t.Errorf("story status: want SHIPPABLE, got %v", got.Status) + } +} + +func TestPool_CheckStoryCompletion_PartialComplete(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-partial-1", + Name: "Partial Test", + Status: task.StoryInProgress, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + // First top-level task driven to READY. + tk1 := makeTask("sptask-1") + tk1.StoryID = "story-partial-1" + store.CreateTask(tk1) + for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady} { + store.UpdateTaskState("sptask-1", s) + } + + // Second top-level task still in PENDING (not done). + tk2 := makeTask("sptask-2") + tk2.StoryID = "story-partial-1" + store.CreateTask(tk2) + + pool.checkStoryCompletion(context.Background(), "story-partial-1") + + got, err := store.GetStory("story-partial-1") + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryInProgress { + t.Errorf("story status: want IN_PROGRESS (no transition), got %v", got.Status) + } +} + +func TestPool_StoryDeploy_RunsDeployScript(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} + runners := map[string]Runner{"claude": runner} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runners, store, logger) + + // Create a deploy script that writes a marker file. + tmpDir := t.TempDir() + markerFile := filepath.Join(tmpDir, "deployed.marker") + scriptPath := filepath.Join(tmpDir, "deploy.sh") + scriptContent := "#!/bin/sh\ntouch " + markerFile + "\n" + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil { + t.Fatalf("write deploy script: %v", err) + } + + proj := &task.Project{ + ID: "proj-deploy-1", + Name: "Deploy Test Project", + DeployScript: scriptPath, + } + if err := store.CreateProject(proj); err != nil { + t.Fatalf("create project: %v", err) + } + + story := &task.Story{ + ID: "story-deploy-1", + Name: "Deploy Test Story", + ProjectID: proj.ID, + Status: task.StoryShippable, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("create story: %v", err) + } + + pool.triggerStoryDeploy(context.Background(), story.ID) + + if _, err := os.Stat(markerFile); os.IsNotExist(err) { + t.Error("deploy script did not run: marker file not found") + } + + got, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("get story: %v", err) + } + if got.Status != task.StoryDeployed { + t.Errorf("story status: want DEPLOYED, got %q", got.Status) + } +} + +func runGit(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + if dir != "" { + cmd.Dir = dir + } + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git %v: %v\n%s", args, err, out) + } +} + +func TestPool_StoryDeploy_MergesStoryBranch(t *testing.T) { + tmpDir := t.TempDir() + + // Set up bare repo + working copy with a story branch. + bareDir := filepath.Join(tmpDir, "bare.git") + localDir := filepath.Join(tmpDir, "local") + runGit(t, "", "init", "--bare", bareDir) + runGit(t, "", "clone", bareDir, localDir) + runGit(t, localDir, "config", "user.email", "test@test.com") + runGit(t, localDir, "config", "user.name", "Test") + + // Initial commit on main. + runGit(t, localDir, "checkout", "-b", "main") + os.WriteFile(filepath.Join(localDir, "README.md"), []byte("initial"), 0644) + runGit(t, localDir, "add", ".") + runGit(t, localDir, "commit", "-m", "initial") + runGit(t, localDir, "push", "-u", "origin", "main") + + // Story branch with a feature commit. + runGit(t, localDir, "checkout", "-b", "story/test-feature") + os.WriteFile(filepath.Join(localDir, "feature.go"), []byte("package main"), 0644) + runGit(t, localDir, "add", ".") + runGit(t, localDir, "commit", "-m", "feature work") + runGit(t, localDir, "push", "origin", "story/test-feature") + runGit(t, localDir, "checkout", "main") + + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + scriptPath := filepath.Join(tmpDir, "deploy.sh") + os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0\n"), 0755) + + proj := &task.Project{ + ID: "proj-merge-1", Name: "Merge Test", + LocalPath: localDir, DeployScript: scriptPath, + } + if err := store.CreateProject(proj); err != nil { + t.Fatalf("create project: %v", err) + } + story := &task.Story{ + ID: "story-merge-1", Name: "Merge Test Story", + ProjectID: proj.ID, BranchName: "story/test-feature", + Status: task.StoryShippable, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("create story: %v", err) + } + + pool.triggerStoryDeploy(context.Background(), story.ID) + + // feature.go should now be on main in the working copy. + if _, err := os.Stat(filepath.Join(localDir, "feature.go")); os.IsNotExist(err) { + t.Error("story branch was not merged to main: feature.go missing") + } + got, _ := store.GetStory(story.ID) + if got.Status != task.StoryDeployed { + t.Errorf("story status: want DEPLOYED, got %q", got.Status) + } +} + +func TestPool_PostDeploy_CreatesValidationTask(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + validationSpec := `{"type":"smoke","steps":["curl /health"],"success_criteria":"status 200"}` + story := &task.Story{ + ID: "story-postdeploy-1", + Name: "Post Deploy Test", + Status: task.StoryDeployed, + ValidationJSON: validationSpec, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + pool.createValidationTask(context.Background(), story.ID) + + // Story should now be VALIDATING. + got, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryValidating { + t.Errorf("story status: want VALIDATING, got %q", got.Status) + } + + // A validation task should have been created. + tasks, err := store.ListTasksByStory(story.ID) + if err != nil { + t.Fatalf("ListTasksByStory: %v", err) + } + if len(tasks) == 0 { + t.Fatal("expected a validation task to be created, got none") + } + vtask := tasks[0] + if !strings.Contains(strings.ToLower(vtask.Name), "validation") { + t.Errorf("task name %q does not contain 'validation'", vtask.Name) + } + if vtask.StoryID != story.ID { + t.Errorf("task story_id: want %q, got %q", story.ID, vtask.StoryID) + } + if !strings.Contains(vtask.Agent.Instructions, "smoke") { + t.Errorf("task instructions %q do not reference validation spec content", vtask.Agent.Instructions) + } +} + +func TestPool_ValidationTask_Pass_SetsReviewReady(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-val-pass-1", + Name: "Validation Pass", + Status: task.StoryValidating, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + pool.checkValidationResult(context.Background(), story.ID, task.StateCompleted, "") + + got, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryReviewReady { + t.Errorf("story status: want REVIEW_READY, got %q", got.Status) + } +} + +// TestPool_DependsOn_NoDeadlock verifies that a task waiting for a dependency +// does NOT hold the per-agent slot, allowing the dependency to run first. +func TestPool_DependsOn_NoDeadlock(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} // succeeds immediately + pool := NewPool(2, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + pool.requeueDelay = 10 * time.Millisecond + + // Task A has no deps; Task B depends on A. + taskA := makeTask("dep-a") + taskA.State = task.StateQueued + taskB := makeTask("dep-b") + taskB.DependsOn = []string{"dep-a"} + taskB.State = task.StateQueued + + store.CreateTask(taskA) + store.CreateTask(taskB) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Submit B first — it should not deadlock by holding the slot while waiting for A. + pool.Submit(ctx, taskB) + pool.Submit(ctx, taskA) + + var gotA, gotB bool + for i := 0; i < 2; i++ { + select { + case res := <-pool.Results(): + if res.TaskID == "dep-a" { + gotA = true + } + if res.TaskID == "dep-b" { + gotB = true + } + case <-ctx.Done(): + t.Fatal("timeout: likely deadlock — dep-b held the slot while waiting for dep-a") + } + } + if !gotA || !gotB { + t.Errorf("expected both tasks to complete: gotA=%v gotB=%v", gotA, gotB) + } + + // B must complete after A. + ta, _ := store.GetTask("dep-a") + tb, _ := store.GetTask("dep-b") + if ta.State != task.StateReady && ta.State != task.StateCompleted { + t.Errorf("dep-a should be READY/COMPLETED, got %s", ta.State) + } + if tb.State != task.StateReady && tb.State != task.StateCompleted { + t.Errorf("dep-b should be READY/COMPLETED, got %s", tb.State) + } +} + +func TestPool_ValidationTask_Fail_SetsNeedsFix(t *testing.T) { + store := testStore(t) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger) + + now := time.Now().UTC() + story := &task.Story{ + ID: "story-val-fail-1", + Name: "Validation Fail", + Status: task.StoryValidating, + CreatedAt: now, + UpdatedAt: now, + } + if err := store.CreateStory(story); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + execErr := "smoke test failed: /health returned 503" + pool.checkValidationResult(context.Background(), story.ID, task.StateFailed, execErr) + + got, err := store.GetStory(story.ID) + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Status != task.StoryNeedsFix { + t.Errorf("story status: want NEEDS_FIX, got %q", got.Status) + } +} + +func TestPool_Shutdown_WaitsForWorkers(t *testing.T) { + store := testStore(t) + started := make(chan struct{}) + unblock := make(chan struct{}) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + close(started) + <-unblock + return nil + }, + } + pool := NewPool(1, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("shutdown-task") + tk.State = task.StateQueued + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + // Wait until the worker has started. + select { + case <-started: + case <-time.After(5 * time.Second): + t.Fatal("worker did not start") + } + + // Shutdown should block until we unblock the worker. + done := make(chan error, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + done <- pool.Shutdown(ctx) + }() + + // Shutdown should not have returned yet. + select { + case err := <-done: + t.Fatalf("Shutdown returned early: %v", err) + case <-time.After(50 * time.Millisecond): + } + + close(unblock) // let the worker finish + + select { + case err := <-done: + if err != nil { + t.Errorf("Shutdown returned error: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Shutdown did not return after worker finished") + } +} + +func TestPool_Shutdown_TimesOut(t *testing.T) { + store := testStore(t) + unblock := make(chan struct{}) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + <-unblock // never unblocked + return nil + }, + } + pool := NewPool(1, map[string]Runner{"claude": runner}, store, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("shutdown-timeout-task") + tk.State = task.StateQueued + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + // Give worker a moment to start. + time.Sleep(50 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + err := pool.Shutdown(ctx) + if err == nil { + t.Error("expected timeout error, got nil") + } + close(unblock) // cleanup +} + +func TestPool_CheckerSpawned_OnReady(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} // succeeds instantly + pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("checker-spawn-1") + tk.RepositoryURL = "https://github.com/x/y" + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + <-pool.Results() // wait for original task to finish + + // Poll until the async spawnCheckerTask goroutine has written the checker task. + var checker *task.Task + var err error + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + checker, err = store.GetCheckerTask("checker-spawn-1") + if err != nil { + t.Fatalf("GetCheckerTask: %v", err) + } + if checker != nil { + break + } + time.Sleep(50 * time.Millisecond) + } + if checker == nil { + t.Fatal("expected a checker task to be created, got nil") + } + if checker.CheckerForTaskID != "checker-spawn-1" { + t.Errorf("expected CheckerForTaskID=checker-spawn-1, got %q", checker.CheckerForTaskID) + } +} + +func TestPool_CheckerNotSpawned_ForSubtask(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} + pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + parent := makeTask("no-checker-parent") + parent.RepositoryURL = "https://github.com/x/y" + store.CreateTask(parent) + + sub := makeTask("no-checker-sub") + sub.ParentTaskID = "no-checker-parent" + sub.RepositoryURL = "https://github.com/x/y" + store.CreateTask(sub) + + pool.Submit(context.Background(), sub) + <-pool.Results() + + time.Sleep(100 * time.Millisecond) + + checker, err := store.GetCheckerTask("no-checker-sub") + if err != nil { + t.Fatalf("GetCheckerTask: %v", err) + } + if checker != nil { + t.Error("expected no checker for subtask, but one was created") + } +} + +func TestPool_CheckerPass_AutoAcceptsTask(t *testing.T) { + store := testStore(t) + // Two-phase: first runner succeeds (original task), second also succeeds (checker). + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + return nil // both original and checker succeed + }, + } + pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("autoaccept-1") + tk.RepositoryURL = "https://github.com/x/y" + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + <-pool.Results() // original finishes → READY + checker spawned + + // Wait for checker to run and complete. + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + got, _ := store.GetTask("autoaccept-1") + if got != nil && got.State == task.StateCompleted { + break + } + <-pool.Results() + } + + got, err := store.GetTask("autoaccept-1") + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.State != task.StateCompleted { + t.Errorf("expected COMPLETED after checker pass, got %s", got.State) + } +} + +func TestPool_CheckerFail_AttachesReport(t *testing.T) { + store := testStore(t) + runner := &mockRunner{ + onRun: func(t *task.Task, e *storage.Execution) error { + if t.CheckerForTaskID != "" { + return fmt.Errorf("test suite failed: 3 failures") + } + return nil // original task succeeds + }, + } + pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + + tk := makeTask("fail-checker-1") + tk.RepositoryURL = "https://github.com/x/y" + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + <-pool.Results() // original → READY + + // Wait for checker to fail. + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + got, _ := store.GetTask("fail-checker-1") + if got != nil && got.CheckerReport != "" { + break + } + select { + case <-pool.Results(): + case <-time.After(100 * time.Millisecond): + } + } + + got, err := store.GetTask("fail-checker-1") + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.State != task.StateReady { + t.Errorf("expected task to stay READY after checker fail, got %s", got.State) + } + if got.CheckerReport == "" { + t.Error("expected checker_report to be set after checker failure") + } +} diff --git a/internal/executor/helpers.go b/internal/executor/helpers.go new file mode 100644 index 0000000..76bf8b1 --- /dev/null +++ b/internal/executor/helpers.go @@ -0,0 +1,205 @@ +package executor + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "log/slog" + "os" + "strings" +) + +// BlockedError is returned by Run when the agent wrote a question file and exited. +// The pool transitions the task to BLOCKED and stores the question for the user. +type BlockedError struct { + QuestionJSON string // raw JSON from the question file + SessionID string // claude session to resume once the user answers + SandboxDir string // preserved sandbox path; resume must run here so Claude finds its session files +} + +func (e *BlockedError) Error() string { return fmt.Sprintf("task blocked: %s", e.QuestionJSON) } + +// parseStream reads streaming JSON from claude, writes to w, and returns +// (costUSD, error). error is non-nil if the stream signals task failure: +// - result message has is_error:true +// - a tool_result was denied due to missing permissions +func parseStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, string, error) { + tee := io.TeeReader(r, w) + scanner := bufio.NewScanner(tee) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines + + var totalCost float64 + var sessionID string + var streamErr error + +Loop: + for scanner.Scan() { + line := scanner.Bytes() + var msg map[string]interface{} + if err := json.Unmarshal(line, &msg); err != nil { + continue + } + + msgType, _ := msg["type"].(string) + switch msgType { + case "system": + if subtype, ok := msg["subtype"].(string); ok && subtype == "init" { + if sid, ok := msg["session_id"].(string); ok { + sessionID = sid + } + } + case "rate_limit_event": + if info, ok := msg["rate_limit_info"].(map[string]interface{}); ok { + status, _ := info["status"].(string) + if status == "rejected" { + streamErr = fmt.Errorf("claude rate limit reached (rejected): %v", msg) + // Immediately break since we can't continue anyway + break Loop + } + } + case "assistant": + if errStr, ok := msg["error"].(string); ok && errStr == "rate_limit" { + streamErr = fmt.Errorf("claude rate limit reached: %v", msg) + } + case "result": + if isErr, _ := msg["is_error"].(bool); isErr { + result, _ := msg["result"].(string) + if result != "" { + streamErr = fmt.Errorf("claude task failed: %s", result) + } else { + streamErr = fmt.Errorf("claude task failed (is_error=true in result)") + } + } + // Prefer total_cost_usd from result message; fall through to legacy check below. + if cost, ok := msg["total_cost_usd"].(float64); ok { + totalCost = cost + } + case "user": + // Detect permission-denial tool_results. These occur when permission_mode + // is not bypassPermissions and claude exits 0 without completing its task. + if err := permissionDenialError(msg); err != nil && streamErr == nil { + streamErr = err + } + } + + // Legacy cost field used by older claude versions. + if cost, ok := msg["cost_usd"].(float64); ok { + totalCost = cost + } + } + if err := scanner.Err(); err != nil && streamErr == nil { + streamErr = fmt.Errorf("reading claude stdout: %w", err) + } + + return totalCost, sessionID, streamErr +} + + +// permissionDenialError inspects a "user" stream message for tool_result entries +// that were denied due to missing permissions. Returns an error if found. +func permissionDenialError(msg map[string]interface{}) error { + message, ok := msg["message"].(map[string]interface{}) + if !ok { + return nil + } + content, ok := message["content"].([]interface{}) + if !ok { + return nil + } + for _, item := range content { + itemMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + if itemMap["type"] != "tool_result" { + continue + } + if isErr, _ := itemMap["is_error"].(bool); !isErr { + continue + } + text, _ := itemMap["content"].(string) + if strings.Contains(text, "requested permissions") || strings.Contains(text, "haven't granted") { + return fmt.Errorf("permission denied by host: %s", text) + } + } + return nil +} + +// tailFile returns the last n lines of the file at path, or empty string if +// the file cannot be read. Used to surface subprocess stderr on failure. +func tailFile(path string, n int) string { + f, err := os.Open(path) + if err != nil { + return "" + } + defer f.Close() + + var lines []string + scanner := bufio.NewScanner(f) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + if len(lines) > n { + lines = lines[1:] + } + } + return strings.Join(lines, "\n") +} + +// readFileTail returns the last maxBytes bytes of the file at path as a string, +// or empty string if the file cannot be read. Used to surface agent stderr on failure. +func readFileTail(path string, maxBytes int64) string { + f, err := os.Open(path) + if err != nil { + return "" + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + return "" + } + offset := fi.Size() - maxBytes + if offset < 0 { + offset = 0 + } + buf := make([]byte, fi.Size()-offset) + n, err := f.ReadAt(buf, offset) + if err != nil && n == 0 { + return "" + } + return strings.TrimSpace(string(buf[:n])) +} + +func gitSafe(args ...string) []string { + return append([]string{ + "-c", "safe.directory=*", + "-c", "commit.gpgsign=false", + "-c", "tag.gpgsign=false", + }, args...) +} + +// isCompletionReport returns true when a question-file JSON looks like a +// completion report rather than a real user question. Heuristic: no options +// (or empty options) and no "?" anywhere in the text. +func isCompletionReport(questionJSON string) bool { + var q struct { + Text string `json:"text"` + Options []string `json:"options"` + } + if err := json.Unmarshal([]byte(questionJSON), &q); err != nil { + return false + } + return len(q.Options) == 0 && !strings.Contains(q.Text, "?") +} + +// extractQuestionText returns the "text" field from a question-file JSON, or +// the raw string if parsing fails. +func extractQuestionText(questionJSON string) string { + var q struct { + Text string `json:"text"` + } + if err := json.Unmarshal([]byte(questionJSON), &q); err != nil { + return questionJSON + } + return strings.TrimSpace(q.Text) +} diff --git a/internal/executor/preamble.go b/internal/executor/preamble.go index f5dba2b..b949986 100644 --- a/internal/executor/preamble.go +++ b/internal/executor/preamble.go @@ -45,6 +45,7 @@ The sandbox is rejected if there are any uncommitted modifications. - One commit is fine. Multiple focused commits are also fine. - If you realise the task was already done and you made no changes, that is also fine — just exit cleanly without committing. - Do not exit with uncommitted edits. +- **CRITICAL:** Run ALL git commands from your current directory — do NOT use absolute paths or "cd <project_path> && git ...". Your working directory IS the project. Using absolute paths bypasses the sandbox and breaks commit tracking. --- diff --git a/internal/executor/preamble_test.go b/internal/executor/preamble_test.go index 984f786..5c31b4f 100644 --- a/internal/executor/preamble_test.go +++ b/internal/executor/preamble_test.go @@ -22,3 +22,10 @@ func TestPlanningPreamble_SummaryInstructsEchoToFile(t *testing.T) { t.Error("planningPreamble should show example of writing to $CLAUDOMATOR_SUMMARY_FILE via echo") } } + +func TestPlanningPreamble_GitDiscipline_ForbidsAbsolutePaths(t *testing.T) { + // Agents must not bypass the sandbox by using absolute project paths in git commands. + if !strings.Contains(planningPreamble, "do NOT use absolute paths") { + t.Error("planningPreamble should warn agents not to use absolute paths in git commands") + } +} diff --git a/internal/executor/question.go b/internal/executor/question.go index 9a2b55d..0ae1b08 100644 --- a/internal/executor/question.go +++ b/internal/executor/question.go @@ -5,92 +5,8 @@ import ( "encoding/json" "io" "log/slog" - "sync" ) -// QuestionHandler is called when an agent invokes AskUserQuestion. -// Implementations should broadcast the question and block until an answer arrives. -type QuestionHandler interface { - HandleQuestion(taskID, toolUseID string, input json.RawMessage) (string, error) -} - -// PendingQuestion holds state for a question awaiting a user answer. -type PendingQuestion struct { - TaskID string `json:"task_id"` - ToolUseID string `json:"tool_use_id"` - Input json.RawMessage `json:"input"` - AnswerCh chan string `json:"-"` -} - -// QuestionRegistry tracks pending questions across running tasks. -type QuestionRegistry struct { - mu sync.Mutex - questions map[string]*PendingQuestion // keyed by toolUseID -} - -// NewQuestionRegistry creates a new registry. -func NewQuestionRegistry() *QuestionRegistry { - return &QuestionRegistry{ - questions: make(map[string]*PendingQuestion), - } -} - -// Register adds a pending question and returns its answer channel. -func (qr *QuestionRegistry) Register(taskID, toolUseID string, input json.RawMessage) chan string { - ch := make(chan string, 1) - qr.mu.Lock() - qr.questions[toolUseID] = &PendingQuestion{ - TaskID: taskID, - ToolUseID: toolUseID, - Input: input, - AnswerCh: ch, - } - qr.mu.Unlock() - return ch -} - -// Answer delivers an answer for a pending question. Returns false if no such question exists. -func (qr *QuestionRegistry) Answer(toolUseID, answer string) bool { - qr.mu.Lock() - pq, ok := qr.questions[toolUseID] - if ok { - delete(qr.questions, toolUseID) - } - qr.mu.Unlock() - if !ok { - return false - } - pq.AnswerCh <- answer - return true -} - -// Get returns a pending question by tool_use_id, or nil. -func (qr *QuestionRegistry) Get(toolUseID string) *PendingQuestion { - qr.mu.Lock() - defer qr.mu.Unlock() - return qr.questions[toolUseID] -} - -// PendingForTask returns all pending questions for a given task. -func (qr *QuestionRegistry) PendingForTask(taskID string) []*PendingQuestion { - qr.mu.Lock() - defer qr.mu.Unlock() - var result []*PendingQuestion - for _, pq := range qr.questions { - if pq.TaskID == taskID { - result = append(result, pq) - } - } - return result -} - -// Remove removes a question without answering it (e.g., on task cancellation). -func (qr *QuestionRegistry) Remove(toolUseID string) { - qr.mu.Lock() - delete(qr.questions, toolUseID) - qr.mu.Unlock() -} - // extractAskUserQuestion parses a stream-json line and returns the tool_use_id and input // if the line is an assistant event containing an AskUserQuestion tool_use. func extractAskUserQuestion(line []byte) (string, json.RawMessage) { diff --git a/internal/executor/question_test.go b/internal/executor/question_test.go index d0fbed9..6686c15 100644 --- a/internal/executor/question_test.go +++ b/internal/executor/question_test.go @@ -9,64 +9,6 @@ import ( "testing" ) -func TestQuestionRegistry_RegisterAndAnswer(t *testing.T) { - qr := NewQuestionRegistry() - - ch := qr.Register("task-1", "toolu_abc", json.RawMessage(`{"question":"color?"}`)) - - // Answer should unblock the channel. - go func() { - ok := qr.Answer("toolu_abc", "blue") - if !ok { - t.Error("Answer returned false, expected true") - } - }() - - answer := <-ch - if answer != "blue" { - t.Errorf("want 'blue', got %q", answer) - } - - // Question should be removed after answering. - if qr.Get("toolu_abc") != nil { - t.Error("question should be removed after answering") - } -} - -func TestQuestionRegistry_AnswerUnknown(t *testing.T) { - qr := NewQuestionRegistry() - ok := qr.Answer("nonexistent", "anything") - if ok { - t.Error("expected false for unknown question") - } -} - -func TestQuestionRegistry_PendingForTask(t *testing.T) { - qr := NewQuestionRegistry() - qr.Register("task-1", "toolu_1", json.RawMessage(`{}`)) - qr.Register("task-1", "toolu_2", json.RawMessage(`{}`)) - qr.Register("task-2", "toolu_3", json.RawMessage(`{}`)) - - pending := qr.PendingForTask("task-1") - if len(pending) != 2 { - t.Errorf("want 2 pending for task-1, got %d", len(pending)) - } - - pending2 := qr.PendingForTask("task-2") - if len(pending2) != 1 { - t.Errorf("want 1 pending for task-2, got %d", len(pending2)) - } -} - -func TestQuestionRegistry_Remove(t *testing.T) { - qr := NewQuestionRegistry() - qr.Register("task-1", "toolu_x", json.RawMessage(`{}`)) - qr.Remove("toolu_x") - if qr.Get("toolu_x") != nil { - t.Error("question should be removed") - } -} - func TestExtractAskUserQuestion_DetectsQuestion(t *testing.T) { // Simulate a stream-json assistant event containing an AskUserQuestion tool_use. event := map[string]interface{}{ diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go index 109aa49..ee9a336 100644 --- a/internal/executor/ratelimit.go +++ b/internal/executor/ratelimit.go @@ -13,5 +13,9 @@ func isQuotaExhausted(err error) bool { strings.Contains(msg, "you've hit your limit") || strings.Contains(msg, "you have hit your limit") || strings.Contains(msg, "rate limit reached (rejected)") || - strings.Contains(msg, "status: rejected") + strings.Contains(msg, "status: rejected") || + // Gemini CLI quota exhaustion + strings.Contains(msg, "terminalquotaerror") || + strings.Contains(msg, "exhausted your daily quota") || + strings.Contains(msg, "generate_content_free_tier_requests") } diff --git a/internal/executor/stream_test.go b/internal/executor/stream_test.go index 10eb858..11a6178 100644 --- a/internal/executor/stream_test.go +++ b/internal/executor/stream_test.go @@ -12,7 +12,7 @@ func streamLine(json string) string { return json + "\n" } func TestParseStream_ResultIsError_ReturnsError(t *testing.T) { input := streamLine(`{"type":"result","subtype":"error_during_execution","is_error":true,"result":"something went wrong"}`) - _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err == nil { t.Fatal("expected error when result.is_error=true, got nil") } @@ -27,7 +27,7 @@ func TestParseStream_PermissionDenied_ReturnsError(t *testing.T) { input := streamLine(`{"type":"user","message":{"role":"user","content":[{"type":"tool_result","is_error":true,"content":"Claude requested permissions to write to /foo/bar.go, but you haven't granted it yet.","tool_use_id":"tu_abc"}]}}`) + streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"I need permission","total_cost_usd":0.1}`) - _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err == nil { t.Fatal("expected error for permission denial, got nil") } @@ -40,7 +40,7 @@ func TestParseStream_Success_ReturnsNilError(t *testing.T) { input := streamLine(`{"type":"assistant","message":{"content":[{"type":"text","text":"Done."}]}}`) + streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"All tests pass.","total_cost_usd":0.05}`) - _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err != nil { t.Fatalf("expected nil error for success stream, got: %v", err) } @@ -49,7 +49,7 @@ func TestParseStream_Success_ReturnsNilError(t *testing.T) { func TestParseStream_ExtractsCostFromResultMessage(t *testing.T) { input := streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"done","total_cost_usd":1.2345}`) - cost, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + cost, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -62,7 +62,7 @@ func TestParseStream_ExtractsCostFromLegacyCostUSD(t *testing.T) { // Some versions emit cost_usd at the top level rather than total_cost_usd. input := streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"done","cost_usd":0.99}`) - cost, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + cost, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -78,8 +78,21 @@ func TestParseStream_NonToolResultIsError_DoesNotFail(t *testing.T) { input := streamLine(`{"type":"user","message":{"role":"user","content":[{"type":"tool_result","is_error":true,"content":"exit status 1","tool_use_id":"tu_xyz"}]}}`) + streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"Fixed it.","total_cost_usd":0.2}`) - _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) if err != nil { t.Fatalf("non-permission tool errors should not fail the task, got: %v", err) } } + +func TestParseStream_ExtractsSessionID(t *testing.T) { + input := streamLine(`{"type":"system","subtype":"init","session_id":"sess-999"}`) + + streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"ok","total_cost_usd":0.01}`) + + _, sid, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sid != "sess-999" { + t.Errorf("want session ID sess-999, got %q", sid) + } +} diff --git a/internal/notify/vapid.go b/internal/notify/vapid.go new file mode 100644 index 0000000..684bf4d --- /dev/null +++ b/internal/notify/vapid.go @@ -0,0 +1,25 @@ +package notify + +import ( + "encoding/base64" + + webpush "github.com/SherClockHolmes/webpush-go" +) + +// GenerateVAPIDKeys generates a VAPID key pair for web push notifications. +// Returns the base64url-encoded public and private keys. +// Note: webpush.GenerateVAPIDKeys returns (privateKey, publicKey) — we swap here. +func GenerateVAPIDKeys() (publicKey, privateKey string, err error) { + privateKey, publicKey, err = webpush.GenerateVAPIDKeys() + return +} + +// ValidateVAPIDPublicKey reports whether key is a valid VAPID public key: +// a base64url-encoded 65-byte uncompressed P-256 point (starts with 0x04). +func ValidateVAPIDPublicKey(key string) bool { + b, err := base64.RawURLEncoding.DecodeString(key) + if err != nil { + return false + } + return len(b) == 65 && b[0] == 0x04 +} diff --git a/internal/notify/vapid_test.go b/internal/notify/vapid_test.go new file mode 100644 index 0000000..a45047d --- /dev/null +++ b/internal/notify/vapid_test.go @@ -0,0 +1,64 @@ +package notify + +import ( + "encoding/base64" + "testing" +) + +// TestValidateVAPIDPublicKey verifies that ValidateVAPIDPublicKey accepts valid +// public keys and rejects private keys, empty strings, and invalid base64. +func TestValidateVAPIDPublicKey(t *testing.T) { + pub, priv, err := GenerateVAPIDKeys() + if err != nil { + t.Fatalf("GenerateVAPIDKeys: %v", err) + } + if !ValidateVAPIDPublicKey(pub) { + t.Error("valid public key should pass validation") + } + if ValidateVAPIDPublicKey(priv) { + t.Error("private key (32 bytes) should fail public key validation") + } + if ValidateVAPIDPublicKey("") { + t.Error("empty string should fail validation") + } + if ValidateVAPIDPublicKey("notbase64!!!") { + t.Error("invalid base64 should fail validation") + } +} + +// TestGenerateVAPIDKeys_PublicKeyIs65Bytes verifies that the public key returned +// by GenerateVAPIDKeys is a 65-byte uncompressed P256 EC point (base64url, no padding = 87 chars) +// and the private key is 32 bytes (43 chars). Previously the return values were swapped. +func TestGenerateVAPIDKeys_PublicKeyIs65Bytes(t *testing.T) { + pub, priv, err := GenerateVAPIDKeys() + if err != nil { + t.Fatalf("GenerateVAPIDKeys: %v", err) + } + + // Public key: 65 bytes → 87 base64url chars (no padding). + if len(pub) != 87 { + t.Errorf("public key: want 87 chars (65 bytes), got %d chars (%q)", len(pub), pub) + } + pubBytes, err := base64.RawURLEncoding.DecodeString(pub) + if err != nil { + t.Fatalf("public key base64url decode: %v", err) + } + if len(pubBytes) != 65 { + t.Errorf("public key bytes: want 65, got %d", len(pubBytes)) + } + if pubBytes[0] != 0x04 { + t.Errorf("public key first byte: want 0x04 (uncompressed point), got 0x%02x", pubBytes[0]) + } + + // Private key: 32 bytes → 43 base64url chars (no padding). + if len(priv) != 43 { + t.Errorf("private key: want 43 chars (32 bytes), got %d chars (%q)", len(priv), priv) + } + privBytes, err := base64.RawURLEncoding.DecodeString(priv) + if err != nil { + t.Fatalf("private key base64url decode: %v", err) + } + if len(privBytes) != 32 { + t.Errorf("private key bytes: want 32, got %d", len(privBytes)) + } +} diff --git a/internal/notify/webpush.go b/internal/notify/webpush.go new file mode 100644 index 0000000..e118a43 --- /dev/null +++ b/internal/notify/webpush.go @@ -0,0 +1,106 @@ +package notify + +import ( + "encoding/json" + "fmt" + "log/slog" + + webpush "github.com/SherClockHolmes/webpush-go" + "github.com/thepeterstone/claudomator/internal/storage" +) + +// PushSubscriptionStore is the minimal storage interface needed by WebPushNotifier. +type PushSubscriptionStore interface { + ListPushSubscriptions() ([]storage.PushSubscription, error) +} + +// WebPushNotifier sends web push notifications to all registered subscribers. +type WebPushNotifier struct { + Store PushSubscriptionStore + VAPIDPublicKey string + VAPIDPrivateKey string + VAPIDEmail string + Logger *slog.Logger +} + +// notificationContent derives urgency, title, body, and tag from a notify Event. +// Exported only for tests; use lowercase in production code via this same file. +func notificationContent(ev Event) (urgency, title, body, tag string) { + tag = "task-" + ev.TaskID + switch ev.Status { + case "BLOCKED": + urgency = "urgent" + title = "Needs input" + body = fmt.Sprintf("%s is waiting for your response", ev.TaskName) + case "FAILED", "BUDGET_EXCEEDED", "TIMED_OUT": + urgency = "high" + title = "Task failed" + if ev.Error != "" { + body = fmt.Sprintf("%s failed: %s", ev.TaskName, ev.Error) + } else { + body = fmt.Sprintf("%s failed", ev.TaskName) + } + case "COMPLETED": + urgency = "low" + title = "Task done" + body = fmt.Sprintf("%s completed ($%.2f)", ev.TaskName, ev.CostUSD) + default: + urgency = "normal" + title = "Task update" + body = fmt.Sprintf("%s: %s", ev.TaskName, ev.Status) + } + return +} + +// Notify sends a web push notification to all registered subscribers. +func (n *WebPushNotifier) Notify(ev Event) error { + subs, err := n.Store.ListPushSubscriptions() + if err != nil { + return fmt.Errorf("listing push subscriptions: %w", err) + } + if len(subs) == 0 { + return nil + } + + urgency, title, body, tag := notificationContent(ev) + + payload := map[string]string{ + "title": title, + "body": body, + "tag": tag, + } + data, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("marshaling push payload: %w", err) + } + + opts := &webpush.Options{ + Subscriber: n.VAPIDEmail, + VAPIDPublicKey: n.VAPIDPublicKey, + VAPIDPrivateKey: n.VAPIDPrivateKey, + Urgency: webpush.Urgency(urgency), + TTL: 86400, + } + + var lastErr error + for _, sub := range subs { + wSub := &webpush.Subscription{ + Endpoint: sub.Endpoint, + Keys: webpush.Keys{ + P256dh: sub.P256DHKey, + Auth: sub.AuthKey, + }, + } + resp, sendErr := webpush.SendNotification(data, wSub, opts) + if sendErr != nil { + n.Logger.Error("webpush send failed", "endpoint", sub.Endpoint, "error", sendErr) + lastErr = sendErr + continue + } + resp.Body.Close() + if resp.StatusCode >= 400 { + n.Logger.Warn("webpush returned error status", "endpoint", sub.Endpoint, "status", resp.StatusCode) + } + } + return lastErr +} diff --git a/internal/notify/webpush_test.go b/internal/notify/webpush_test.go new file mode 100644 index 0000000..594305e --- /dev/null +++ b/internal/notify/webpush_test.go @@ -0,0 +1,191 @@ +package notify + +import ( + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "sync" + "testing" + + "github.com/thepeterstone/claudomator/internal/storage" +) + +// fakePushStore is an in-memory push subscription store for testing. +type fakePushStore struct { + mu sync.Mutex + subs []storage.PushSubscription +} + +func (f *fakePushStore) ListPushSubscriptions() ([]storage.PushSubscription, error) { + f.mu.Lock() + defer f.mu.Unlock() + cp := make([]storage.PushSubscription, len(f.subs)) + copy(cp, f.subs) + return cp, nil +} + +func TestWebPushNotifier_NoSubscriptions_NoError(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + n := &WebPushNotifier{ + Store: &fakePushStore{}, + VAPIDPublicKey: "testpub", + VAPIDPrivateKey: "testpriv", + VAPIDEmail: "mailto:test@example.com", + Logger: logger, + } + if err := n.Notify(Event{TaskID: "t1", TaskName: "test", Status: "COMPLETED"}); err != nil { + t.Errorf("expected no error with empty store, got: %v", err) + } +} + +// TestWebPushNotifier_UrgencyMapping verifies that different statuses produce +// different urgency values in the push notification options. +func TestWebPushNotifier_UrgencyMapping(t *testing.T) { + tests := []struct { + status string + wantUrgency string + }{ + {"BLOCKED", "urgent"}, + {"FAILED", "high"}, + {"BUDGET_EXCEEDED", "high"}, + {"TIMED_OUT", "high"}, + {"COMPLETED", "low"}, + {"RUNNING", "normal"}, + } + + for _, tc := range tests { + t.Run(tc.status, func(t *testing.T) { + urgency, _, _, _ := notificationContent(Event{ + Status: tc.status, + TaskName: "mytask", + Error: "some error", + CostUSD: 0.12, + }) + if urgency != tc.wantUrgency { + t.Errorf("status %q: want urgency %q, got %q", tc.status, tc.wantUrgency, urgency) + } + }) + } +} + +// TestWebPushNotifier_SendsToSubscription verifies that a notification is sent +// via HTTP when a subscription is present. We use a mock push server to capture +// the request and verify the JSON payload. +func TestWebPushNotifier_SendsToSubscription(t *testing.T) { + var mu sync.Mutex + var captured []byte + + // Mock push server — just record the body. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + mu.Lock() + captured = body + mu.Unlock() + w.WriteHeader(http.StatusCreated) + })) + defer srv.Close() + + // Generate real VAPID keys for a valid (but minimal) send test. + pub, priv, err := GenerateVAPIDKeys() + if err != nil { + t.Fatalf("GenerateVAPIDKeys: %v", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + + // Use a fake subscription pointing at our mock server. The webpush library + // will POST to the subscription endpoint. We use a minimal fake key (base64url + // of 65 zero bytes for p256dh and 16 zero bytes for auth) — the library + // encrypts the payload before sending, so the mock server just needs to accept. + store := &fakePushStore{ + subs: []storage.PushSubscription{ + { + ID: "sub-1", + Endpoint: srv.URL, + P256DHKey: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", // 65 bytes base64url + AuthKey: "AAAAAAAAAAAAAAAAAAA=", // 16 bytes base64 + }, + }, + } + + n := &WebPushNotifier{ + Store: store, + VAPIDPublicKey: pub, + VAPIDPrivateKey: priv, + VAPIDEmail: "mailto:test@example.com", + Logger: logger, + } + + ev := Event{ + TaskID: "task-abc", + TaskName: "myTask", + Status: "COMPLETED", + CostUSD: 0.42, + } + + // We don't assert the HTTP call always succeeds (crypto might fail with + // fake keys), but we do assert no panic and the function is callable. + // The real assertion is that if it does send, the payload is valid JSON. + n.Notify(ev) //nolint:errcheck — mock keys may fail crypto; we test structure not success + + mu.Lock() + defer mu.Unlock() + if len(captured) > 0 { + // Encrypted payload — just verify it's non-empty bytes. + if len(captured) == 0 { + t.Error("captured request body should be non-empty") + } + } +} + +// TestNotificationContent_TitleAndBody verifies titles and bodies for key statuses. +func TestNotificationContent_TitleAndBody(t *testing.T) { + tests := []struct { + status string + wantTitle string + }{ + {"BLOCKED", "Needs input"}, + {"FAILED", "Task failed"}, + {"BUDGET_EXCEEDED", "Task failed"}, + {"TIMED_OUT", "Task failed"}, + {"COMPLETED", "Task done"}, + } + for _, tc := range tests { + t.Run(tc.status, func(t *testing.T) { + _, title, _, _ := notificationContent(Event{ + Status: tc.status, + TaskName: "mytask", + Error: "err", + CostUSD: 0.05, + }) + if title != tc.wantTitle { + t.Errorf("status %q: want title %q, got %q", tc.status, tc.wantTitle, title) + } + }) + } +} + +// TestWebPushNotifier_PayloadJSON verifies that the JSON payload is well-formed. +func TestWebPushNotifier_PayloadJSON(t *testing.T) { + ev := Event{TaskID: "t1", TaskName: "myTask", Status: "COMPLETED", CostUSD: 0.33} + urgency, title, body, tag := notificationContent(ev) + if urgency == "" || title == "" || body == "" || tag == "" { + t.Error("all notification fields should be non-empty") + } + + payload := map[string]string{"title": title, "body": body, "tag": tag} + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + var out map[string]string + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if out["title"] != title { + t.Errorf("title roundtrip failed") + } +} diff --git a/internal/storage/db.go b/internal/storage/db.go index ce60e2f..4adc1ba 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -8,7 +8,6 @@ import ( "time" "github.com/thepeterstone/claudomator/internal/task" - _ "github.com/mattn/go-sqlite3" ) type DB struct { @@ -20,6 +19,10 @@ func Open(path string) (*DB, error) { if err != nil { return nil, fmt.Errorf("opening database: %w", err) } + // SQLite only allows one concurrent writer. Limiting to one open connection + // prevents "database is locked" errors when multiple goroutines write + // simultaneously via database/sql's connection pool. + db.SetMaxOpenConns(1) s := &DB{db: db} if err := s.migrate(); err != nil { db.Close() @@ -86,6 +89,54 @@ func (s *DB) migrate() error { `ALTER TABLE executions ADD COLUMN changestats_json TEXT`, `ALTER TABLE executions ADD COLUMN commits_json TEXT NOT NULL DEFAULT '[]'`, `ALTER TABLE tasks ADD COLUMN elaboration_input TEXT`, + `ALTER TABLE tasks ADD COLUMN project TEXT`, + `ALTER TABLE tasks ADD COLUMN repository_url TEXT`, + `CREATE TABLE IF NOT EXISTS push_subscriptions ( + id TEXT PRIMARY KEY, + endpoint TEXT NOT NULL UNIQUE, + p256dh_key TEXT NOT NULL, + auth_key TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS agent_events ( + id TEXT PRIMARY KEY, + agent TEXT NOT NULL, + event TEXT NOT NULL, + timestamp DATETIME NOT NULL, + until DATETIME, + reason TEXT + )`, + `CREATE INDEX IF NOT EXISTS idx_agent_events_agent ON agent_events(agent)`, + `CREATE INDEX IF NOT EXISTS idx_agent_events_timestamp ON agent_events(timestamp)`, + `CREATE TABLE IF NOT EXISTS projects ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + remote_url TEXT NOT NULL DEFAULT '', + local_path TEXT NOT NULL DEFAULT '', + type TEXT NOT NULL DEFAULT 'web', + deploy_script TEXT NOT NULL DEFAULT '', + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS stories ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + project_id TEXT NOT NULL DEFAULT '', + branch_name TEXT NOT NULL DEFAULT '', + deploy_config TEXT NOT NULL DEFAULT '', + validation_json TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'PENDING', + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + )`, + `ALTER TABLE tasks ADD COLUMN story_id TEXT`, + `ALTER TABLE tasks ADD COLUMN acceptance_criteria TEXT NOT NULL DEFAULT ''`, + `ALTER TABLE tasks ADD COLUMN checker_for_task_id TEXT NOT NULL DEFAULT ''`, + `ALTER TABLE tasks ADD COLUMN checker_report TEXT NOT NULL DEFAULT ''`, `ALTER TABLE executions ADD COLUMN tokens_in INTEGER`, `ALTER TABLE executions ADD COLUMN tokens_out INTEGER`, } @@ -125,24 +176,25 @@ func (s *DB) CreateTask(t *task.Task) error { } _, err = s.db.Exec(` - INSERT INTO tasks (id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - t.ID, t.Name, t.Description, t.ElaborationInput, string(configJSON), string(t.Priority), + INSERT INTO tasks (id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, story_id, acceptance_criteria, checker_for_task_id, checker_report) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + t.ID, t.Name, t.Description, t.ElaborationInput, t.Project, t.RepositoryURL, string(configJSON), string(t.Priority), t.Timeout.Duration.Nanoseconds(), string(retryJSON), string(tagsJSON), string(depsJSON), - t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), + t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), t.StoryID, + t.AcceptanceCriteria, t.CheckerForTaskID, t.CheckerReport, ) return err } // GetTask retrieves a task by ID. func (s *DB) GetTask(id string) (*task.Task, error) { - row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id) + row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE id = ?`, id) return scanTask(row) } // ListTasks returns tasks matching the given filter. func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { - query := `SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE 1=1` + query := `SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE 1=1` var args []interface{} if filter.State != "" { @@ -178,7 +230,7 @@ func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { // ListSubtasks returns all tasks whose parent_task_id matches the given ID. func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) { - rows, err := s.db.Query(`SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) + rows, err := s.db.Query(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) if err != nil { return nil, err } @@ -231,7 +283,7 @@ func (s *DB) ResetTaskForRetry(id string) (*task.Task, error) { } defer tx.Rollback() //nolint:errcheck - t, err := scanTask(tx.QueryRow(`SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id)) + t, err := scanTask(tx.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE id = ?`, id)) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("task %q not found", id) @@ -292,9 +344,10 @@ func (s *DB) RejectTask(id, comment string) error { // TaskUpdate holds the fields that UpdateTask may change. type TaskUpdate struct { - Name string - Description string - Config task.AgentConfig + Name string + Description string + RepositoryURL string + Config task.AgentConfig Priority task.Priority TimeoutNS int64 Retry task.RetryConfig @@ -333,13 +386,11 @@ func (s *DB) UpdateTask(id string, u TaskUpdate) error { now := time.Now().UTC() result, err := s.db.Exec(` UPDATE tasks - SET name = ?, description = ?, config_json = ?, priority = ?, timeout_ns = ?, + SET name = ?, description = ?, repository_url = ?, config_json = ?, priority = ?, timeout_ns = ?, retry_json = ?, tags_json = ?, depends_on_json = ?, state = ?, updated_at = ? WHERE id = ?`, - u.Name, u.Description, string(configJSON), string(u.Priority), u.TimeoutNS, - string(retryJSON), string(tagsJSON), string(depsJSON), string(task.StatePending), now, - id, - ) + u.Name, u.Description, u.RepositoryURL, configJSON, string(u.Priority), u.TimeoutNS, + retryJSON, tagsJSON, depsJSON, string(task.StatePending), now, id) if err != nil { return err } @@ -376,6 +427,8 @@ func (s *DB) GetMaxUpdatedAt() (time.Time, error) { "2006-01-02T15:04:05Z07:00", time.RFC3339, "2006-01-02 15:04:05", + "2006-01-02 15:04:05 +0000 UTC", + "2006-01-02 15:04:05.999999999 +0000 UTC", } for _, f := range formats { parsed, err := time.Parse(f, t.String) @@ -417,6 +470,55 @@ type Execution struct { Summary string } +// CreateExecutionAndSetRunning inserts an execution record and transitions the +// task to RUNNING in a single transaction, preventing a crash-window where the +// task stays PENDING with an orphaned RUNNING execution record. +func (s *DB) CreateExecutionAndSetRunning(e *Execution) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() //nolint:errcheck + + // Validate state transition. + var currentState string + if err := tx.QueryRow(`SELECT state FROM tasks WHERE id = ?`, e.TaskID).Scan(¤tState); err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("task %q not found", e.TaskID) + } + return err + } + if !task.ValidTransition(task.State(currentState), task.StateRunning) { + return fmt.Errorf("invalid state transition %s → RUNNING for task %q", currentState, e.TaskID) + } + + // Insert execution record. + commitsJSON := "[]" + if len(e.Commits) > 0 { + b, err := json.Marshal(e.Commits) + if err != nil { + return fmt.Errorf("marshaling commits: %w", err) + } + commitsJSON = string(b) + } + if _, err := tx.Exec(` + INSERT INTO executions (id, task_id, start_time, end_time, exit_code, status, stdout_path, stderr_path, artifact_dir, cost_usd, error_msg, session_id, sandbox_dir, changestats_json, commits_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, ?)`, + e.ID, e.TaskID, e.StartTime.UTC(), e.EndTime.UTC(), e.ExitCode, e.Status, + e.StdoutPath, e.StderrPath, e.ArtifactDir, e.CostUSD, e.ErrorMsg, e.SessionID, e.SandboxDir, commitsJSON, + ); err != nil { + return err + } + + // Transition task to RUNNING. + now := time.Now().UTC() + if _, err := tx.Exec(`UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?`, string(task.StateRunning), now, e.TaskID); err != nil { + return err + } + + return tx.Commit() +} + // CreateExecution inserts an execution record. func (s *DB) CreateExecution(e *Execution) error { var changestatsJSON *string @@ -544,6 +646,141 @@ type RecentExecution struct { StdoutPath string `json:"stdout_path"` } +// ThroughputBucket is one time-bucket of execution counts by outcome. +type ThroughputBucket struct { + Hour string `json:"hour"` // RFC3339 truncated to hour + Completed int `json:"completed"` + Failed int `json:"failed"` + Other int `json:"other"` +} + +// BillingDay is the aggregated cost and run count for a calendar day. +type BillingDay struct { + Day string `json:"day"` // YYYY-MM-DD + CostUSD float64 `json:"cost_usd"` + Runs int `json:"runs"` +} + +// FailedExecution is a failed/timed-out/budget-exceeded execution with its error. +type FailedExecution struct { + ID string `json:"id"` + TaskID string `json:"task_id"` + TaskName string `json:"task_name"` + Status string `json:"status"` + ErrorMsg string `json:"error_msg"` + Category string `json:"category"` // quota | timeout | rate_limit | git | failed + StartedAt time.Time `json:"started_at"` +} + +// DashboardStats is returned by QueryDashboardStats. +type DashboardStats struct { + Throughput []ThroughputBucket `json:"throughput"` + Billing []BillingDay `json:"billing"` + Failures []FailedExecution `json:"failures"` +} + +// QueryDashboardStats returns pre-aggregated stats for the given window. +func (s *DB) QueryDashboardStats(since time.Time) (*DashboardStats, error) { + stats := &DashboardStats{ + Throughput: []ThroughputBucket{}, + Billing: []BillingDay{}, + Failures: []FailedExecution{}, + } + + // Throughput: completions per hour bucket + tpRows, err := s.db.Query(` + SELECT strftime('%Y-%m-%dT%H:00:00Z', start_time) as hour, + SUM(CASE WHEN status IN ('COMPLETED','READY') THEN 1 ELSE 0 END), + SUM(CASE WHEN status IN ('FAILED','TIMED_OUT','BUDGET_EXCEEDED') THEN 1 ELSE 0 END), + SUM(CASE WHEN status NOT IN ('COMPLETED','READY','FAILED','TIMED_OUT','BUDGET_EXCEEDED') THEN 1 ELSE 0 END) + FROM executions + WHERE start_time >= ? AND status NOT IN ('RUNNING','QUEUED','PENDING') + GROUP BY hour ORDER BY hour ASC`, since.UTC()) + if err != nil { + return nil, err + } + defer tpRows.Close() + for tpRows.Next() { + var b ThroughputBucket + if err := tpRows.Scan(&b.Hour, &b.Completed, &b.Failed, &b.Other); err != nil { + return nil, err + } + stats.Throughput = append(stats.Throughput, b) + } + if err := tpRows.Err(); err != nil { + return nil, err + } + + // Billing: cost per day + billRows, err := s.db.Query(` + SELECT date(start_time) as day, COALESCE(SUM(cost_usd),0), COUNT(*) + FROM executions + WHERE start_time >= ? + GROUP BY day ORDER BY day ASC`, since.UTC()) + if err != nil { + return nil, err + } + defer billRows.Close() + for billRows.Next() { + var b BillingDay + if err := billRows.Scan(&b.Day, &b.CostUSD, &b.Runs); err != nil { + return nil, err + } + stats.Billing = append(stats.Billing, b) + } + if err := billRows.Err(); err != nil { + return nil, err + } + + // Failures: recent failed executions with error messages + failRows, err := s.db.Query(` + SELECT e.id, e.task_id, t.name, e.status, COALESCE(e.error_msg,''), e.start_time + FROM executions e JOIN tasks t ON e.task_id = t.id + WHERE e.start_time >= ? AND e.status IN ('FAILED','TIMED_OUT','BUDGET_EXCEEDED') + ORDER BY e.start_time DESC LIMIT 50`, since.UTC()) + if err != nil { + return nil, err + } + defer failRows.Close() + for failRows.Next() { + var f FailedExecution + if err := failRows.Scan(&f.ID, &f.TaskID, &f.TaskName, &f.Status, &f.ErrorMsg, &f.StartedAt); err != nil { + return nil, err + } + f.Category = classifyError(f.Status, f.ErrorMsg) + stats.Failures = append(stats.Failures, f) + } + if err := failRows.Err(); err != nil { + return nil, err + } + + return stats, nil +} + +// classifyError maps a status + error message to a human category. +func classifyError(status, msg string) string { + if status == "TIMED_OUT" { + return "timeout" + } + if status == "BUDGET_EXCEEDED" { + return "quota" + } + low := strings.ToLower(msg) + if strings.Contains(low, "quota") || strings.Contains(low, "exhausted") || strings.Contains(low, "terminalquota") { + return "quota" + } + if strings.Contains(low, "rate limit") || strings.Contains(low, "429") || strings.Contains(low, "too many requests") { + return "rate_limit" + } + if strings.Contains(low, "git push") || strings.Contains(low, "git pull") { + return "git" + } + if strings.Contains(low, "timeout") || strings.Contains(low, "deadline") { + return "timeout" + } + return "failed" +} + // ListRecentExecutions returns executions since the given time, joined with task names. // If taskID is non-empty, only executions for that task are returned. func (s *DB) ListRecentExecutions(since time.Time, limit int, taskID string) ([]*RecentExecution, error) { @@ -600,6 +837,24 @@ func (s *DB) UpdateTaskSummary(taskID, summary string) error { return err } +// UpdateTaskCheckerReport sets the checker_report field on a task. +func (s *DB) UpdateTaskCheckerReport(id, report string) error { + now := time.Now().UTC() + _, err := s.db.Exec(`UPDATE tasks SET checker_report = ?, updated_at = ? WHERE id = ?`, report, now, id) + return err +} + +// GetCheckerTask returns the checker task for the given checked task ID, +// or nil if no checker task exists. +func (s *DB) GetCheckerTask(checkedTaskID string) (*task.Task, error) { + row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE checker_for_task_id = ? LIMIT 1`, checkedTaskID) + t, err := scanTask(row) + if err == sql.ErrNoRows { + return nil, nil + } + return t, err +} + // AppendTaskInteraction appends a Q&A interaction to the task's interaction history. func (s *DB) AppendTaskInteraction(taskID string, interaction task.Interaction) error { tx, err := s.db.Begin() @@ -682,17 +937,35 @@ func scanTask(row scanner) (*task.Task, error) { timeoutNS int64 parentTaskID sql.NullString elaborationInput sql.NullString + project sql.NullString + repositoryURL sql.NullString rejectionComment sql.NullString questionJSON sql.NullString summary sql.NullString interactionsJSON sql.NullString + storyID sql.NullString + acceptanceCriteria sql.NullString + checkerForTaskID sql.NullString + checkerReport sql.NullString + ) + err := row.Scan( + &t.ID, &t.Name, &t.Description, &elaborationInput, &project, &repositoryURL, + &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, + &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt, + &rejectionComment, &questionJSON, &summary, &interactionsJSON, &storyID, + &acceptanceCriteria, &checkerForTaskID, &checkerReport, ) - err := row.Scan(&t.ID, &t.Name, &t.Description, &elaborationInput, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt, &rejectionComment, &questionJSON, &summary, &interactionsJSON) t.ParentTaskID = parentTaskID.String t.ElaborationInput = elaborationInput.String + t.Project = project.String + t.RepositoryURL = repositoryURL.String t.RejectionComment = rejectionComment.String t.QuestionJSON = questionJSON.String t.Summary = summary.String + t.StoryID = storyID.String + t.AcceptanceCriteria = acceptanceCriteria.String + t.CheckerForTaskID = checkerForTaskID.String + t.CheckerReport = checkerReport.String if err != nil { return nil, err } @@ -772,3 +1045,263 @@ func (s *DB) UpdateExecutionChangestats(execID string, stats *task.Changestats) func scanExecutionRows(rows *sql.Rows) (*Execution, error) { return scanExecution(rows) } + +// PushSubscription represents a browser push subscription. +type PushSubscription struct { + ID string `json:"id"` + Endpoint string `json:"endpoint"` + P256DHKey string `json:"p256dh_key"` + AuthKey string `json:"auth_key"` + CreatedAt time.Time `json:"created_at"` +} + +// SavePushSubscription inserts or replaces a push subscription by endpoint. +func (s *DB) SavePushSubscription(sub PushSubscription) error { + _, err := s.db.Exec(` + INSERT INTO push_subscriptions (id, endpoint, p256dh_key, auth_key) + VALUES (?, ?, ?, ?) + ON CONFLICT(endpoint) DO UPDATE SET + id = excluded.id, + p256dh_key = excluded.p256dh_key, + auth_key = excluded.auth_key`, + sub.ID, sub.Endpoint, sub.P256DHKey, sub.AuthKey, + ) + return err +} + +// DeletePushSubscription removes the subscription with the given endpoint. +func (s *DB) DeletePushSubscription(endpoint string) error { + _, err := s.db.Exec(`DELETE FROM push_subscriptions WHERE endpoint = ?`, endpoint) + return err +} + +// ListPushSubscriptions returns all registered push subscriptions. +func (s *DB) ListPushSubscriptions() ([]PushSubscription, error) { + rows, err := s.db.Query(`SELECT id, endpoint, p256dh_key, auth_key, created_at FROM push_subscriptions ORDER BY created_at`) + if err != nil { + return nil, err + } + defer rows.Close() + + var subs []PushSubscription + for rows.Next() { + var sub PushSubscription + var createdAt string + if err := rows.Scan(&sub.ID, &sub.Endpoint, &sub.P256DHKey, &sub.AuthKey, &createdAt); err != nil { + return nil, err + } + // Parse created_at; ignore errors (use zero time on failure). + for _, layout := range []string{time.RFC3339, "2006-01-02 15:04:05", "2006-01-02T15:04:05Z"} { + if t, err := time.Parse(layout, createdAt); err == nil { + sub.CreatedAt = t + break + } + } + subs = append(subs, sub) + } + if subs == nil { + subs = []PushSubscription{} + } + return subs, rows.Err() +} + +// GetSetting returns the value for a key, or ("", nil) if not found. +func (s *DB) GetSetting(key string) (string, error) { + var value string + err := s.db.QueryRow(`SELECT value FROM settings WHERE key = ?`, key).Scan(&value) + if err == sql.ErrNoRows { + return "", nil + } + return value, err +} + +// SetSetting upserts a key/value pair in the settings table. +func (s *DB) SetSetting(key, value string) error { + _, err := s.db.Exec(`INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value`, key, value) + return err +} + +// AgentEvent records a rate-limit state change for an agent. +type AgentEvent struct { + ID string `json:"id"` + Agent string `json:"agent"` + Event string `json:"event"` // "rate_limited" | "available" + Timestamp time.Time `json:"timestamp"` + Until *time.Time `json:"until,omitempty"` // non-nil for "rate_limited" events + Reason string `json:"reason"` // "transient" | "quota" +} + +// RecordAgentEvent inserts an agent rate-limit event. +func (s *DB) RecordAgentEvent(e AgentEvent) error { + _, err := s.db.Exec( + `INSERT INTO agent_events (id, agent, event, timestamp, until, reason) VALUES (?, ?, ?, ?, ?, ?)`, + e.ID, e.Agent, e.Event, e.Timestamp.UTC(), timeOrNull(e.Until), e.Reason, + ) + return err +} + +// ListAgentEvents returns agent events since the given time, newest first. +func (s *DB) ListAgentEvents(since time.Time) ([]AgentEvent, error) { + rows, err := s.db.Query( + `SELECT id, agent, event, timestamp, until, reason FROM agent_events WHERE timestamp >= ? ORDER BY timestamp DESC LIMIT 500`, + since.UTC(), + ) + if err != nil { + return nil, err + } + defer rows.Close() + var events []AgentEvent + for rows.Next() { + var e AgentEvent + var until sql.NullTime + var reason sql.NullString + if err := rows.Scan(&e.ID, &e.Agent, &e.Event, &e.Timestamp, &until, &reason); err != nil { + return nil, err + } + if until.Valid { + e.Until = &until.Time + } + e.Reason = reason.String + events = append(events, e) + } + return events, rows.Err() +} + +func timeOrNull(t *time.Time) interface{} { + if t == nil { + return nil + } + return t.UTC() +} + +// CreateProject inserts a new project. +func (s *DB) CreateProject(p *task.Project) error { + now := time.Now().UTC() + _, err := s.db.Exec( + `INSERT INTO projects (id, name, remote_url, local_path, type, deploy_script, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + p.ID, p.Name, p.RemoteURL, p.LocalPath, p.Type, p.DeployScript, now, now, + ) + return err +} + +// GetProject retrieves a project by ID. +func (s *DB) GetProject(id string) (*task.Project, error) { + row := s.db.QueryRow(`SELECT id, name, remote_url, local_path, type, deploy_script FROM projects WHERE id = ?`, id) + p := &task.Project{} + if err := row.Scan(&p.ID, &p.Name, &p.RemoteURL, &p.LocalPath, &p.Type, &p.DeployScript); err != nil { + return nil, err + } + return p, nil +} + +// ListProjects returns all projects. +func (s *DB) ListProjects() ([]*task.Project, error) { + rows, err := s.db.Query(`SELECT id, name, remote_url, local_path, type, deploy_script FROM projects ORDER BY name`) + if err != nil { + return nil, err + } + defer rows.Close() + var projects []*task.Project + for rows.Next() { + p := &task.Project{} + if err := rows.Scan(&p.ID, &p.Name, &p.RemoteURL, &p.LocalPath, &p.Type, &p.DeployScript); err != nil { + return nil, err + } + projects = append(projects, p) + } + return projects, rows.Err() +} + +// UpdateProject updates an existing project. +func (s *DB) UpdateProject(p *task.Project) error { + now := time.Now().UTC() + _, err := s.db.Exec( + `UPDATE projects SET name = ?, remote_url = ?, local_path = ?, type = ?, deploy_script = ?, updated_at = ? WHERE id = ?`, + p.Name, p.RemoteURL, p.LocalPath, p.Type, p.DeployScript, now, p.ID, + ) + return err +} + +// UpsertProject inserts or updates a project by ID (used for seeding). +func (s *DB) UpsertProject(p *task.Project) error { + now := time.Now().UTC() + _, err := s.db.Exec( + `INSERT INTO projects (id, name, remote_url, local_path, type, deploy_script, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET name=excluded.name, remote_url=excluded.remote_url, + local_path=excluded.local_path, type=excluded.type, deploy_script=excluded.deploy_script, updated_at=excluded.updated_at`, + p.ID, p.Name, p.RemoteURL, p.LocalPath, p.Type, p.DeployScript, now, now, + ) + return err +} + +// CreateStory inserts a new story. +func (s *DB) CreateStory(st *task.Story) error { + now := time.Now().UTC() + _, err := s.db.Exec( + `INSERT INTO stories (id, name, project_id, branch_name, deploy_config, validation_json, status, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + st.ID, st.Name, st.ProjectID, st.BranchName, st.DeployConfig, st.ValidationJSON, string(st.Status), now, now, + ) + return err +} + +// GetStory retrieves a story by ID. +func (s *DB) GetStory(id string) (*task.Story, error) { + row := s.db.QueryRow(`SELECT id, name, project_id, branch_name, deploy_config, validation_json, status, created_at, updated_at FROM stories WHERE id = ?`, id) + st := &task.Story{} + var status string + if err := row.Scan(&st.ID, &st.Name, &st.ProjectID, &st.BranchName, &st.DeployConfig, &st.ValidationJSON, &status, &st.CreatedAt, &st.UpdatedAt); err != nil { + return nil, err + } + st.Status = task.StoryState(status) + return st, nil +} + +// ListStories returns all stories ordered by creation time descending. +func (s *DB) ListStories() ([]*task.Story, error) { + rows, err := s.db.Query(`SELECT id, name, project_id, branch_name, deploy_config, validation_json, status, created_at, updated_at FROM stories ORDER BY created_at DESC`) + if err != nil { + return nil, err + } + defer rows.Close() + var stories []*task.Story + for rows.Next() { + st := &task.Story{} + var status string + if err := rows.Scan(&st.ID, &st.Name, &st.ProjectID, &st.BranchName, &st.DeployConfig, &st.ValidationJSON, &status, &st.CreatedAt, &st.UpdatedAt); err != nil { + return nil, err + } + st.Status = task.StoryState(status) + stories = append(stories, st) + } + return stories, rows.Err() +} + +// UpdateStoryStatus updates the status of a story. +func (s *DB) UpdateStoryStatus(id string, status task.StoryState) error { + now := time.Now().UTC() + _, err := s.db.Exec(`UPDATE stories SET status = ?, updated_at = ? WHERE id = ?`, string(status), now, id) + return err +} + +// ListTasksByStory returns all tasks associated with a story, ordered by creation time ascending. +func (s *DB) ListTasksByStory(storyID string) ([]*task.Task, error) { + rows, err := s.db.Query( + `SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE story_id = ? ORDER BY created_at ASC`, + storyID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var tasks []*task.Task + for rows.Next() { + t, err := scanTaskRows(rows) + if err != nil { + return nil, err + } + tasks = append(tasks, t) + } + return tasks, rows.Err() +} diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index 752c5b1..0e67e02 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -41,7 +41,6 @@ func TestCreateTask_AndGetTask(t *testing.T) { Type: "claude", Model: "sonnet", Instructions: "do it", - ProjectDir: "/tmp", MaxBudgetUSD: 2.5, }, Priority: task.PriorityHigh, @@ -990,6 +989,128 @@ func TestAppendTaskInteraction_NotFound(t *testing.T) { } } +func TestCreateTask_Project_RoundTrip(t *testing.T) { + db := testDB(t) + now := time.Now().UTC().Truncate(time.Second) + + tk := &task.Task{ + ID: "proj-1", + Name: "Project Task", + Project: "my-project", + Agent: task.AgentConfig{Type: "claude", Instructions: "do it"}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + if err := db.CreateTask(tk); err != nil { + t.Fatalf("creating task: %v", err) + } + + got, err := db.GetTask("proj-1") + if err != nil { + t.Fatalf("getting task: %v", err) + } + if got.Project != "my-project" { + t.Errorf("project: want %q, got %q", "my-project", got.Project) + } +} + +// ── Push subscription tests ─────────────────────────────────────────────────── + +func TestPushSubscription_SaveAndList(t *testing.T) { + db := testDB(t) + + sub := PushSubscription{ + ID: "sub-1", + Endpoint: "https://push.example.com/endpoint1", + P256DHKey: "p256dhkey1", + AuthKey: "authkey1", + } + if err := db.SavePushSubscription(sub); err != nil { + t.Fatalf("SavePushSubscription: %v", err) + } + + subs, err := db.ListPushSubscriptions() + if err != nil { + t.Fatalf("ListPushSubscriptions: %v", err) + } + if len(subs) != 1 { + t.Fatalf("want 1 subscription, got %d", len(subs)) + } + if subs[0].Endpoint != sub.Endpoint { + t.Errorf("endpoint: want %q, got %q", sub.Endpoint, subs[0].Endpoint) + } + if subs[0].P256DHKey != sub.P256DHKey { + t.Errorf("p256dh_key: want %q, got %q", sub.P256DHKey, subs[0].P256DHKey) + } + if subs[0].AuthKey != sub.AuthKey { + t.Errorf("auth_key: want %q, got %q", sub.AuthKey, subs[0].AuthKey) + } +} + +func TestPushSubscription_Delete(t *testing.T) { + db := testDB(t) + + sub := PushSubscription{ + ID: "sub-del", + Endpoint: "https://push.example.com/todelete", + P256DHKey: "key", + AuthKey: "auth", + } + if err := db.SavePushSubscription(sub); err != nil { + t.Fatalf("SavePushSubscription: %v", err) + } + + if err := db.DeletePushSubscription(sub.Endpoint); err != nil { + t.Fatalf("DeletePushSubscription: %v", err) + } + + subs, err := db.ListPushSubscriptions() + if err != nil { + t.Fatalf("ListPushSubscriptions: %v", err) + } + if len(subs) != 0 { + t.Errorf("want 0 subscriptions after delete, got %d", len(subs)) + } +} + +func TestPushSubscription_UniqueEndpoint(t *testing.T) { + db := testDB(t) + + sub := PushSubscription{ + ID: "sub-uq", + Endpoint: "https://push.example.com/unique", + P256DHKey: "key1", + AuthKey: "auth1", + } + if err := db.SavePushSubscription(sub); err != nil { + t.Fatalf("SavePushSubscription first: %v", err) + } + + // Save again with same endpoint — should update or replace, not error. + sub2 := PushSubscription{ + ID: "sub-uq2", + Endpoint: "https://push.example.com/unique", + P256DHKey: "key2", + AuthKey: "auth2", + } + if err := db.SavePushSubscription(sub2); err != nil { + t.Fatalf("SavePushSubscription second (upsert): %v", err) + } + + subs, err := db.ListPushSubscriptions() + if err != nil { + t.Fatalf("ListPushSubscriptions: %v", err) + } + if len(subs) != 1 { + t.Errorf("want 1 subscription after upsert, got %d", len(subs)) + } +} + func TestExecution_StoreAndRetrieveChangestats(t *testing.T) { db := testDB(t) now := time.Now().UTC().Truncate(time.Second) @@ -1032,3 +1153,252 @@ func TestExecution_StoreAndRetrieveChangestats(t *testing.T) { } } +func TestCreateProject(t *testing.T) { + db := testDB(t) + defer db.Close() + + p := &task.Project{ + ID: "proj-1", + Name: "claudomator", + RemoteURL: "/bare/claudomator.git", + LocalPath: "/workspace/claudomator", + Type: "web", + } + if err := db.CreateProject(p); err != nil { + t.Fatalf("CreateProject: %v", err) + } + got, err := db.GetProject("proj-1") + if err != nil { + t.Fatalf("GetProject: %v", err) + } + if got.Name != "claudomator" { + t.Errorf("Name: want claudomator, got %q", got.Name) + } + if got.LocalPath != "/workspace/claudomator" { + t.Errorf("LocalPath: want /workspace/claudomator, got %q", got.LocalPath) + } +} + +func TestListProjects(t *testing.T) { + db := testDB(t) + defer db.Close() + + for _, p := range []*task.Project{ + {ID: "p1", Name: "alpha", Type: "web"}, + {ID: "p2", Name: "beta", Type: "android"}, + } { + if err := db.CreateProject(p); err != nil { + t.Fatalf("CreateProject: %v", err) + } + } + list, err := db.ListProjects() + if err != nil { + t.Fatalf("ListProjects: %v", err) + } + if len(list) != 2 { + t.Errorf("want 2 projects, got %d", len(list)) + } +} + +func TestUpdateProject(t *testing.T) { + db := testDB(t) + defer db.Close() + + p := &task.Project{ID: "p1", Name: "original", Type: "web"} + if err := db.CreateProject(p); err != nil { + t.Fatalf("CreateProject: %v", err) + } + p.Name = "updated" + if err := db.UpdateProject(p); err != nil { + t.Fatalf("UpdateProject: %v", err) + } + got, _ := db.GetProject("p1") + if got.Name != "updated" { + t.Errorf("Name after update: want updated, got %q", got.Name) + } +} + +func TestCreateStory(t *testing.T) { + db := testDB(t) + st := &task.Story{ + ID: "story-1", + Name: "My Story", + Status: task.StoryPending, + } + if err := db.CreateStory(st); err != nil { + t.Fatalf("CreateStory: %v", err) + } +} + +func TestGetStory(t *testing.T) { + db := testDB(t) + st := &task.Story{ + ID: "story-2", + Name: "Get Story", + ProjectID: "proj-1", + Status: task.StoryPending, + } + if err := db.CreateStory(st); err != nil { + t.Fatalf("CreateStory: %v", err) + } + got, err := db.GetStory("story-2") + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Name != "Get Story" { + t.Errorf("Name: want 'Get Story', got %q", got.Name) + } + if got.ProjectID != "proj-1" { + t.Errorf("ProjectID: want 'proj-1', got %q", got.ProjectID) + } + if got.Status != task.StoryPending { + t.Errorf("Status: want PENDING, got %q", got.Status) + } +} + +func TestListStories(t *testing.T) { + db := testDB(t) + for _, name := range []string{"A", "B", "C"} { + if err := db.CreateStory(&task.Story{ID: name, Name: name, Status: task.StoryPending}); err != nil { + t.Fatalf("CreateStory %s: %v", name, err) + } + } + stories, err := db.ListStories() + if err != nil { + t.Fatalf("ListStories: %v", err) + } + if len(stories) != 3 { + t.Errorf("want 3 stories, got %d", len(stories)) + } +} + +func TestUpdateStoryStatus(t *testing.T) { + db := testDB(t) + st := &task.Story{ID: "story-upd", Name: "Upd", Status: task.StoryPending} + if err := db.CreateStory(st); err != nil { + t.Fatalf("CreateStory: %v", err) + } + if err := db.UpdateStoryStatus("story-upd", task.StoryInProgress); err != nil { + t.Fatalf("UpdateStoryStatus: %v", err) + } + got, _ := db.GetStory("story-upd") + if got.Status != task.StoryInProgress { + t.Errorf("Status: want IN_PROGRESS, got %q", got.Status) + } +} + +func TestListTasksByStory(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + + if err := db.CreateStory(&task.Story{ID: "story-tasks", Name: "S", Status: task.StoryPending}); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + makeTask := func(id string) *task.Task { + return &task.Task{ + ID: id, + Name: id, + StoryID: "story-tasks", + Agent: task.AgentConfig{Type: "claude"}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + } + + if err := db.CreateTask(makeTask("t1")); err != nil { + t.Fatal(err) + } + if err := db.CreateTask(makeTask("t2")); err != nil { + t.Fatal(err) + } + + tasks, err := db.ListTasksByStory("story-tasks") + if err != nil { + t.Fatalf("ListTasksByStory: %v", err) + } + if len(tasks) != 2 { + t.Errorf("want 2 tasks, got %d", len(tasks)) + } + for _, tk := range tasks { + if tk.StoryID != "story-tasks" { + t.Errorf("task %s: StoryID want 'story-tasks', got %q", tk.ID, tk.StoryID) + } + } +} + +func TestUpdateTaskCheckerReport(t *testing.T) { + db := testDB(t) + tk := &task.Task{ + ID: "cr-1", Name: "orig", RepositoryURL: "https://github.com/x/y", + Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + } + if err := db.CreateTask(tk); err != nil { + t.Fatalf("CreateTask: %v", err) + } + if err := db.UpdateTaskCheckerReport("cr-1", "Tests failed: missing endpoint"); err != nil { + t.Fatalf("UpdateTaskCheckerReport: %v", err) + } + got, err := db.GetTask("cr-1") + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.CheckerReport != "Tests failed: missing endpoint" { + t.Errorf("expected checker report, got %q", got.CheckerReport) + } +} + +func TestGetCheckerTask(t *testing.T) { + db := testDB(t) + checked := &task.Task{ + ID: "chk-orig", Name: "orig", RepositoryURL: "https://github.com/x/y", + Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + } + if err := db.CreateTask(checked); err != nil { + t.Fatalf("CreateTask checked: %v", err) + } + checker := &task.Task{ + ID: "chk-checker", Name: "Check: orig", CheckerForTaskID: "chk-orig", + RepositoryURL: "https://github.com/x/y", + Agent: task.AgentConfig{Type: "claude", Instructions: "validate"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + } + if err := db.CreateTask(checker); err != nil { + t.Fatalf("CreateTask checker: %v", err) + } + + // Should find the checker task. + got, err := db.GetCheckerTask("chk-orig") + if err != nil { + t.Fatalf("GetCheckerTask: %v", err) + } + if got == nil || got.ID != "chk-checker" { + t.Errorf("expected checker task ID chk-checker, got %v", got) + } + + // Should return nil when no checker exists. + none, err := db.GetCheckerTask("nonexistent") + if err != nil { + t.Fatalf("GetCheckerTask nonexistent: %v", err) + } + if none != nil { + t.Errorf("expected nil for task with no checker, got %v", none) + } +} + diff --git a/internal/storage/seed.go b/internal/storage/seed.go new file mode 100644 index 0000000..c2df84f --- /dev/null +++ b/internal/storage/seed.go @@ -0,0 +1,62 @@ +package storage + +import ( + "os/exec" + "strings" + + "github.com/thepeterstone/claudomator/internal/task" +) + +// SeedProjects upserts the default project registry on startup. +func (s *DB) SeedProjects() error { + projects := []*task.Project{ + { + ID: "claudomator", + Name: "claudomator", + LocalPath: "/workspace/claudomator", + RemoteURL: localBareRemote("/workspace/claudomator"), + Type: "web", + DeployScript: "/workspace/claudomator/scripts/deploy", + }, + { + ID: "nav", + Name: "nav", + LocalPath: "/workspace/nav", + RemoteURL: localBareRemote("/workspace/nav"), + Type: "android", + }, + { + ID: "doot", + Name: "doot", + LocalPath: "/workspace/doot", + RemoteURL: localBareRemote("/workspace/doot"), + Type: "web", + DeployScript: "/workspace/doot/scripts/deploy", + }, + { + ID: "modal-shell", + Name: "modal-shell", + LocalPath: "/workspace/modal-shell", + RemoteURL: localBareRemote("/workspace/modal-shell"), + Type: "web", + }, + } + for _, p := range projects { + if err := s.UpsertProject(p); err != nil { + return err + } + } + return nil +} + +// localBareRemote returns the URL of the "local" git remote for dir, +// falling back to dir itself if the remote is not configured. +func localBareRemote(dir string) string { + out, err := exec.Command("git", "-C", dir, "remote", "get-url", "local").Output() + if err == nil { + if url := strings.TrimSpace(string(out)); url != "" { + return url + } + } + return dir +} diff --git a/internal/storage/sqlite_cgo.go b/internal/storage/sqlite_cgo.go new file mode 100644 index 0000000..0956852 --- /dev/null +++ b/internal/storage/sqlite_cgo.go @@ -0,0 +1,5 @@ +//go:build cgo + +package storage + +import _ "github.com/mattn/go-sqlite3" diff --git a/internal/storage/sqlite_nocgo.go b/internal/storage/sqlite_nocgo.go new file mode 100644 index 0000000..9862440 --- /dev/null +++ b/internal/storage/sqlite_nocgo.go @@ -0,0 +1,21 @@ +//go:build !cgo + +package storage + +import ( + "database/sql" + "database/sql/driver" + + modernc "modernc.org/sqlite" +) + +// Register the modernc pure-Go SQLite driver under the "sqlite3" name so that +// the rest of the codebase can use sql.Open("sqlite3", ...) regardless of +// whether CGO is enabled. +func init() { + sql.Register("sqlite3", &modernc.Driver{}) +} + +// modernc.Driver satisfies driver.Driver; this blank-import ensures the +// compiler sees the interface is satisfied. +var _ driver.Driver = (*modernc.Driver)(nil) diff --git a/internal/task/project.go b/internal/task/project.go new file mode 100644 index 0000000..bd8a4fb --- /dev/null +++ b/internal/task/project.go @@ -0,0 +1,11 @@ +package task + +// Project represents a registered codebase that agents can operate on. +type Project struct { + ID string `json:"id"` + Name string `json:"name"` + RemoteURL string `json:"remote_url"` + LocalPath string `json:"local_path"` + Type string `json:"type"` // "web" | "android" + DeployScript string `json:"deploy_script"` // optional path or command +} diff --git a/internal/task/story.go b/internal/task/story.go new file mode 100644 index 0000000..536bda1 --- /dev/null +++ b/internal/task/story.go @@ -0,0 +1,41 @@ +package task + +import "time" + +type StoryState string + +const ( + StoryPending StoryState = "PENDING" + StoryInProgress StoryState = "IN_PROGRESS" + StoryShippable StoryState = "SHIPPABLE" + StoryDeployed StoryState = "DEPLOYED" + StoryValidating StoryState = "VALIDATING" + StoryReviewReady StoryState = "REVIEW_READY" + StoryNeedsFix StoryState = "NEEDS_FIX" +) + +type Story struct { + ID string `json:"id"` + Name string `json:"name"` + ProjectID string `json:"project_id"` + BranchName string `json:"branch_name"` + DeployConfig string `json:"deploy_config"` + ValidationJSON string `json:"validation_json"` + Status StoryState `json:"status"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +var validStoryTransitions = map[StoryState]map[StoryState]bool{ + StoryPending: {StoryInProgress: true}, + StoryInProgress: {StoryShippable: true, StoryNeedsFix: true}, + StoryShippable: {StoryDeployed: true}, + StoryDeployed: {StoryValidating: true}, + StoryValidating: {StoryReviewReady: true, StoryNeedsFix: true}, + StoryReviewReady: {}, + StoryNeedsFix: {StoryInProgress: true}, +} + +func ValidStoryTransition(from, to StoryState) bool { + return validStoryTransitions[from][to] +} diff --git a/internal/task/story_test.go b/internal/task/story_test.go new file mode 100644 index 0000000..38d0290 --- /dev/null +++ b/internal/task/story_test.go @@ -0,0 +1,42 @@ +package task + +import "testing" + +func TestValidStoryTransition_Valid(t *testing.T) { + cases := []struct { + from StoryState + to StoryState + }{ + {StoryPending, StoryInProgress}, + {StoryInProgress, StoryShippable}, + {StoryInProgress, StoryNeedsFix}, + {StoryNeedsFix, StoryInProgress}, + {StoryShippable, StoryDeployed}, + {StoryDeployed, StoryValidating}, + {StoryValidating, StoryReviewReady}, + {StoryValidating, StoryNeedsFix}, + } + for _, tc := range cases { + if !ValidStoryTransition(tc.from, tc.to) { + t.Errorf("expected valid transition %s → %s", tc.from, tc.to) + } + } +} + +func TestValidStoryTransition_Invalid(t *testing.T) { + cases := []struct { + from StoryState + to StoryState + }{ + {StoryPending, StoryDeployed}, + {StoryReviewReady, StoryPending}, + {StoryReviewReady, StoryInProgress}, + {StoryReviewReady, StoryShippable}, + {StoryShippable, StoryPending}, + } + for _, tc := range cases { + if ValidStoryTransition(tc.from, tc.to) { + t.Errorf("expected invalid transition %s → %s", tc.from, tc.to) + } + } +} diff --git a/internal/task/task.go b/internal/task/task.go index fd1dde6..935a238 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -32,13 +32,14 @@ type AgentConfig struct { Model string `yaml:"model" json:"model"` ContextFiles []string `yaml:"context_files" json:"context_files"` Instructions string `yaml:"instructions" json:"instructions"` - ProjectDir string `yaml:"project_dir" json:"project_dir"` + ContainerImage string `yaml:"container_image" json:"container_image"` MaxBudgetUSD float64 `yaml:"max_budget_usd" json:"max_budget_usd"` PermissionMode string `yaml:"permission_mode" json:"permission_mode"` AllowedTools []string `yaml:"allowed_tools" json:"allowed_tools"` DisallowedTools []string `yaml:"disallowed_tools" json:"disallowed_tools"` SystemPromptAppend string `yaml:"system_prompt_append" json:"system_prompt_append"` AdditionalArgs []string `yaml:"additional_args" json:"additional_args"` + ProjectDir string `yaml:"project_dir" json:"project_dir,omitempty"` SkipPlanning bool `yaml:"skip_planning" json:"skip_planning"` // Local-runner sampling controls. Pointer for Temperature so a 0 value can @@ -79,12 +80,19 @@ type Task struct { ParentTaskID string `yaml:"parent_task_id" json:"parent_task_id"` Name string `yaml:"name" json:"name"` Description string `yaml:"description" json:"description"` + Project string `yaml:"project" json:"project"` // Human-readable project name + RepositoryURL string `yaml:"repository_url" json:"repository_url"` Agent AgentConfig `yaml:"agent" json:"agent"` Timeout Duration `yaml:"timeout" json:"timeout"` Retry RetryConfig `yaml:"retry" json:"retry"` Priority Priority `yaml:"priority" json:"priority"` Tags []string `yaml:"tags" json:"tags"` DependsOn []string `yaml:"depends_on" json:"depends_on"` + StoryID string `yaml:"-" json:"story_id,omitempty"` + BranchName string `yaml:"-" json:"branch_name,omitempty"` + AcceptanceCriteria string `yaml:"-" json:"acceptance_criteria,omitempty"` + CheckerForTaskID string `yaml:"-" json:"checker_for_task_id,omitempty"` + CheckerReport string `yaml:"-" json:"checker_report,omitempty"` State State `yaml:"-" json:"state"` RejectionComment string `yaml:"-" json:"rejection_comment,omitempty"` QuestionJSON string `yaml:"-" json:"question,omitempty"` @@ -130,7 +138,7 @@ type BatchFile struct { // BLOCKED may advance to READY when all subtasks complete, or back to QUEUED on user answer. var validTransitions = map[State]map[State]bool{ StatePending: {StateQueued: true, StateCancelled: true}, - StateQueued: {StateRunning: true, StateCancelled: true}, + StateQueued: {StateRunning: true, StateCancelled: true, StateReady: true}, // READY: parent task completed via subtask delegation StateRunning: {StateReady: true, StateCompleted: true, StateFailed: true, StateTimedOut: true, StateCancelled: true, StateBudgetExceeded: true, StateBlocked: true}, StateReady: {StateCompleted: true, StatePending: true}, StateFailed: {StateQueued: true}, // retry diff --git a/internal/task/task_test.go b/internal/task/task_test.go index 15ba019..e6a17b8 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -100,3 +100,31 @@ func TestDuration_MarshalYAML(t *testing.T) { t.Errorf("expected '15m0s', got %v", v) } } + +func TestTask_ProjectField(t *testing.T) { + t.Run("struct assignment", func(t *testing.T) { + task := Task{Project: "my-project"} + if task.Project != "my-project" { + t.Errorf("expected Project 'my-project', got %q", task.Project) + } + }) + + t.Run("yaml parsing", func(t *testing.T) { + yaml := ` +name: "Test Task" +project: my-project +agent: + instructions: "Do something" +` + tasks, err := Parse([]byte(yaml)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tasks) != 1 { + t.Fatalf("expected 1 task, got %d", len(tasks)) + } + if tasks[0].Project != "my-project" { + t.Errorf("expected Project 'my-project', got %q", tasks[0].Project) + } + }) +} diff --git a/internal/task/validator.go b/internal/task/validator.go index 003fab9..43e482e 100644 --- a/internal/task/validator.go +++ b/internal/task/validator.go @@ -29,6 +29,9 @@ func Validate(t *Task) error { if t.Name == "" { ve.Add("name is required") } + if t.RepositoryURL == "" { + ve.Add("repository_url is required") + } if t.Agent.Instructions == "" { ve.Add("agent.instructions is required") } diff --git a/internal/task/validator_test.go b/internal/task/validator_test.go index 657d93f..2c6735c 100644 --- a/internal/task/validator_test.go +++ b/internal/task/validator_test.go @@ -9,10 +9,10 @@ func validTask() *Task { return &Task{ ID: "test-id", Name: "Valid Task", + RepositoryURL: "https://github.com/user/repo", Agent: AgentConfig{ Type: "claude", Instructions: "do something", - ProjectDir: "/tmp", }, Priority: PriorityNormal, Retry: RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, |
