summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/api/changestats.go15
-rw-r--r--internal/api/deployment.go4
-rw-r--r--internal/api/drops.go165
-rw-r--r--internal/api/drops_test.go159
-rw-r--r--internal/api/elaborate.go205
-rw-r--r--internal/api/elaborate_test.go11
-rw-r--r--internal/api/executions.go42
-rw-r--r--internal/api/projects.go71
-rw-r--r--internal/api/push.go120
-rw-r--r--internal/api/push_test.go159
-rw-r--r--internal/api/server.go131
-rw-r--r--internal/api/server_test.go464
-rw-r--r--internal/api/stories.go378
-rw-r--r--internal/api/stories_test.go351
-rw-r--r--internal/api/task_view.go47
-rw-r--r--internal/api/webhook.go74
-rw-r--r--internal/api/webhook_test.go114
-rw-r--r--internal/cli/list.go6
-rw-r--r--internal/cli/project_test.go102
-rw-r--r--internal/cli/root.go2
-rw-r--r--internal/cli/run.go32
-rw-r--r--internal/cli/serve.go126
-rw-r--r--internal/cli/status.go3
-rw-r--r--internal/cli/version.go18
-rw-r--r--internal/config/config.go43
-rw-r--r--internal/executor/claude.go170
-rw-r--r--internal/executor/claude_test.go127
-rw-r--r--internal/executor/container.go549
-rw-r--r--internal/executor/container_test.go687
-rw-r--r--internal/executor/executor.go695
-rw-r--r--internal/executor/executor_test.go911
-rw-r--r--internal/executor/helpers.go205
-rw-r--r--internal/executor/preamble.go1
-rw-r--r--internal/executor/preamble_test.go7
-rw-r--r--internal/executor/question.go84
-rw-r--r--internal/executor/question_test.go58
-rw-r--r--internal/executor/ratelimit.go6
-rw-r--r--internal/executor/stream_test.go25
-rw-r--r--internal/notify/vapid.go25
-rw-r--r--internal/notify/vapid_test.go64
-rw-r--r--internal/notify/webpush.go106
-rw-r--r--internal/notify/webpush_test.go191
-rw-r--r--internal/storage/db.go569
-rw-r--r--internal/storage/db_test.go372
-rw-r--r--internal/storage/seed.go62
-rw-r--r--internal/storage/sqlite_cgo.go5
-rw-r--r--internal/storage/sqlite_nocgo.go21
-rw-r--r--internal/task/project.go11
-rw-r--r--internal/task/story.go41
-rw-r--r--internal/task/story_test.go42
-rw-r--r--internal/task/task.go12
-rw-r--r--internal/task/task_test.go28
-rw-r--r--internal/task/validator.go3
-rw-r--r--internal/task/validator_test.go2
54 files changed, 7194 insertions, 727 deletions
diff --git a/internal/api/changestats.go b/internal/api/changestats.go
deleted file mode 100644
index 4f18f7f..0000000
--- a/internal/api/changestats.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package api
-
-import "github.com/thepeterstone/claudomator/internal/task"
-
-// parseChangestatFromOutput delegates to task.ParseChangestatFromOutput.
-// Kept as a package-local wrapper for use within the api package.
-func parseChangestatFromOutput(output string) *task.Changestats {
- return task.ParseChangestatFromOutput(output)
-}
-
-// parseChangestatFromFile delegates to task.ParseChangestatFromFile.
-// Kept as a package-local wrapper for use within the api package.
-func parseChangestatFromFile(path string) *task.Changestats {
- return task.ParseChangestatFromFile(path)
-}
diff --git a/internal/api/deployment.go b/internal/api/deployment.go
index d927545..8972fe2 100644
--- a/internal/api/deployment.go
+++ b/internal/api/deployment.go
@@ -23,7 +23,7 @@ func (s *Server) handleGetDeploymentStatus(w http.ResponseWriter, r *http.Reques
if err != nil {
if err == sql.ErrNoRows {
// No execution yet — return status with no fix commits.
- status := deployment.Check(nil, tk.Agent.ProjectDir)
+ status := deployment.Check(nil, tk.RepositoryURL)
writeJSON(w, http.StatusOK, status)
return
}
@@ -31,6 +31,6 @@ func (s *Server) handleGetDeploymentStatus(w http.ResponseWriter, r *http.Reques
return
}
- status := deployment.Check(exec.Commits, tk.Agent.ProjectDir)
+ status := deployment.Check(exec.Commits, tk.RepositoryURL)
writeJSON(w, http.StatusOK, status)
}
diff --git a/internal/api/drops.go b/internal/api/drops.go
new file mode 100644
index 0000000..a5000f1
--- /dev/null
+++ b/internal/api/drops.go
@@ -0,0 +1,165 @@
+package api
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+// handleListDrops returns a JSON array of files in the drops directory.
+func (s *Server) handleListDrops(w http.ResponseWriter, r *http.Request) {
+ if s.dropsDir == "" {
+ writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "drops directory not configured"})
+ return
+ }
+
+ entries, err := os.ReadDir(s.dropsDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ writeJSON(w, http.StatusOK, []map[string]interface{}{})
+ return
+ }
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list drops"})
+ return
+ }
+
+ type fileEntry struct {
+ Name string `json:"name"`
+ Size int64 `json:"size"`
+ Modified time.Time `json:"modified"`
+ }
+ files := []fileEntry{}
+ for _, e := range entries {
+ if e.IsDir() {
+ continue
+ }
+ info, err := e.Info()
+ if err != nil {
+ continue
+ }
+ files = append(files, fileEntry{
+ Name: e.Name(),
+ Size: info.Size(),
+ Modified: info.ModTime().UTC(),
+ })
+ }
+ writeJSON(w, http.StatusOK, files)
+}
+
+// handleGetDrop serves a file from the drops directory as an attachment.
+func (s *Server) handleGetDrop(w http.ResponseWriter, r *http.Request) {
+ if s.dropsDir == "" {
+ writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "drops directory not configured"})
+ return
+ }
+
+ filename := r.PathValue("filename")
+ if strings.Contains(filename, "/") || strings.Contains(filename, "..") {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid filename"})
+ return
+ }
+
+ path := filepath.Join(s.dropsDir, filepath.Clean(filename))
+ // Extra safety: ensure the resolved path is still inside dropsDir.
+ if !strings.HasPrefix(path, s.dropsDir) {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid filename"})
+ return
+ }
+
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "file not found"})
+ return
+ }
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to open file"})
+ return
+ }
+ defer f.Close()
+
+ w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename))
+ w.Header().Set("Content-Type", "application/octet-stream")
+ io.Copy(w, f) //nolint:errcheck
+}
+
+// handlePostDrop accepts a file upload (multipart/form-data or raw body with ?filename=).
+func (s *Server) handlePostDrop(w http.ResponseWriter, r *http.Request) {
+ if s.dropsDir == "" {
+ writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "drops directory not configured"})
+ return
+ }
+
+ if err := os.MkdirAll(s.dropsDir, 0700); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to create drops directory"})
+ return
+ }
+
+ ct := r.Header.Get("Content-Type")
+ if strings.Contains(ct, "multipart/form-data") {
+ s.handleMultipartDrop(w, r)
+ return
+ }
+
+ // Raw body with ?filename= query param.
+ filename := r.URL.Query().Get("filename")
+ if filename == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "filename query param required for raw upload"})
+ return
+ }
+ if strings.Contains(filename, "/") || strings.Contains(filename, "..") {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid filename"})
+ return
+ }
+ path := filepath.Join(s.dropsDir, filename)
+ data, err := io.ReadAll(r.Body)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to read body"})
+ return
+ }
+ if err := os.WriteFile(path, data, 0600); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to save file"})
+ return
+ }
+ writeJSON(w, http.StatusCreated, map[string]interface{}{"name": filename, "size": len(data)})
+}
+
+func (s *Server) handleMultipartDrop(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseMultipartForm(32 << 20); err != nil { // 32 MB limit
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "failed to parse multipart form: " + err.Error()})
+ return
+ }
+
+ file, header, err := r.FormFile("file")
+ if err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "missing 'file' field: " + err.Error()})
+ return
+ }
+ defer file.Close()
+
+ filename := filepath.Base(header.Filename)
+ if filename == "" || filename == "." {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid filename"})
+ return
+ }
+
+ path := filepath.Join(s.dropsDir, filename)
+ dst, err := os.Create(path)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to create file"})
+ return
+ }
+ defer dst.Close()
+
+ n, err := io.Copy(dst, file)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to write file"})
+ return
+ }
+
+ writeJSON(w, http.StatusCreated, map[string]interface{}{"name": filename, "size": n})
+}
+
diff --git a/internal/api/drops_test.go b/internal/api/drops_test.go
new file mode 100644
index 0000000..ab67489
--- /dev/null
+++ b/internal/api/drops_test.go
@@ -0,0 +1,159 @@
+package api
+
+import (
+ "bytes"
+ "encoding/json"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+func testServerWithDrops(t *testing.T) (*Server, string) {
+ t.Helper()
+ srv, _ := testServer(t)
+ dropsDir := t.TempDir()
+ srv.SetDropsDir(dropsDir)
+ return srv, dropsDir
+}
+
+func TestHandleListDrops_Empty(t *testing.T) {
+ srv, _ := testServerWithDrops(t)
+
+ req := httptest.NewRequest("GET", "/api/drops", nil)
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d", rec.Code)
+ }
+
+ var files []map[string]interface{}
+ if err := json.NewDecoder(rec.Body).Decode(&files); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if len(files) != 0 {
+ t.Errorf("want empty list, got %d entries", len(files))
+ }
+}
+
+func TestHandleListDrops_WithFile(t *testing.T) {
+ srv, dropsDir := testServerWithDrops(t)
+
+ // Create a file in the drops dir.
+ if err := os.WriteFile(filepath.Join(dropsDir, "hello.txt"), []byte("world"), 0600); err != nil {
+ t.Fatal(err)
+ }
+
+ req := httptest.NewRequest("GET", "/api/drops", nil)
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d: %s", rec.Code, rec.Body.String())
+ }
+
+ var files []map[string]interface{}
+ if err := json.NewDecoder(rec.Body).Decode(&files); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if len(files) != 1 {
+ t.Fatalf("want 1 file, got %d", len(files))
+ }
+ if files[0]["name"] != "hello.txt" {
+ t.Errorf("name: want %q, got %v", "hello.txt", files[0]["name"])
+ }
+}
+
+func TestHandlePostDrop_Multipart(t *testing.T) {
+ srv, dropsDir := testServerWithDrops(t)
+
+ var buf bytes.Buffer
+ w := multipart.NewWriter(&buf)
+ fw, err := w.CreateFormFile("file", "test.txt")
+ if err != nil {
+ t.Fatal(err)
+ }
+ fw.Write([]byte("hello world")) //nolint:errcheck
+ w.Close()
+
+ req := httptest.NewRequest("POST", "/api/drops", &buf)
+ req.Header.Set("Content-Type", w.FormDataContentType())
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusCreated {
+ t.Fatalf("want 201, got %d: %s", rec.Code, rec.Body.String())
+ }
+
+ var resp map[string]interface{}
+ if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if resp["name"] != "test.txt" {
+ t.Errorf("name: want %q, got %v", "test.txt", resp["name"])
+ }
+
+ // Verify file was created on disk.
+ content, err := os.ReadFile(filepath.Join(dropsDir, "test.txt"))
+ if err != nil {
+ t.Fatalf("reading uploaded file: %v", err)
+ }
+ if string(content) != "hello world" {
+ t.Errorf("content: want %q, got %q", "hello world", content)
+ }
+}
+
+func TestHandleGetDrop_Download(t *testing.T) {
+ srv, dropsDir := testServerWithDrops(t)
+
+ if err := os.WriteFile(filepath.Join(dropsDir, "download.txt"), []byte("download me"), 0600); err != nil {
+ t.Fatal(err)
+ }
+
+ req := httptest.NewRequest("GET", "/api/drops/download.txt", nil)
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d", rec.Code)
+ }
+
+ cd := rec.Header().Get("Content-Disposition")
+ if !strings.Contains(cd, "attachment") {
+ t.Errorf("want Content-Disposition: attachment, got %q", cd)
+ }
+ if rec.Body.String() != "download me" {
+ t.Errorf("body: want %q, got %q", "download me", rec.Body.String())
+ }
+}
+
+func TestHandleGetDrop_PathTraversal(t *testing.T) {
+ srv, _ := testServerWithDrops(t)
+
+ // Attempt path traversal — should be rejected.
+ req := httptest.NewRequest("GET", "/api/drops/..%2Fetc%2Fpasswd", nil)
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ // The Go net/http router will handle %2F-encoded slashes as literal characters,
+ // so the filename becomes "../etc/passwd". Our handler should reject it.
+ if rec.Code == http.StatusOK {
+ t.Error("expected non-200 for path traversal attempt")
+ }
+}
+
+func TestHandleGetDrop_NotFound(t *testing.T) {
+ srv, _ := testServerWithDrops(t)
+
+ req := httptest.NewRequest("GET", "/api/drops/notexist.txt", nil)
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusNotFound {
+ t.Fatalf("want 404, got %d", rec.Code)
+ }
+}
diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go
index 30095c8..8676b36 100644
--- a/internal/api/elaborate.go
+++ b/internal/api/elaborate.go
@@ -172,7 +172,7 @@ func readProjectContext(workDir string) string {
return ""
}
var sb strings.Builder
- for _, filename := range []string{"CLAUDE.md", "SESSION_STATE.md"} {
+ for _, filename := range []string{"CLAUDE.md", ".agent/worklog.md"} {
path := filepath.Join(workDir, filename)
if data, err := os.ReadFile(path); err == nil {
if sb.Len() > 0 {
@@ -303,6 +303,197 @@ func (s *Server) elaborateWithGemini(ctx context.Context, workDir, fullPrompt st
return &result, nil
}
+// elaboratedStorySubtask is a leaf unit within a story task.
+type elaboratedStorySubtask struct {
+ Name string `json:"name"`
+ Instructions string `json:"instructions"`
+}
+
+// elaboratedStoryTask is one independently-deployable unit in a story plan.
+type elaboratedStoryTask struct {
+ Name string `json:"name"`
+ Instructions string `json:"instructions"`
+ AcceptanceCriteria string `json:"acceptance_criteria"`
+ Subtasks []elaboratedStorySubtask `json:"subtasks"`
+}
+
+// elaboratedStoryValidation describes how to verify the story was successful.
+type elaboratedStoryValidation struct {
+ Type string `json:"type"`
+ Steps []string `json:"steps"`
+ SuccessCriteria string `json:"success_criteria"`
+}
+
+// elaboratedStory is the full implementation plan produced by story elaboration.
+type elaboratedStory struct {
+ Name string `json:"name"`
+ BranchName string `json:"branch_name"`
+ Tasks []elaboratedStoryTask `json:"tasks"`
+ Validation elaboratedStoryValidation `json:"validation"`
+}
+
+func buildStoryElaboratePrompt() string {
+ return `You are a software architect. Given a goal, analyze the codebase at /workspace and produce a structured implementation plan as JSON.
+
+Output ONLY valid JSON matching this schema:
+{
+ "name": "story name",
+ "branch_name": "story/kebab-case-name",
+ "tasks": [
+ {
+ "name": "task name",
+ "instructions": "detailed instructions including file paths and what to change",
+ "acceptance_criteria": "specific, verifiable conditions a separate reviewer can check — e.g. 'run go test ./... and verify all pass; confirm GET /api/foo returns 200 with expected JSON shape'",
+ "subtasks": [
+ { "name": "subtask name", "instructions": "..." }
+ ]
+ }
+ ],
+ "validation": {
+ "type": "build|test|smoke",
+ "steps": ["step1", "step2"],
+ "success_criteria": "what success looks like"
+ }
+}
+
+Rules:
+- Tasks must be independently buildable (each can be deployed alone)
+- Subtasks within a task are order-dependent and run sequentially
+- Instructions must include specific file paths, function names, and exact changes
+- Instructions must end with: git add -A && git commit -m "..." && git push origin <branch>
+- Validation should match the scope: small change = build check; new feature = smoke test
+- acceptance_criteria must be concrete and verifiable by a separate agent — no vague assertions like "code looks good"`
+}
+
+func (s *Server) elaborateStoryWithClaude(ctx context.Context, workDir, goal string) (*elaboratedStory, error) {
+ cmd := exec.CommandContext(ctx, s.claudeBinaryPath(),
+ "-p", goal,
+ "--system-prompt", buildStoryElaboratePrompt(),
+ "--output-format", "json",
+ "--model", "haiku",
+ )
+ if workDir != "" {
+ cmd.Dir = workDir
+ }
+
+ var stdout, stderr bytes.Buffer
+ cmd.Stdout = &stdout
+ cmd.Stderr = &stderr
+
+ err := cmd.Run()
+
+ output := stdout.Bytes()
+ if len(output) == 0 {
+ if err != nil {
+ return nil, fmt.Errorf("claude failed: %w (stderr: %s)", err, stderr.String())
+ }
+ return nil, fmt.Errorf("claude returned no output")
+ }
+
+ var wrapper claudeJSONResult
+ if jerr := json.Unmarshal(output, &wrapper); jerr != nil {
+ return nil, fmt.Errorf("failed to parse claude JSON wrapper: %w (output: %s)", jerr, string(output))
+ }
+ if wrapper.IsError {
+ return nil, fmt.Errorf("claude error: %s", wrapper.Result)
+ }
+
+ var result elaboratedStory
+ if jerr := json.Unmarshal([]byte(extractJSON(wrapper.Result)), &result); jerr != nil {
+ return nil, fmt.Errorf("failed to parse elaborated story JSON: %w (result: %s)", jerr, wrapper.Result)
+ }
+ return &result, nil
+}
+
+func (s *Server) elaborateStoryWithGemini(ctx context.Context, workDir, goal string) (*elaboratedStory, error) {
+ combinedPrompt := fmt.Sprintf("%s\n\n%s", buildStoryElaboratePrompt(), goal)
+ cmd := exec.CommandContext(ctx, s.geminiBinaryPath(),
+ "-p", combinedPrompt,
+ "--output-format", "json",
+ "--model", "gemini-2.5-flash-lite",
+ )
+ if workDir != "" {
+ cmd.Dir = workDir
+ }
+
+ var stdout, stderr bytes.Buffer
+ cmd.Stdout = &stdout
+ cmd.Stderr = &stderr
+
+ if err := cmd.Run(); err != nil {
+ return nil, fmt.Errorf("gemini failed: %w (stderr: %s)", err, stderr.String())
+ }
+
+ var wrapper geminiJSONResult
+ if err := json.Unmarshal(stdout.Bytes(), &wrapper); err != nil {
+ return nil, fmt.Errorf("failed to parse gemini JSON wrapper: %w (output: %s)", err, stdout.String())
+ }
+
+ var result elaboratedStory
+ if err := json.Unmarshal([]byte(extractJSON(wrapper.Response)), &result); err != nil {
+ return nil, fmt.Errorf("failed to parse elaborated story JSON: %w (response: %s)", err, wrapper.Response)
+ }
+ return &result, nil
+}
+
+func (s *Server) handleElaborateStory(w http.ResponseWriter, r *http.Request) {
+ var input struct {
+ Goal string `json:"goal"`
+ ProjectID string `json:"project_id"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if input.Goal == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "goal is required"})
+ return
+ }
+ if input.ProjectID == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "project_id is required"})
+ return
+ }
+
+ proj, err := s.store.GetProject(input.ProjectID)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "project not found"})
+ return
+ }
+
+ // Update git refs without modifying the working tree.
+ if proj.LocalPath != "" {
+ gitCmd := exec.Command("git", "-C", proj.LocalPath, "fetch", "origin")
+ if err := gitCmd.Run(); err != nil {
+ s.logger.Warn("story elaborate: git fetch failed", "error", err, "path", proj.LocalPath)
+ }
+ }
+
+ ctx, cancel := context.WithTimeout(r.Context(), elaborateTimeout)
+ defer cancel()
+
+ result, err := s.elaborateStoryWithClaude(ctx, proj.LocalPath, input.Goal)
+ if err != nil {
+ s.logger.Warn("story elaborate: claude failed, falling back to gemini", "error", err)
+ result, err = s.elaborateStoryWithGemini(ctx, proj.LocalPath, input.Goal)
+ if err != nil {
+ s.logger.Error("story elaborate: fallback gemini also failed", "error", err)
+ writeJSON(w, http.StatusBadGateway, map[string]string{
+ "error": fmt.Sprintf("elaboration failed: %v", err),
+ })
+ return
+ }
+ }
+
+ if result.Name == "" {
+ writeJSON(w, http.StatusBadGateway, map[string]string{
+ "error": "elaboration failed: missing required fields in response",
+ })
+ return
+ }
+
+ writeJSON(w, http.StatusOK, result)
+}
+
func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) {
if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) {
writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
@@ -310,7 +501,9 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) {
}
var input struct {
- Prompt string `json:"prompt"`
+ Prompt string `json:"prompt"`
+ ProjectID string `json:"project_id"`
+ // project_dir kept for backward compat; project_id takes precedence
ProjectDir string `json:"project_dir"`
}
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@@ -323,11 +516,15 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) {
}
workDir := s.workDir
- if input.ProjectDir != "" {
+ if input.ProjectID != "" {
+ if proj, err := s.store.GetProject(input.ProjectID); err == nil {
+ workDir = proj.LocalPath
+ }
+ } else if input.ProjectDir != "" {
workDir = input.ProjectDir
}
- if input.ProjectDir != "" {
+ if workDir != s.workDir {
go s.appendRawNarrative(workDir, input.Prompt)
}
diff --git a/internal/api/elaborate_test.go b/internal/api/elaborate_test.go
index 0b5c706..32cec3c 100644
--- a/internal/api/elaborate_test.go
+++ b/internal/api/elaborate_test.go
@@ -350,6 +350,8 @@ func TestElaborateTask_InvalidJSONFromClaude(t *testing.T) {
// Fake Claude returns something that is not valid JSON.
srv.elaborateCmdPath = createFakeClaude(t, "not valid json at all", 0)
+ // Ensure Gemini fallback also fails so we get the expected 502.
+ srv.geminiBinPath = "/nonexistent/gemini"
body := `{"prompt":"do something"}`
req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(body))
@@ -388,14 +390,17 @@ func createFakeClaudeCapturingArgs(t *testing.T, output string, exitCode int, ar
func TestElaborateTask_WithProjectContext(t *testing.T) {
srv, _ := testServer(t)
- // Create a temporary workspace with CLAUDE.md and SESSION_STATE.md
+ // Create a temporary workspace with CLAUDE.md and .agent/worklog.md
workDir := t.TempDir()
claudeContent := "Claude context info"
sessionContent := "Session state info"
if err := os.WriteFile(filepath.Join(workDir, "CLAUDE.md"), []byte(claudeContent), 0600); err != nil {
t.Fatal(err)
}
- if err := os.WriteFile(filepath.Join(workDir, "SESSION_STATE.md"), []byte(sessionContent), 0600); err != nil {
+ if err := os.MkdirAll(filepath.Join(workDir, ".agent"), 0700); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.WriteFile(filepath.Join(workDir, ".agent", "worklog.md"), []byte(sessionContent), 0600); err != nil {
t.Fatal(err)
}
@@ -436,7 +441,7 @@ func TestElaborateTask_WithProjectContext(t *testing.T) {
t.Errorf("expected arguments to contain CLAUDE.md content, got %s", argsStr)
}
if !strings.Contains(argsStr, sessionContent) {
- t.Errorf("expected arguments to contain SESSION_STATE.md content, got %s", argsStr)
+ t.Errorf("expected arguments to contain .agent/worklog.md content, got %s", argsStr)
}
}
diff --git a/internal/api/executions.go b/internal/api/executions.go
index 114425e..4d8ba9c 100644
--- a/internal/api/executions.go
+++ b/internal/api/executions.go
@@ -86,6 +86,48 @@ func (s *Server) handleGetExecutionLog(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, content)
}
+// handleGetDashboardStats returns pre-aggregated error, throughput, and billing stats.
+// GET /api/stats?window=7d|24h
+func (s *Server) handleGetDashboardStats(w http.ResponseWriter, r *http.Request) {
+ window := 7 * 24 * time.Hour
+ if r.URL.Query().Get("window") == "24h" {
+ window = 24 * time.Hour
+ }
+ since := time.Now().Add(-window)
+
+ stats, err := s.store.QueryDashboardStats(since)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusOK, stats)
+}
+
+// handleGetAgentStatus returns the current status of all agents and recent rate-limit events.
+// GET /api/agents/status?since=<RFC3339>
+func (s *Server) handleGetAgentStatus(w http.ResponseWriter, r *http.Request) {
+ since := time.Now().Add(-24 * time.Hour)
+ if v := r.URL.Query().Get("since"); v != "" {
+ if t, err := time.Parse(time.RFC3339, v); err == nil {
+ since = t
+ }
+ }
+
+ events, err := s.store.ListAgentEvents(since)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ if events == nil {
+ events = []storage.AgentEvent{}
+ }
+
+ writeJSON(w, http.StatusOK, map[string]interface{}{
+ "agents": s.pool.AgentStatuses(),
+ "events": events,
+ })
+}
+
// tailLogFile reads the last n lines from the file at path.
func tailLogFile(path string, n int) (string, error) {
data, err := os.ReadFile(path)
diff --git a/internal/api/projects.go b/internal/api/projects.go
new file mode 100644
index 0000000..d3dbbf9
--- /dev/null
+++ b/internal/api/projects.go
@@ -0,0 +1,71 @@
+package api
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/google/uuid"
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+func (s *Server) handleListProjects(w http.ResponseWriter, r *http.Request) {
+ projects, err := s.store.ListProjects()
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ if projects == nil {
+ projects = []*task.Project{}
+ }
+ writeJSON(w, http.StatusOK, projects)
+}
+
+func (s *Server) handleCreateProject(w http.ResponseWriter, r *http.Request) {
+ var p task.Project
+ if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if p.Name == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"})
+ return
+ }
+ if p.ID == "" {
+ p.ID = uuid.New().String()
+ }
+ if err := s.store.CreateProject(&p); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusCreated, p)
+}
+
+func (s *Server) handleGetProject(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ p, err := s.store.GetProject(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "project not found"})
+ return
+ }
+ writeJSON(w, http.StatusOK, p)
+}
+
+func (s *Server) handleUpdateProject(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ existing, err := s.store.GetProject(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "project not found"})
+ return
+ }
+ if err := json.NewDecoder(r.Body).Decode(existing); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ existing.ID = id // ensure ID cannot be changed via body
+ if err := s.store.UpdateProject(existing); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusOK, existing)
+}
+
diff --git a/internal/api/push.go b/internal/api/push.go
new file mode 100644
index 0000000..dde5441
--- /dev/null
+++ b/internal/api/push.go
@@ -0,0 +1,120 @@
+package api
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/google/uuid"
+ "github.com/thepeterstone/claudomator/internal/storage"
+ webui "github.com/thepeterstone/claudomator/web"
+)
+
+// pushSubscriptionStore is the minimal interface needed by push handlers.
+type pushSubscriptionStore interface {
+ SavePushSubscription(sub storage.PushSubscription) error
+ DeletePushSubscription(endpoint string) error
+ ListPushSubscriptions() ([]storage.PushSubscription, error)
+}
+
+// SetVAPIDConfig configures VAPID keys and email for web push notifications.
+func (s *Server) SetVAPIDConfig(pub, priv, email string) {
+ s.vapidPublicKey = pub
+ s.vapidPrivateKey = priv
+ s.vapidEmail = email
+}
+
+// SetPushStore configures the push subscription store.
+func (s *Server) SetPushStore(store pushSubscriptionStore) {
+ s.pushStore = store
+}
+
+// SetDropsDir configures the file drop directory.
+func (s *Server) SetDropsDir(dir string) {
+ s.dropsDir = dir
+}
+
+// handleGetVAPIDKey returns the VAPID public key for client-side push subscription.
+func (s *Server) handleGetVAPIDKey(w http.ResponseWriter, r *http.Request) {
+ writeJSON(w, http.StatusOK, map[string]string{"public_key": s.vapidPublicKey})
+}
+
+// handleServiceWorker serves sw.js with a Service-Worker-Allowed: / header so
+// the SW can control the full origin even though it is registered from /api/push/sw.js.
+func (s *Server) handleServiceWorker(w http.ResponseWriter, r *http.Request) {
+ data, err := webui.Files.ReadFile("sw.js")
+ if err != nil {
+ http.Error(w, "service worker not found", http.StatusNotFound)
+ return
+ }
+ w.Header().Set("Content-Type", "application/javascript")
+ w.Header().Set("Service-Worker-Allowed", "/")
+ w.WriteHeader(http.StatusOK)
+ w.Write(data)
+}
+
+// handlePushSubscribe saves a new push subscription.
+func (s *Server) handlePushSubscribe(w http.ResponseWriter, r *http.Request) {
+ var input struct {
+ Endpoint string `json:"endpoint"`
+ Keys struct {
+ P256DH string `json:"p256dh"`
+ Auth string `json:"auth"`
+ } `json:"keys"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if input.Endpoint == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "endpoint is required"})
+ return
+ }
+ if input.Keys.P256DH == "" || input.Keys.Auth == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "keys.p256dh and keys.auth are required"})
+ return
+ }
+
+ sub := storage.PushSubscription{
+ ID: uuid.New().String(),
+ Endpoint: input.Endpoint,
+ P256DHKey: input.Keys.P256DH,
+ AuthKey: input.Keys.Auth,
+ }
+
+ store := s.pushStore
+ if store == nil {
+ store = s.store
+ }
+
+ if err := store.SavePushSubscription(sub); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusCreated, map[string]string{"id": sub.ID})
+}
+
+// handlePushUnsubscribe deletes a push subscription.
+func (s *Server) handlePushUnsubscribe(w http.ResponseWriter, r *http.Request) {
+ var input struct {
+ Endpoint string `json:"endpoint"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if input.Endpoint == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "endpoint is required"})
+ return
+ }
+
+ store := s.pushStore
+ if store == nil {
+ store = s.store
+ }
+
+ if err := store.DeletePushSubscription(input.Endpoint); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ w.WriteHeader(http.StatusNoContent)
+}
diff --git a/internal/api/push_test.go b/internal/api/push_test.go
new file mode 100644
index 0000000..dfd5a3a
--- /dev/null
+++ b/internal/api/push_test.go
@@ -0,0 +1,159 @@
+package api
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "testing"
+
+ "github.com/thepeterstone/claudomator/internal/storage"
+)
+
+// mockPushStore implements pushSubscriptionStore for testing.
+type mockPushStore struct {
+ mu sync.Mutex
+ subs []storage.PushSubscription
+}
+
+func (m *mockPushStore) SavePushSubscription(sub storage.PushSubscription) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ // Upsert by endpoint.
+ for i, s := range m.subs {
+ if s.Endpoint == sub.Endpoint {
+ m.subs[i] = sub
+ return nil
+ }
+ }
+ m.subs = append(m.subs, sub)
+ return nil
+}
+
+func (m *mockPushStore) DeletePushSubscription(endpoint string) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ filtered := m.subs[:0]
+ for _, s := range m.subs {
+ if s.Endpoint != endpoint {
+ filtered = append(filtered, s)
+ }
+ }
+ m.subs = filtered
+ return nil
+}
+
+func (m *mockPushStore) ListPushSubscriptions() ([]storage.PushSubscription, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ cp := make([]storage.PushSubscription, len(m.subs))
+ copy(cp, m.subs)
+ return cp, nil
+}
+
+func testServerWithPush(t *testing.T) (*Server, *mockPushStore) {
+ t.Helper()
+ srv, _ := testServer(t)
+ ps := &mockPushStore{}
+ srv.SetVAPIDConfig("testpub", "testpriv", "mailto:test@example.com")
+ srv.SetPushStore(ps)
+ return srv, ps
+}
+
+func TestHandleGetVAPIDKey(t *testing.T) {
+ srv, _ := testServerWithPush(t)
+
+ req := httptest.NewRequest("GET", "/api/push/vapid-key", nil)
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d", rec.Code)
+ }
+
+ var resp map[string]string
+ if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if resp["public_key"] != "testpub" {
+ t.Errorf("want public_key %q, got %q", "testpub", resp["public_key"])
+ }
+}
+
+func TestHandlePushSubscribe_CreatesSub(t *testing.T) {
+ srv, ps := testServerWithPush(t)
+
+ body := `{"endpoint":"https://push.example.com/sub1","keys":{"p256dh":"key1","auth":"auth1"}}`
+ req := httptest.NewRequest("POST", "/api/push/subscribe", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusCreated {
+ t.Fatalf("want 201, got %d: %s", rec.Code, rec.Body.String())
+ }
+
+ subs, _ := ps.ListPushSubscriptions()
+ if len(subs) != 1 {
+ t.Fatalf("want 1 subscription, got %d", len(subs))
+ }
+ if subs[0].Endpoint != "https://push.example.com/sub1" {
+ t.Errorf("endpoint: want %q, got %q", "https://push.example.com/sub1", subs[0].Endpoint)
+ }
+}
+
+func TestHandlePushSubscribe_MissingEndpoint(t *testing.T) {
+ srv, _ := testServerWithPush(t)
+
+ body := `{"keys":{"p256dh":"key1","auth":"auth1"}}`
+ req := httptest.NewRequest("POST", "/api/push/subscribe", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("want 400, got %d", rec.Code)
+ }
+}
+
+func TestHandlePushUnsubscribe_DeletesSub(t *testing.T) {
+ srv, ps := testServerWithPush(t)
+
+ // Add a subscription.
+ ps.SavePushSubscription(storage.PushSubscription{ //nolint:errcheck
+ ID: "sub-1",
+ Endpoint: "https://push.example.com/todelete",
+ P256DHKey: "key",
+ AuthKey: "auth",
+ })
+
+ body := `{"endpoint":"https://push.example.com/todelete"}`
+ req := httptest.NewRequest("DELETE", "/api/push/subscribe", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusNoContent {
+ t.Fatalf("want 204, got %d: %s", rec.Code, rec.Body.String())
+ }
+
+ subs, _ := ps.ListPushSubscriptions()
+ if len(subs) != 0 {
+ t.Errorf("want 0 subscriptions after delete, got %d", len(subs))
+ }
+}
+
+func TestHandlePushUnsubscribe_MissingEndpoint(t *testing.T) {
+ srv, _ := testServerWithPush(t)
+
+ body := `{}`
+ req := httptest.NewRequest("DELETE", "/api/push/subscribe", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ srv.mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("want 400, got %d", rec.Code)
+ }
+}
diff --git a/internal/api/server.go b/internal/api/server.go
index 33048e4..28cfe4a 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -16,6 +16,7 @@ import (
"github.com/thepeterstone/claudomator/internal/notify"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
+ "github.com/thepeterstone/claudomator/internal/version"
webui "github.com/thepeterstone/claudomator/web"
"github.com/google/uuid"
)
@@ -31,6 +32,7 @@ type questionStore interface {
// Server provides the REST API and WebSocket endpoint for Claudomator.
type Server struct {
+ ctx context.Context // server lifecycle context; used for pool submissions
store *storage.DB
logStore logStore // injectable for tests; defaults to store
taskLogStore taskLogStore // injectable for tests; defaults to store
@@ -51,7 +53,12 @@ type Server struct {
elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints
webhookSecret string // HMAC-SHA256 secret for GitHub webhook validation
projects []config.Project // configured projects for webhook routing
- llm *llm.Client // optional local LLM client; when set, elaboration prefers it
+ vapidPublicKey string
+ vapidPrivateKey string
+ vapidEmail string
+ pushStore pushSubscriptionStore
+ dropsDir string
+ llm *llm.Client
}
// SetAPIToken configures a bearer token that must be supplied to access the API.
@@ -59,6 +66,12 @@ func (s *Server) SetAPIToken(token string) {
s.apiToken = token
}
+// SetContext replaces the server's lifecycle context used for pool submissions.
+// Call this before StartHub to tie task submissions to the server's shutdown signal.
+func (s *Server) SetContext(ctx context.Context) {
+ s.ctx = ctx
+}
+
// SetNotifier configures a notifier that is called on every task completion.
func (s *Server) SetNotifier(n notify.Notifier) {
s.notifier = n
@@ -75,6 +88,9 @@ func (s *Server) SetWorkspaceRoot(path string) {
s.workspaceRoot = path
}
+// Pool returns the executor pool, for graceful shutdown by the caller.
+func (s *Server) Pool() *executor.Pool { return s.pool }
+
// SetLLM wires a local OpenAI-compatible LLM client for use by elaboration
// (and future internal helpers). When non-nil, elaboration will prefer it
// over the Claude CLI; on failure it falls back to claude → gemini.
@@ -82,9 +98,11 @@ func (s *Server) SetLLM(c *llm.Client) {
s.llm = c
}
+
func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server {
wd, _ := os.Getwd()
s := &Server{
+ ctx: context.Background(),
store: store,
logStore: store,
taskLogStore: store,
@@ -125,6 +143,8 @@ func (s *Server) routes() {
s.mux.HandleFunc("GET /api/tasks/{id}/subtasks", s.handleListSubtasks)
s.mux.HandleFunc("GET /api/tasks/{id}/executions", s.handleListExecutions)
s.mux.HandleFunc("GET /api/executions", s.handleListRecentExecutions)
+ s.mux.HandleFunc("GET /api/stats", s.handleGetDashboardStats)
+ s.mux.HandleFunc("GET /api/agents/status", s.handleGetAgentStatus)
s.mux.HandleFunc("GET /api/executions/{id}", s.handleGetExecution)
s.mux.HandleFunc("GET /api/executions/{id}/log", s.handleGetExecutionLog)
s.mux.HandleFunc("GET /api/tasks/{id}/logs/stream", s.handleStreamTaskLogs)
@@ -135,29 +155,53 @@ func (s *Server) routes() {
s.mux.HandleFunc("GET /api/ws", s.handleWebSocket)
s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces)
s.mux.HandleFunc("GET /api/tasks/{id}/deployment-status", s.handleGetDeploymentStatus)
+ s.mux.HandleFunc("GET /api/projects", s.handleListProjects)
+ s.mux.HandleFunc("POST /api/projects", s.handleCreateProject)
+ s.mux.HandleFunc("GET /api/projects/{id}", s.handleGetProject)
+ s.mux.HandleFunc("PUT /api/projects/{id}", s.handleUpdateProject)
+ s.mux.HandleFunc("POST /api/stories/elaborate", s.handleElaborateStory)
+ s.mux.HandleFunc("POST /api/stories/approve", s.handleApproveStory)
+ s.mux.HandleFunc("GET /api/stories", s.handleListStories)
+ s.mux.HandleFunc("POST /api/stories", s.handleCreateStory)
+ s.mux.HandleFunc("GET /api/stories/{id}", s.handleGetStory)
+ s.mux.HandleFunc("GET /api/stories/{id}/tasks", s.handleListStoryTasks)
+ s.mux.HandleFunc("POST /api/stories/{id}/tasks", s.handleAddTaskToStory)
+ s.mux.HandleFunc("PUT /api/stories/{id}/status", s.handleUpdateStoryStatus)
+ s.mux.HandleFunc("POST /api/stories/{id}/ship", s.handleShipStory)
+ s.mux.HandleFunc("GET /api/stories/{id}/deployment-status", s.handleStoryDeploymentStatus)
s.mux.HandleFunc("GET /api/health", s.handleHealth)
+ s.mux.HandleFunc("GET /api/version", s.handleVersion)
s.mux.HandleFunc("POST /api/webhooks/github", s.handleGitHubWebhook)
+ s.mux.HandleFunc("GET /api/push/vapid-key", s.handleGetVAPIDKey)
+ s.mux.HandleFunc("GET /api/push/sw.js", s.handleServiceWorker)
+ s.mux.HandleFunc("POST /api/push/subscribe", s.handlePushSubscribe)
+ s.mux.HandleFunc("DELETE /api/push/subscribe", s.handlePushUnsubscribe)
+ s.mux.HandleFunc("GET /api/drops", s.handleListDrops)
+ s.mux.HandleFunc("GET /api/drops/{filename}", s.handleGetDrop)
+ s.mux.HandleFunc("POST /api/drops", s.handlePostDrop)
s.mux.Handle("GET /", http.FileServerFS(webui.Files))
}
-// forwardResults listens on the executor pool's result channel and broadcasts via WebSocket.
+// forwardResults listens on the executor pool's result and started channels and broadcasts via WebSocket.
func (s *Server) forwardResults() {
+ go func() {
+ for taskID := range s.pool.Started() {
+ event := map[string]interface{}{
+ "type": "task_started",
+ "task_id": taskID,
+ "timestamp": time.Now().UTC(),
+ }
+ data, _ := json.Marshal(event)
+ s.hub.Broadcast(data)
+ }
+ }()
for result := range s.pool.Results() {
s.processResult(result)
}
}
// processResult broadcasts a task completion event via WebSocket and calls the notifier if set.
-// It also parses git diff stats from the execution stdout log and persists them.
func (s *Server) processResult(result *executor.Result) {
- if result.Execution.StdoutPath != "" {
- if stats := parseChangestatFromFile(result.Execution.StdoutPath); stats != nil {
- if err := s.store.UpdateExecutionChangestats(result.Execution.ID, stats); err != nil {
- s.logger.Error("failed to store changestats", "execID", result.Execution.ID, "error", err)
- }
- }
- }
-
event := map[string]interface{}{
"type": "task_completed",
"task_id": result.TaskID,
@@ -318,7 +362,7 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
ResumeAnswer: input.Answer,
SandboxDir: latest.SandboxDir,
}
- if err := s.pool.SubmitResume(context.Background(), tk, resumeExec); err != nil {
+ if err := s.pool.SubmitResume(s.ctx, tk, resumeExec); err != nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": err.Error()})
return
}
@@ -363,7 +407,7 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request
ResumeSessionID: latest.SessionID,
ResumeAnswer: resumeMsg,
}
- if err := s.pool.SubmitResume(context.Background(), tk, resumeExec); err != nil {
+ if err := s.pool.SubmitResume(s.ctx, tk, resumeExec); err != nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": err.Error()})
return
}
@@ -415,11 +459,17 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
})
}
+func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request) {
+ writeJSON(w, http.StatusOK, map[string]string{"version": version.Version()})
+}
+
func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) {
var input struct {
Name string `json:"name"`
Description string `json:"description"`
ElaborationInput string `json:"elaboration_input"`
+ Project string `json:"project"`
+ RepositoryURL string `json:"repository_url"`
Agent task.AgentConfig `json:"agent"`
Claude task.AgentConfig `json:"claude"` // legacy alias
Timeout string `json:"timeout"`
@@ -443,6 +493,8 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) {
Name: input.Name,
Description: input.Description,
ElaborationInput: input.ElaborationInput,
+ Project: input.Project,
+ RepositoryURL: input.RepositoryURL,
Agent: input.Agent,
Priority: task.Priority(input.Priority),
Tags: input.Tags,
@@ -453,6 +505,7 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) {
UpdatedAt: now,
ParentTaskID: input.ParentTaskID,
}
+
if t.Agent.Type == "" {
t.Agent.Type = "claude"
}
@@ -523,7 +576,11 @@ func (s *Server) handleListTasks(w http.ResponseWriter, r *http.Request) {
if tasks == nil {
tasks = []*task.Task{}
}
- writeJSON(w, http.StatusOK, tasks)
+ views := make([]*taskView, len(tasks))
+ for i, tk := range tasks {
+ views[i] = s.enrichTask(tk)
+ }
+ writeJSON(w, http.StatusOK, views)
}
func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) {
@@ -533,8 +590,43 @@ func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
return
}
- writeJSON(w, http.StatusOK, t)
+ writeJSON(w, http.StatusOK, s.enrichTask(t))
+}
+// retryableDepStates are the states from which a dependency can be retried
+// when cascading a retry from a dependent task.
+var retryableDepStates = map[task.State]bool{
+ task.StateFailed: true,
+ task.StateTimedOut: true,
+ task.StateCancelled: true,
+ task.StateBudgetExceeded: true,
+}
+
+// cascadeRetryDeps resets any dependency (recursively) that is in a retryable
+// terminal state, and submits it to the pool. This ensures that retrying a
+// CANCELLED task that was blocked by a failed dep will also restart that dep.
+func (s *Server) cascadeRetryDeps(ctx context.Context, t *task.Task) {
+ for _, depID := range t.DependsOn {
+ dep, err := s.store.GetTask(depID)
+ if err != nil {
+ s.logger.Warn("cascadeRetryDeps: dep not found", "depID", depID)
+ continue
+ }
+ if !retryableDepStates[dep.State] {
+ continue
+ }
+ // Recursively cascade first (depth-first so root deps go first).
+ s.cascadeRetryDeps(ctx, dep)
+ reset, err := s.store.ResetTaskForRetry(depID)
+ if err != nil {
+ s.logger.Warn("cascadeRetryDeps: reset failed", "depID", depID, "error", err)
+ continue
+ }
+ if submitErr := s.pool.Submit(ctx, reset); submitErr != nil {
+ s.logger.Warn("cascadeRetryDeps: submit failed", "depID", depID, "error", submitErr)
+ }
+ }
}
+
func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
agentParam := r.URL.Query().Get("agent") // Use a different name to avoid confusion
@@ -583,7 +675,11 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
}
// The task `t` now has the correct agent configuration.
- if err := s.pool.Submit(context.Background(), t); err != nil {
+ // 6. Cascade-retry any deps that are in a terminal failure state so the
+ // task isn't immediately re-cancelled by checkDepsReady.
+ s.cascadeRetryDeps(r.Context(), originalTask)
+
+ if err := s.pool.Submit(s.ctx, t); err != nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)})
return
}
@@ -611,6 +707,9 @@ func (s *Server) handleAcceptTask(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
+ if t.StoryID != "" {
+ go s.pool.CheckStoryCompletion(r.Context(), t.StoryID)
+ }
writeJSON(w, http.StatusOK, map[string]string{"message": "task accepted", "task_id": id})
}
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index 2139e36..2530d55 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -16,6 +16,7 @@ import (
"context"
+ "github.com/google/uuid"
"github.com/thepeterstone/claudomator/internal/executor"
"github.com/thepeterstone/claudomator/internal/notify"
"github.com/thepeterstone/claudomator/internal/storage"
@@ -89,6 +90,9 @@ func testServerWithRunner(t *testing.T, runner executor.Runner) (*Server, *stora
t.Cleanup(func() { store.Close() })
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ if mr, ok := runner.(*mockRunner); ok {
+ mr.logDir = t.TempDir()
+ }
runners := map[string]executor.Runner{
"claude": runner,
"gemini": runner,
@@ -99,11 +103,39 @@ func testServerWithRunner(t *testing.T, runner executor.Runner) (*Server, *stora
}
type mockRunner struct {
- err error
- sleep time.Duration
+ err error
+ sleep time.Duration
+ logDir string
+ onRun func(*task.Task, *storage.Execution) error
+}
+
+func (m *mockRunner) ExecLogDir(execID string) string {
+ if m.logDir == "" {
+ return ""
+ }
+ return filepath.Join(m.logDir, execID)
}
-func (m *mockRunner) Run(ctx context.Context, _ *task.Task, _ *storage.Execution) error {
+func (m *mockRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
+ if e.ID == "" {
+ e.ID = uuid.New().String()
+ }
+ if m.logDir != "" {
+ dir := m.ExecLogDir(e.ID)
+ if err := os.MkdirAll(dir, 0755); err != nil {
+ return err
+ }
+ e.StdoutPath = filepath.Join(dir, "stdout.log")
+ e.StderrPath = filepath.Join(dir, "stderr.log")
+ e.ArtifactDir = dir
+ // Create an empty file at least
+ os.WriteFile(e.StdoutPath, []byte(""), 0644)
+ }
+ if m.onRun != nil {
+ if err := m.onRun(t, e); err != nil {
+ return err
+ }
+ }
if m.sleep > 0 {
select {
case <-time.After(m.sleep):
@@ -143,41 +175,26 @@ func testServerWithGeminiMockRunner(t *testing.T) (*Server, *storage.DB) {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
- // Create the mock gemini binary script. Use single-quoted heredoc so
- // bash does not try to evaluate the literal backticks as command
- // substitution.
- mockBinDir := t.TempDir()
- mockGeminiPath := filepath.Join(mockBinDir, "mock-gemini-binary.sh")
- mockScriptContent := `#!/bin/bash
-cat <<'EOF'
-` + "```json" + `
-{"type":"content_block_start","content_block":{"text":"Hello, Gemini!","type":"text"}}
-{"type":"content_block_delta","content_block":{"text":" How are you?"}}
-{"type":"content_block_end"}
-{"type":"message_delta","message":{"role":"model"}}
-{"type":"message_end"}
-` + "```" + `
-EOF
-exit 0
-`
- if err := os.WriteFile(mockGeminiPath, []byte(mockScriptContent), 0755); err != nil {
- t.Fatalf("writing mock gemini script: %v", err)
- }
-
- // Configure GeminiRunner to use the mock script.
- geminiRunner := &executor.GeminiRunner{
- BinaryPath: mockGeminiPath,
- Logger: logger,
- LogDir: t.TempDir(), // Ensure log directory is temporary for test
- APIURL: "http://localhost:8080", // Placeholder, not used by this mock
+ mr := &mockRunner{
+ logDir: t.TempDir(),
+ onRun: func(t *task.Task, e *storage.Execution) error {
+ lines := []string{
+ `{"type":"content_block_start","content_block":{"text":"Hello, Gemini!","type":"text"}}`,
+ `{"type":"content_block_delta","content_block":{"text":" How are you?"}}`,
+ `{"type":"content_block_end"}`,
+ `{"type":"message_delta","message":{"role":"model"}}`,
+ `{"type":"message_end"}`,
+ }
+ return os.WriteFile(e.StdoutPath, []byte(strings.Join(lines, "\n")), 0644)
+ },
}
runners := map[string]executor.Runner{
- "claude": &mockRunner{}, // Keep mock for claude to not interfere
- "gemini": geminiRunner,
+ "claude": mr,
+ "gemini": mr,
}
pool := executor.NewPool(2, runners, store, logger)
- srv := NewServer(store, pool, logger, "claude", "gemini") // Pass original binary paths
+ srv := NewServer(store, pool, logger, "claude", "gemini")
return srv, store
}
@@ -200,6 +217,7 @@ func TestGeminiLogs_ParsedCorrectly(t *testing.T) {
tk := createTestTask(t, srv, `{
"name": "Gemini Log Test Task",
"description": "Test Gemini log parsing",
+ "repository_url": "https://github.com/user/repo",
"agent": {
"type": "gemini",
"instructions": "generate some output",
@@ -346,6 +364,7 @@ func TestCreateTask_Success(t *testing.T) {
payload := `{
"name": "API Task",
"description": "Created via API",
+ "repository_url": "https://github.com/user/repo",
"agent": {
"type": "claude",
"instructions": "do the thing",
@@ -399,6 +418,50 @@ func TestCreateTask_ValidationFailure(t *testing.T) {
}
}
+func TestProject_RoundTrip(t *testing.T) {
+ srv, _ := testServer(t)
+
+ payload := `{
+ "name": "Project Task",
+ "project": "test-project",
+ "repository_url": "https://github.com/user/repo",
+ "agent": {
+ "type": "claude",
+ "instructions": "do the thing",
+ "model": "sonnet"
+ }
+ }`
+ req := httptest.NewRequest("POST", "/api/tasks", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusCreated {
+ t.Fatalf("create: want 201, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ var created task.Task
+ json.NewDecoder(w.Body).Decode(&created)
+ if created.Project != "test-project" {
+ t.Errorf("create response: project want 'test-project', got %q", created.Project)
+ }
+
+ // GET the task and verify project is persisted
+ req2 := httptest.NewRequest("GET", "/api/tasks/"+created.ID, nil)
+ w2 := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w2, req2)
+
+ if w2.Code != http.StatusOK {
+ t.Fatalf("get: want 200, got %d; body: %s", w2.Code, w2.Body.String())
+ }
+
+ var fetched task.Task
+ json.NewDecoder(w2.Body).Decode(&fetched)
+ if fetched.Project != "test-project" {
+ t.Errorf("get response: project want 'test-project', got %q", fetched.Project)
+ }
+}
+
func TestListTasks_Empty(t *testing.T) {
srv, _ := testServer(t)
@@ -436,6 +499,7 @@ func TestListTasks_WithTasks(t *testing.T) {
for i := 0; i < 3; i++ {
tk := &task.Task{
ID: fmt.Sprintf("lt-%d", i), Name: fmt.Sprintf("T%d", i),
+ RepositoryURL: "https://github.com/user/repo",
Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, Priority: task.PriorityNormal,
Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
Tags: []string{}, DependsOn: []string{}, State: task.StatePending,
@@ -473,6 +537,7 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.
tk := &task.Task{
ID: id,
Name: "test-task-" + id,
+ RepositoryURL: "https://github.com/user/repo",
Agent: task.AgentConfig{Type: "claude", Instructions: "do something"},
Priority: task.PriorityNormal,
Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
@@ -851,6 +916,7 @@ func TestRunTask_ManualRunIgnoresRetryLimit(t *testing.T) {
tk := &task.Task{
ID: "retry-limit-manual",
Name: "Retry Limit Task",
+ RepositoryURL: "https://github.com/user/repo",
Agent: task.AgentConfig{Instructions: "do something"},
Priority: task.PriorityNormal,
Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
@@ -888,6 +954,7 @@ func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) {
tk := &task.Task{
ID: "retry-within-1",
Name: "Retry Within Task",
+ RepositoryURL: "https://github.com/user/repo",
Agent: task.AgentConfig{Instructions: "do something"},
Priority: task.PriorityNormal,
Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"},
@@ -935,7 +1002,7 @@ func TestDeleteTask_Success(t *testing.T) {
srv, store := testServer(t)
// Create a task to delete.
- created := createTestTask(t, srv, `{"name":"Delete Me","agent":{"type":"claude","instructions":"x","model":"sonnet"}}`)
+ created := createTestTask(t, srv, `{"name":"Delete Me","repository_url":"https://github.com/user/repo","agent":{"type":"claude","instructions":"x","model":"sonnet"}}`)
req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil)
w := httptest.NewRecorder()
@@ -970,6 +1037,7 @@ func TestDeleteTask_RunningTaskRejected(t *testing.T) {
tk := &task.Task{
ID: "running-task-del",
Name: "Running Task",
+ RepositoryURL: "https://github.com/user/repo",
Agent: task.AgentConfig{Instructions: "x", Model: "sonnet"},
Priority: task.PriorityNormal,
Tags: []string{},
@@ -1524,6 +1592,7 @@ func TestRunTask_AgentTimesOut_TaskSetToTimedOut(t *testing.T) {
tk := &task.Task{
ID: "async-timeout-1",
Name: "timeout-test",
+ RepositoryURL: "https://github.com/user/repo",
Agent: task.AgentConfig{Type: "claude", Instructions: "do something"},
Priority: task.PriorityNormal,
Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
@@ -1581,34 +1650,31 @@ func TestRunTask_AgentCancelled_TaskSetToCancelled(t *testing.T) {
}
}
-// TestGetTask_IncludesChangestats verifies that after processResult parses git diff stats
-// from the execution stdout log, they appear in the execution history response.
+// TestGetTask_IncludesChangestats verifies that changestats stored on an execution
+// are returned correctly by GET /api/tasks/{id}/executions.
func TestGetTask_IncludesChangestats(t *testing.T) {
srv, store := testServer(t)
tk := createTaskWithState(t, store, "cs-task-1", task.StateCompleted)
- // Write a stdout log with a git diff --stat summary line.
- dir := t.TempDir()
- stdoutPath := filepath.Join(dir, "stdout.log")
- logContent := "Agent output line 1\n3 files changed, 50 insertions(+), 10 deletions(-)\nAgent output line 2\n"
- if err := os.WriteFile(stdoutPath, []byte(logContent), 0600); err != nil {
- t.Fatal(err)
- }
-
exec := &storage.Execution{
- ID: "cs-exec-1",
- TaskID: tk.ID,
- StartTime: time.Now().UTC(),
- EndTime: time.Now().UTC().Add(time.Minute),
- Status: "COMPLETED",
- StdoutPath: stdoutPath,
+ ID: "cs-exec-1",
+ TaskID: tk.ID,
+ StartTime: time.Now().UTC(),
+ EndTime: time.Now().UTC().Add(time.Minute),
+ Status: "COMPLETED",
}
if err := store.CreateExecution(exec); err != nil {
t.Fatal(err)
}
- // processResult should parse changestats from the stdout log and store them.
+ // Pool stores changestats after execution; simulate by calling UpdateExecutionChangestats directly.
+ cs := &task.Changestats{FilesChanged: 3, LinesAdded: 50, LinesRemoved: 10}
+ if err := store.UpdateExecutionChangestats(exec.ID, cs); err != nil {
+ t.Fatal(err)
+ }
+
+ // processResult broadcasts but does NOT parse changestats (that's the pool's job).
result := &executor.Result{
TaskID: tk.ID,
Execution: exec,
@@ -1782,3 +1848,299 @@ func TestDeploymentStatus_NotFound(t *testing.T) {
t.Fatalf("want 404, got %d", w.Code)
}
}
+
+// TestListTasks_ReadyTask_IncludesDeploymentStatus verifies that GET /api/tasks
+// returns a deployment_status field for READY tasks containing deployed_commit,
+// fix_commits, and includes_fix.
+func TestListTasks_ReadyTask_IncludesDeploymentStatus(t *testing.T) {
+ srv, store := testServer(t)
+
+ tk := createTaskWithState(t, store, "enrich-list-ready-1", task.StateReady)
+ exec := &storage.Execution{
+ ID: "enrich-list-exec-1",
+ TaskID: tk.ID,
+ StartTime: time.Now(),
+ EndTime: time.Now(),
+ Status: "COMPLETED",
+ Commits: []task.GitCommit{{Hash: "aabbcc", Message: "fix: list test"}},
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatal(err)
+ }
+
+ req := httptest.NewRequest("GET", "/api/tasks", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ var tasks []map[string]interface{}
+ if err := json.NewDecoder(w.Body).Decode(&tasks); err != nil {
+ t.Fatalf("decode: %v", err)
+ }
+
+ var found map[string]interface{}
+ for _, tsk := range tasks {
+ if tsk["id"] == tk.ID {
+ found = tsk
+ break
+ }
+ }
+ if found == nil {
+ t.Fatalf("task %q not found in list response", tk.ID)
+ }
+
+ ds, ok := found["deployment_status"].(map[string]interface{})
+ if !ok {
+ t.Fatalf("READY task missing deployment_status field; got: %v", found["deployment_status"])
+ }
+ if _, ok := ds["deployed_commit"]; !ok {
+ t.Error("deployment_status missing deployed_commit")
+ }
+ if _, ok := ds["includes_fix"]; !ok {
+ t.Error("deployment_status missing includes_fix")
+ }
+}
+
+// TestGetTask_ReadyTask_IncludesDeploymentStatus verifies that GET /api/tasks/{id}
+// returns a deployment_status field for a READY task.
+func TestGetTask_ReadyTask_IncludesDeploymentStatus(t *testing.T) {
+ srv, store := testServer(t)
+
+ tk := createTaskWithState(t, store, "enrich-get-ready-1", task.StateReady)
+ exec := &storage.Execution{
+ ID: "enrich-get-exec-1",
+ TaskID: tk.ID,
+ StartTime: time.Now(),
+ EndTime: time.Now(),
+ Status: "COMPLETED",
+ Commits: []task.GitCommit{{Hash: "ddeeff", Message: "fix: get test"}},
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatal(err)
+ }
+
+ req := httptest.NewRequest("GET", "/api/tasks/"+tk.ID, nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d", w.Code)
+ }
+
+ var resp map[string]interface{}
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+ t.Fatalf("decode: %v", err)
+ }
+
+ ds, ok := resp["deployment_status"].(map[string]interface{})
+ if !ok {
+ t.Fatalf("READY task GET response missing deployment_status; got: %v", resp["deployment_status"])
+ }
+ if _, ok := ds["deployed_commit"]; !ok {
+ t.Error("deployment_status missing deployed_commit")
+ }
+ if _, ok := ds["includes_fix"]; !ok {
+ t.Error("deployment_status missing includes_fix")
+ }
+}
+
+// TestListTasks_NonReadyTask_OmitsDeploymentStatus verifies that non-READY tasks
+// (e.g. PENDING) do not include a deployment_status field.
+func TestListTasks_NonReadyTask_OmitsDeploymentStatus(t *testing.T) {
+ srv, store := testServer(t)
+
+ tk := createTaskWithState(t, store, "enrich-list-pending-1", task.StatePending)
+
+ req := httptest.NewRequest("GET", "/api/tasks", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d", w.Code)
+ }
+
+ var tasks []map[string]interface{}
+ if err := json.NewDecoder(w.Body).Decode(&tasks); err != nil {
+ t.Fatalf("decode: %v", err)
+ }
+
+ var found map[string]interface{}
+ for _, tsk := range tasks {
+ if tsk["id"] == tk.ID {
+ found = tsk
+ break
+ }
+ }
+ if found == nil {
+ t.Fatalf("task %q not found in list", tk.ID)
+ }
+
+ if _, ok := found["deployment_status"]; ok {
+ t.Error("PENDING task should not include deployment_status field")
+ }
+}
+
+func TestProjects_CRUD(t *testing.T) {
+ srv, _ := testServer(t)
+
+ // Create
+ body := `{"name":"testproj","local_path":"/workspace/testproj","type":"web"}`
+ req := httptest.NewRequest("POST", "/api/projects", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusCreated {
+ t.Fatalf("POST /api/projects: want 201, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var created map[string]interface{}
+ json.NewDecoder(w.Body).Decode(&created)
+ id, _ := created["id"].(string)
+ if id == "" {
+ t.Fatal("created project has no id")
+ }
+
+ // Get
+ req = httptest.NewRequest("GET", "/api/projects/"+id, nil)
+ w = httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusOK {
+ t.Fatalf("GET /api/projects/%s: want 200, got %d", id, w.Code)
+ }
+
+ // List
+ req = httptest.NewRequest("GET", "/api/projects", nil)
+ w = httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusOK {
+ t.Fatalf("GET /api/projects: want 200, got %d", w.Code)
+ }
+ var list []interface{}
+ json.NewDecoder(w.Body).Decode(&list)
+ if len(list) == 0 {
+ t.Error("expected at least one project in list")
+ }
+}
+
+func TestHandleRunTask_CascadesRetryToFailedDeps(t *testing.T) {
+ // Use a blocking runner so tasks stay QUEUED long enough to assert state.
+ block := make(chan struct{})
+ t.Cleanup(func() { close(block) })
+ srv, store := testServerWithRunner(t, &mockRunner{onRun: func(*task.Task, *storage.Execution) error {
+ <-block
+ return nil
+ }})
+
+ now := time.Now().UTC()
+
+ // Task A: the dependency, in FAILED state.
+ taskA := &task.Task{
+ ID: "cascade-dep-a",
+ Name: "Dep A",
+ State: task.StateFailed,
+ DependsOn: []string{},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ Agent: task.AgentConfig{Type: "claude", Instructions: "do A"},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := store.CreateTask(taskA); err != nil {
+ t.Fatalf("CreateTask A: %v", err)
+ }
+
+ // Task B: depends on A, in CANCELLED state (was cancelled because A failed).
+ taskB := &task.Task{
+ ID: "cascade-task-b",
+ Name: "Task B",
+ State: task.StateCancelled,
+ DependsOn: []string{taskA.ID},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ Agent: task.AgentConfig{Type: "claude", Instructions: "do B"},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := store.CreateTask(taskB); err != nil {
+ t.Fatalf("CreateTask B: %v", err)
+ }
+
+ // Run task B — should cascade-retry dep A.
+ req := httptest.NewRequest("POST", "/api/tasks/cascade-task-b/run", nil)
+ w := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("expected 202, got %d: %s", w.Code, w.Body.String())
+ }
+
+ // Dep A should now be QUEUED.
+ a, err := store.GetTask(taskA.ID)
+ if err != nil {
+ t.Fatalf("GetTask A: %v", err)
+ }
+ if a.State != task.StateQueued {
+ t.Errorf("dep A: want QUEUED after cascade, got %s", a.State)
+ }
+
+ // Task B itself should be QUEUED.
+ b, err := store.GetTask(taskB.ID)
+ if err != nil {
+ t.Fatalf("GetTask B: %v", err)
+ }
+ if b.State != task.StateQueued {
+ t.Errorf("task B: want QUEUED, got %s", b.State)
+ }
+}
+
+func TestShipStory_ShippableStory_Returns202(t *testing.T) {
+ srv, store := testServer(t)
+
+ proj := &task.Project{
+ ID: "ship-proj-1", Name: "test", RemoteURL: "https://github.com/x/y",
+ Type: "web", DeployScript: "",
+ }
+ if err := store.CreateProject(proj); err != nil {
+ t.Fatalf("CreateProject: %v", err)
+ }
+
+ story := &task.Story{
+ ID: "ship-story-1", Name: "Ship Test", ProjectID: "ship-proj-1",
+ Status: task.StoryShippable, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+
+ req := httptest.NewRequest("POST", "/api/stories/ship-story-1/ship", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Errorf("expected 202, got %d: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestShipStory_NonShippable_Returns409(t *testing.T) {
+ srv, store := testServer(t)
+
+ story := &task.Story{
+ ID: "nonship-1", Name: "Not Ready", ProjectID: "",
+ Status: task.StoryInProgress, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+
+ req := httptest.NewRequest("POST", "/api/stories/nonship-1/ship", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusConflict {
+ t.Errorf("expected 409, got %d", w.Code)
+ }
+}
diff --git a/internal/api/stories.go b/internal/api/stories.go
new file mode 100644
index 0000000..fa10ccd
--- /dev/null
+++ b/internal/api/stories.go
@@ -0,0 +1,378 @@
+package api
+
+import (
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "os/exec"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/thepeterstone/claudomator/internal/deployment"
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+// createStoryBranch creates a new git branch in localPath from the latest main
+// and pushes it to remoteURL (the bare repo). Idempotent: treats "already exists" as success.
+func createStoryBranch(localPath, branchName, remoteURL string) error {
+ // Fetch latest from the bare repo so we have an up-to-date base.
+ if out, err := exec.Command("git", "-C", localPath, "fetch", remoteURL, "main").CombinedOutput(); err != nil {
+ return fmt.Errorf("git fetch: %w (output: %s)", err, string(out))
+ }
+ base := "FETCH_HEAD"
+ out, err := exec.Command("git", "-C", localPath, "checkout", "-b", branchName, base).CombinedOutput()
+ if err != nil {
+ if !strings.Contains(string(out), "already exists") {
+ return fmt.Errorf("git checkout -b: %w (output: %s)", err, string(out))
+ }
+ }
+ if out, err := exec.Command("git", "-C", localPath, "push", remoteURL, branchName).CombinedOutput(); err != nil {
+ return fmt.Errorf("git push: %w (output: %s)", err, string(out))
+ }
+ return nil
+}
+
+func (s *Server) handleListStories(w http.ResponseWriter, r *http.Request) {
+ stories, err := s.store.ListStories()
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ if stories == nil {
+ stories = []*task.Story{}
+ }
+ writeJSON(w, http.StatusOK, stories)
+}
+
+func (s *Server) handleCreateStory(w http.ResponseWriter, r *http.Request) {
+ var st task.Story
+ if err := json.NewDecoder(r.Body).Decode(&st); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if st.Name == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"})
+ return
+ }
+ if st.ID == "" {
+ st.ID = uuid.New().String()
+ }
+ if st.Status == "" {
+ st.Status = task.StoryPending
+ }
+ if err := s.store.CreateStory(&st); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusCreated, st)
+}
+
+func (s *Server) handleGetStory(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ st, err := s.store.GetStory(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"})
+ return
+ }
+ writeJSON(w, http.StatusOK, st)
+}
+
+func (s *Server) handleListStoryTasks(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ if _, err := s.store.GetStory(id); err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"})
+ return
+ }
+ tasks, err := s.store.ListTasksByStory(id)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ if tasks == nil {
+ tasks = []*task.Task{}
+ }
+ writeJSON(w, http.StatusOK, tasks)
+}
+
+func (s *Server) handleAddTaskToStory(w http.ResponseWriter, r *http.Request) {
+ storyID := r.PathValue("id")
+ st, err := s.store.GetStory(storyID)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"})
+ return
+ }
+ _ = st
+
+ var input struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Project string `json:"project"`
+ RepositoryURL string `json:"repository_url"`
+ Agent task.AgentConfig `json:"agent"`
+ Claude task.AgentConfig `json:"claude"`
+ Timeout string `json:"timeout"`
+ Priority string `json:"priority"`
+ Tags []string `json:"tags"`
+ ParentTaskID string `json:"parent_task_id"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if input.Agent.Instructions == "" && input.Claude.Instructions != "" {
+ input.Agent = input.Claude
+ }
+
+ existing, err := s.store.ListTasksByStory(storyID)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+
+ now := time.Now().UTC()
+ t := &task.Task{
+ ID: uuid.New().String(),
+ Name: input.Name,
+ Description: input.Description,
+ Project: input.Project,
+ RepositoryURL: input.RepositoryURL,
+ Agent: input.Agent,
+ Priority: task.Priority(input.Priority),
+ Tags: input.Tags,
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
+ State: task.StatePending,
+ StoryID: storyID,
+ ParentTaskID: input.ParentTaskID,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+
+ if t.Agent.Type == "" {
+ t.Agent.Type = "claude"
+ }
+ if t.Priority == "" {
+ t.Priority = task.PriorityNormal
+ }
+ if t.Tags == nil {
+ t.Tags = []string{}
+ }
+ if input.Timeout != "" {
+ dur, err := time.ParseDuration(input.Timeout)
+ if err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid timeout: " + err.Error()})
+ return
+ }
+ t.Timeout.Duration = dur
+ }
+
+ // Auto-wire depends_on: new task depends on the last existing task (sorted ASC by created_at).
+ if len(existing) > 0 {
+ lastTask := existing[len(existing)-1]
+ t.DependsOn = []string{lastTask.ID}
+ }
+
+ if err := s.store.CreateTask(t); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusCreated, t)
+}
+
+func (s *Server) handleUpdateStoryStatus(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ st, err := s.store.GetStory(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"})
+ return
+ }
+
+ var input struct {
+ Status task.StoryState `json:"status"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if !task.ValidStoryTransition(st.Status, input.Status) {
+ writeJSON(w, http.StatusConflict, map[string]string{
+ "error": "invalid story status transition from " + string(st.Status) + " to " + string(input.Status),
+ })
+ return
+ }
+ if err := s.store.UpdateStoryStatus(id, input.Status); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusOK, map[string]string{"message": "story status updated", "story_id": id, "status": string(input.Status)})
+}
+
+func (s *Server) handleApproveStory(w http.ResponseWriter, r *http.Request) {
+ var input struct {
+ Name string `json:"name"`
+ BranchName string `json:"branch_name"`
+ ProjectID string `json:"project_id"`
+ Tasks []elaboratedStoryTask `json:"tasks"`
+ Validation elaboratedStoryValidation `json:"validation"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+ if input.Name == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"})
+ return
+ }
+
+ validationJSON, _ := json.Marshal(input.Validation)
+ now := time.Now().UTC()
+ story := &task.Story{
+ ID: uuid.New().String(),
+ Name: input.Name,
+ ProjectID: input.ProjectID,
+ BranchName: input.BranchName,
+ ValidationJSON: string(validationJSON),
+ Status: task.StoryPending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := s.store.CreateStory(story); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+
+ var repoURL string
+ if input.ProjectID != "" {
+ if proj, err := s.store.GetProject(input.ProjectID); err == nil {
+ repoURL = proj.RemoteURL
+ }
+ }
+
+ taskIDs := make([]string, 0, len(input.Tasks))
+ var prevTaskID string
+ for _, tp := range input.Tasks {
+ t := &task.Task{
+ ID: uuid.New().String(),
+ Name: tp.Name,
+ Project: input.ProjectID,
+ RepositoryURL: repoURL,
+ StoryID: story.ID,
+ Agent: task.AgentConfig{Type: "claude", Instructions: tp.Instructions},
+ AcceptanceCriteria: tp.AcceptanceCriteria,
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
+ State: task.StatePending,
+ CreatedAt: time.Now().UTC(),
+ UpdatedAt: time.Now().UTC(),
+ }
+ if prevTaskID != "" {
+ t.DependsOn = []string{prevTaskID}
+ }
+ if err := s.store.CreateTask(t); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ taskIDs = append(taskIDs, t.ID)
+
+ var prevSubtaskID string
+ for _, sub := range tp.Subtasks {
+ st := &task.Task{
+ ID: uuid.New().String(),
+ Name: sub.Name,
+ Project: input.ProjectID,
+ RepositoryURL: repoURL,
+ StoryID: story.ID,
+ ParentTaskID: t.ID,
+ Agent: task.AgentConfig{Type: "claude", Instructions: sub.Instructions},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
+ State: task.StatePending,
+ CreatedAt: time.Now().UTC(),
+ UpdatedAt: time.Now().UTC(),
+ }
+ if prevSubtaskID != "" {
+ st.DependsOn = []string{prevSubtaskID}
+ }
+ if err := s.store.CreateTask(st); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ prevSubtaskID = st.ID
+ }
+ prevTaskID = t.ID
+ }
+
+ // Create the story branch (non-fatal if it fails).
+ if input.BranchName != "" && input.ProjectID != "" {
+ if proj, err := s.store.GetProject(input.ProjectID); err == nil && proj.LocalPath != "" && proj.RemoteURL != "" {
+ if err := createStoryBranch(proj.LocalPath, input.BranchName, proj.RemoteURL); err != nil {
+ s.logger.Warn("story approve: failed to create branch", "error", err, "branch", input.BranchName)
+ }
+ }
+ }
+
+ writeJSON(w, http.StatusCreated, map[string]interface{}{
+ "story": story,
+ "task_ids": taskIDs,
+ })
+}
+
+// handleShipStory triggers the merge + deploy for a SHIPPABLE story.
+// POST /api/stories/{id}/ship
+func (s *Server) handleShipStory(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ if err := s.pool.ShipStory(r.Context(), id); err != nil {
+ writeJSON(w, http.StatusConflict, map[string]string{"error": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusAccepted, map[string]string{"message": "story shipping initiated", "story_id": id})
+}
+
+// handleStoryDeploymentStatus aggregates the deployment status across all tasks in a story.
+// GET /api/stories/{id}/deployment-status
+func (s *Server) handleStoryDeploymentStatus(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+
+ story, err := s.store.GetStory(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "story not found"})
+ return
+ }
+
+ tasks, err := s.store.ListTasksByStory(id)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+
+ // Collect all commits from the latest execution of each task.
+ var allCommits []task.GitCommit
+ for _, t := range tasks {
+ exec, err := s.store.GetLatestExecution(t.ID)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ continue
+ }
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ allCommits = append(allCommits, exec.Commits...)
+ }
+
+ // Determine project remote URL for the deployment check.
+ projectRemoteURL := ""
+ if story.ProjectID != "" {
+ if proj, err := s.store.GetProject(story.ProjectID); err == nil {
+ projectRemoteURL = proj.RemoteURL
+ }
+ }
+
+ status := deployment.Check(allCommits, projectRemoteURL)
+ writeJSON(w, http.StatusOK, status)
+}
diff --git a/internal/api/stories_test.go b/internal/api/stories_test.go
new file mode 100644
index 0000000..f43ad86
--- /dev/null
+++ b/internal/api/stories_test.go
@@ -0,0 +1,351 @@
+package api
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/thepeterstone/claudomator/internal/deployment"
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+func TestCreateStory_API(t *testing.T) {
+ srv, _ := testServer(t)
+
+ body := `{"name":"My Story","project_id":"proj-1"}`
+ req := httptest.NewRequest("POST", "/api/stories", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusCreated {
+ t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
+ }
+ var st task.Story
+ if err := json.NewDecoder(w.Body).Decode(&st); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if st.Name != "My Story" {
+ t.Errorf("Name: want 'My Story', got %q", st.Name)
+ }
+ if st.ID == "" {
+ t.Error("ID should be auto-generated")
+ }
+ if st.Status != task.StoryPending {
+ t.Errorf("Status: want PENDING, got %q", st.Status)
+ }
+}
+
+func TestGetStory_API(t *testing.T) {
+ srv, _ := testServer(t)
+
+ // Create a story first.
+ body := `{"name":"Get Me"}`
+ req := httptest.NewRequest("POST", "/api/stories", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w, req)
+ if w.Code != http.StatusCreated {
+ t.Fatalf("create story: expected 201, got %d", w.Code)
+ }
+ var created task.Story
+ json.NewDecoder(w.Body).Decode(&created)
+
+ // Fetch it.
+ req2 := httptest.NewRequest("GET", "/api/stories/"+created.ID, nil)
+ w2 := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w2, req2)
+
+ if w2.Code != http.StatusOK {
+ t.Fatalf("get story: expected 200, got %d: %s", w2.Code, w2.Body.String())
+ }
+ var got task.Story
+ if err := json.NewDecoder(w2.Body).Decode(&got); err != nil {
+ t.Fatalf("decode: %v", err)
+ }
+ if got.ID != created.ID {
+ t.Errorf("ID: want %q, got %q", created.ID, got.ID)
+ }
+ if got.Name != "Get Me" {
+ t.Errorf("Name: want 'Get Me', got %q", got.Name)
+ }
+}
+
+func TestAddTaskToStory_AutoWiresDependsOn(t *testing.T) {
+ srv, _ := testServer(t)
+
+ // Create a story.
+ storyBody := `{"name":"Story For Tasks"}`
+ req := httptest.NewRequest("POST", "/api/stories", bytes.NewBufferString(storyBody))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w, req)
+ if w.Code != http.StatusCreated {
+ t.Fatalf("create story: %d %s", w.Code, w.Body.String())
+ }
+ var story task.Story
+ json.NewDecoder(w.Body).Decode(&story)
+
+ addTask := func(name string) *task.Task {
+ body := `{"name":"` + name + `","agent":{"type":"claude","instructions":"do it"}}`
+ r := httptest.NewRequest("POST", "/api/stories/"+story.ID+"/tasks", bytes.NewBufferString(body))
+ r.Header.Set("Content-Type", "application/json")
+ wr := httptest.NewRecorder()
+ srv.mux.ServeHTTP(wr, r)
+ if wr.Code != http.StatusCreated {
+ t.Fatalf("add task %s: expected 201, got %d: %s", name, wr.Code, wr.Body.String())
+ }
+ var tk task.Task
+ json.NewDecoder(wr.Body).Decode(&tk)
+ return &tk
+ }
+
+ task1 := addTask("Task 1")
+ task2 := addTask("Task 2")
+ task3 := addTask("Task 3")
+
+ // task1 has no dependencies.
+ if len(task1.DependsOn) != 0 {
+ t.Errorf("task1.DependsOn: want [], got %v", task1.DependsOn)
+ }
+ // task2 depends on task1.
+ if len(task2.DependsOn) != 1 || task2.DependsOn[0] != task1.ID {
+ t.Errorf("task2.DependsOn: want [%s], got %v", task1.ID, task2.DependsOn)
+ }
+ // task3 depends on task2.
+ if len(task3.DependsOn) != 1 || task3.DependsOn[0] != task2.ID {
+ t.Errorf("task3.DependsOn: want [%s], got %v", task2.ID, task3.DependsOn)
+ }
+}
+
+func TestBuildStoryElaboratePrompt(t *testing.T) {
+ prompt := buildStoryElaboratePrompt()
+ checks := []struct {
+ label string
+ want string
+ }{
+ {"schema: name field", `"name"`},
+ {"schema: branch_name field", `"branch_name"`},
+ {"schema: tasks field", `"tasks"`},
+ {"schema: validation field", `"validation"`},
+ {"rule: git push", "git push origin"},
+ {"rule: sequential subtasks", "sequentially"},
+ {"rule: specific file paths", "file paths"},
+ }
+ for _, c := range checks {
+ if !strings.Contains(prompt, c.want) {
+ t.Errorf("%s: prompt should contain %q", c.label, c.want)
+ }
+ }
+}
+
+func TestHandleStoryApprove_WiresDepends(t *testing.T) {
+ srv, _ := testServer(t)
+
+ body := `{
+ "name": "My Story",
+ "branch_name": "story/my-story",
+ "tasks": [
+ {"name": "Task 1", "instructions": "do task 1", "subtasks": []},
+ {"name": "Task 2", "instructions": "do task 2", "subtasks": []},
+ {"name": "Task 3", "instructions": "do task 3", "subtasks": []}
+ ],
+ "validation": {"type": "build", "steps": ["go build ./..."], "success_criteria": "compiles"}
+ }`
+ req := httptest.NewRequest("POST", "/api/stories/approve", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusCreated {
+ t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
+ }
+
+ var resp struct {
+ Story task.Story `json:"story"`
+ TaskIDs []string `json:"task_ids"`
+ }
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if len(resp.TaskIDs) != 3 {
+ t.Fatalf("expected 3 task IDs, got %d", len(resp.TaskIDs))
+ }
+ if resp.Story.Name != "My Story" {
+ t.Errorf("story name: want 'My Story', got %q", resp.Story.Name)
+ }
+
+ // Verify depends_on chain via the store.
+ store := srv.store
+ task1, err := store.GetTask(resp.TaskIDs[0])
+ if err != nil {
+ t.Fatalf("GetTask[0]: %v", err)
+ }
+ task2, err := store.GetTask(resp.TaskIDs[1])
+ if err != nil {
+ t.Fatalf("GetTask[1]: %v", err)
+ }
+ task3, err := store.GetTask(resp.TaskIDs[2])
+ if err != nil {
+ t.Fatalf("GetTask[2]: %v", err)
+ }
+
+ if len(task1.DependsOn) != 0 {
+ t.Errorf("task1.DependsOn: want [], got %v", task1.DependsOn)
+ }
+ if len(task2.DependsOn) != 1 || task2.DependsOn[0] != task1.ID {
+ t.Errorf("task2.DependsOn: want [%s], got %v", task1.ID, task2.DependsOn)
+ }
+ if len(task3.DependsOn) != 1 || task3.DependsOn[0] != task2.ID {
+ t.Errorf("task3.DependsOn: want [%s], got %v", task2.ID, task3.DependsOn)
+ }
+}
+
+func TestHandleStoryApprove_SetsRepositoryURL(t *testing.T) {
+ srv, store := testServer(t)
+
+ proj := &task.Project{
+ ID: "proj-repo",
+ Name: "claudomator",
+ RemoteURL: "/site/git.terst.org/repos/claudomator.git",
+ // LocalPath intentionally empty: branch creation is a non-fatal side effect,
+ // omitting it keeps the test fast and free of real git operations.
+ }
+ if err := store.CreateProject(proj); err != nil {
+ t.Fatalf("CreateProject: %v", err)
+ }
+
+ body := `{
+ "name": "Repo URL Story",
+ "branch_name": "story/repo-url",
+ "project_id": "proj-repo",
+ "tasks": [
+ {"name": "Task A", "instructions": "do A", "subtasks": []},
+ {"name": "Task B", "instructions": "do B", "subtasks": [
+ {"name": "Sub B1", "instructions": "do B1"}
+ ]}
+ ],
+ "validation": {"type": "build", "steps": ["go build ./..."], "success_criteria": "ok"}
+ }`
+ req := httptest.NewRequest("POST", "/api/stories/approve", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusCreated {
+ t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
+ }
+
+ var resp struct {
+ TaskIDs []string `json:"task_ids"`
+ }
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+ t.Fatalf("decode: %v", err)
+ }
+
+ for _, id := range resp.TaskIDs {
+ tk, err := store.GetTask(id)
+ if err != nil {
+ t.Fatalf("GetTask %s: %v", id, err)
+ }
+ if tk.RepositoryURL != proj.RemoteURL {
+ t.Errorf("task %s RepositoryURL: want %q, got %q", tk.ID, proj.RemoteURL, tk.RepositoryURL)
+ }
+ }
+}
+
+func TestApproveStory_AcceptanceCriteriaStored(t *testing.T) {
+ srv, store := testServer(t)
+
+ proj := &task.Project{
+ ID: "ac-proj", Name: "test", RemoteURL: "https://github.com/x/y",
+ Type: "web", DeployScript: "",
+ }
+ store.CreateProject(proj)
+
+ body := `{
+ "name": "AC Story",
+ "branch_name": "story/ac-test",
+ "project_id": "ac-proj",
+ "tasks": [
+ {
+ "name": "Add feature",
+ "instructions": "implement the thing",
+ "acceptance_criteria": "run go test ./... and verify all pass",
+ "subtasks": []
+ }
+ ],
+ "validation": {"type": "test", "steps": [], "success_criteria": "tests pass"}
+ }`
+ req := httptest.NewRequest("POST", "/api/stories/approve", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusCreated {
+ t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
+ }
+
+ var resp struct {
+ TaskIDs []string `json:"task_ids"`
+ }
+ json.NewDecoder(w.Body).Decode(&resp)
+ if len(resp.TaskIDs) == 0 {
+ t.Fatal("expected task_ids in response")
+ }
+
+ tk, err := store.GetTask(resp.TaskIDs[0])
+ if err != nil {
+ t.Fatalf("GetTask: %v", err)
+ }
+ if tk.AcceptanceCriteria != "run go test ./... and verify all pass" {
+ t.Errorf("expected acceptance criteria stored on task, got %q", tk.AcceptanceCriteria)
+ }
+}
+
+func TestHandleStoryDeploymentStatus(t *testing.T) {
+ srv, store := testServer(t)
+
+ // Create a story.
+ now := time.Now().UTC()
+ story := &task.Story{
+ ID: "deploy-story-1",
+ Name: "Deploy Status Story",
+ Status: task.StoryInProgress,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+
+ // Request deployment status — no tasks yet.
+ req := httptest.NewRequest("GET", "/api/stories/deploy-story-1/deployment-status", nil)
+ w := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ var status deployment.Status
+ if err := json.NewDecoder(w.Body).Decode(&status); err != nil {
+ t.Fatalf("decode: %v", err)
+ }
+ // No tasks → no commits → IncludesFix = false (nothing to check).
+ if status.IncludesFix {
+ t.Error("expected IncludesFix=false when no commits")
+ }
+
+ // 404 for unknown story.
+ req2 := httptest.NewRequest("GET", "/api/stories/nonexistent/deployment-status", nil)
+ w2 := httptest.NewRecorder()
+ srv.mux.ServeHTTP(w2, req2)
+ if w2.Code != http.StatusNotFound {
+ t.Errorf("expected 404 for unknown story, got %d", w2.Code)
+ }
+}
diff --git a/internal/api/task_view.go b/internal/api/task_view.go
new file mode 100644
index 0000000..6a4b58e
--- /dev/null
+++ b/internal/api/task_view.go
@@ -0,0 +1,47 @@
+package api
+
+import (
+ "database/sql"
+
+ "github.com/thepeterstone/claudomator/internal/deployment"
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+// taskView wraps a task with computed fields that are derived from execution
+// history and deployment state. It is used as the JSON response type for task
+// list and get endpoints so that callers receive enriched data in one request.
+type taskView struct {
+ *task.Task
+ Changestats *task.Changestats `json:"changestats,omitempty"`
+ DeploymentStatus *deployment.Status `json:"deployment_status,omitempty"`
+ ErrorMsg string `json:"error_msg,omitempty"`
+}
+
+var failedStates = map[task.State]bool{
+ task.StateFailed: true,
+ task.StateBudgetExceeded: true,
+ task.StateTimedOut: true,
+}
+
+// enrichTask fetches the latest execution for the given task and attaches
+// changestats, deployment_status, and error_msg fields.
+func (s *Server) enrichTask(tk *task.Task) *taskView {
+ view := &taskView{Task: tk}
+
+ exec, err := s.store.GetLatestExecution(tk.ID)
+ if err != nil {
+ if err == sql.ErrNoRows && tk.State == task.StateReady {
+ view.DeploymentStatus = deployment.Check(nil, tk.RepositoryURL)
+ }
+ return view
+ }
+
+ if failedStates[tk.State] && exec.ErrorMsg != "" {
+ view.ErrorMsg = exec.ErrorMsg
+ }
+ if tk.State == task.StateReady {
+ view.Changestats = exec.Changestats
+ view.DeploymentStatus = deployment.Check(exec.Commits, tk.RepositoryURL)
+ }
+ return view
+}
diff --git a/internal/api/webhook.go b/internal/api/webhook.go
index 9437f7d..3af4cc8 100644
--- a/internal/api/webhook.go
+++ b/internal/api/webhook.go
@@ -8,6 +8,7 @@ import (
"encoding/json"
"fmt"
"io"
+ "log/slog"
"net/http"
"path/filepath"
"strings"
@@ -18,17 +19,26 @@ import (
"github.com/thepeterstone/claudomator/internal/task"
)
+// prRef is a minimal pull request entry used to extract branch names.
+type prRef struct {
+ Head struct {
+ Ref string `json:"ref"`
+ } `json:"head"`
+}
+
// checkRunPayload is the GitHub check_run webhook payload.
type checkRunPayload struct {
Action string `json:"action"`
CheckRun struct {
- Name string `json:"name"`
- Conclusion string `json:"conclusion"`
- HTMLURL string `json:"html_url"`
- HeadSHA string `json:"head_sha"`
- CheckSuite struct {
+ Name string `json:"name"`
+ Conclusion string `json:"conclusion"`
+ HTMLURL string `json:"html_url"`
+ DetailsURL string `json:"details_url"`
+ HeadSHA string `json:"head_sha"`
+ CheckSuite struct {
HeadBranch string `json:"head_branch"`
} `json:"check_suite"`
+ PullRequests []prRef `json:"pull_requests"`
} `json:"check_run"`
Repository struct {
Name string `json:"name"`
@@ -40,11 +50,12 @@ type checkRunPayload struct {
type workflowRunPayload struct {
Action string `json:"action"`
WorkflowRun struct {
- Name string `json:"name"`
- Conclusion string `json:"conclusion"`
- HTMLURL string `json:"html_url"`
- HeadSHA string `json:"head_sha"`
- HeadBranch string `json:"head_branch"`
+ Name string `json:"name"`
+ Conclusion string `json:"conclusion"`
+ HTMLURL string `json:"html_url"`
+ HeadSHA string `json:"head_sha"`
+ HeadBranch string `json:"head_branch"`
+ PullRequests []prRef `json:"pull_requests"`
} `json:"workflow_run"`
Repository struct {
Name string `json:"name"`
@@ -98,6 +109,7 @@ func (s *Server) handleGitHubWebhook(w http.ResponseWriter, r *http.Request) {
}
eventType := r.Header.Get("X-GitHub-Event")
+ slog.Info("github webhook received", "event", eventType, "bytes", len(body))
switch eventType {
case "check_run":
s.handleCheckRunEvent(w, body)
@@ -118,13 +130,22 @@ func (s *Server) handleCheckRunEvent(w http.ResponseWriter, body []byte) {
w.WriteHeader(http.StatusNoContent)
return
}
+ branch := p.CheckRun.CheckSuite.HeadBranch
+ if branch == "" && len(p.CheckRun.PullRequests) > 0 {
+ branch = p.CheckRun.PullRequests[0].Head.Ref
+ }
+ htmlURL := p.CheckRun.HTMLURL
+ if htmlURL == "" {
+ htmlURL = p.CheckRun.DetailsURL
+ }
+ slog.Info("check_run webhook", "repo", p.Repository.FullName, "conclusion", p.CheckRun.Conclusion, "branch", branch, "html_url", htmlURL)
s.createCIFailureTask(w,
p.Repository.Name,
p.Repository.FullName,
- p.CheckRun.CheckSuite.HeadBranch,
+ branch,
p.CheckRun.HeadSHA,
p.CheckRun.Name,
- p.CheckRun.HTMLURL,
+ htmlURL,
)
}
@@ -142,10 +163,15 @@ func (s *Server) handleWorkflowRunEvent(w http.ResponseWriter, body []byte) {
w.WriteHeader(http.StatusNoContent)
return
}
+ branch := p.WorkflowRun.HeadBranch
+ if branch == "" && len(p.WorkflowRun.PullRequests) > 0 {
+ branch = p.WorkflowRun.PullRequests[0].Head.Ref
+ }
+ slog.Info("workflow_run webhook", "repo", p.Repository.FullName, "conclusion", p.WorkflowRun.Conclusion, "branch", branch, "html_url", p.WorkflowRun.HTMLURL)
s.createCIFailureTask(w,
p.Repository.Name,
p.Repository.FullName,
- p.WorkflowRun.HeadBranch,
+ branch,
p.WorkflowRun.HeadSHA,
p.WorkflowRun.Name,
p.WorkflowRun.HTMLURL,
@@ -155,6 +181,10 @@ func (s *Server) handleWorkflowRunEvent(w http.ResponseWriter, body []byte) {
func (s *Server) createCIFailureTask(w http.ResponseWriter, repoName, fullName, branch, sha, checkName, htmlURL string) {
project := matchProject(s.projects, repoName)
+ if htmlURL == "" && fullName != "" && sha != "" {
+ htmlURL = fmt.Sprintf("https://github.com/%s/commit/%s", fullName, sha)
+ }
+
fallback := fmt.Sprintf(
"A CI failure has been detected and requires investigation.\n\n"+
"Repository: %s\n"+
@@ -188,20 +218,22 @@ func (s *Server) createCIFailureTask(w http.ResponseWriter, repoName, fullName,
Name: fmt.Sprintf("Fix CI failure: %s on %s", checkName, branch),
Agent: task.AgentConfig{
Type: "claude",
+ Model: "sonnet",
Instructions: instructions,
MaxBudgetUSD: 3.0,
AllowedTools: []string{"Read", "Edit", "Bash", "Glob", "Grep"},
},
- Priority: task.PriorityNormal,
- Tags: []string{"ci", "auto"},
- DependsOn: []string{},
- Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
- State: task.StatePending,
- CreatedAt: now,
- UpdatedAt: now,
+ Priority: task.PriorityNormal,
+ Tags: []string{"ci", "auto"},
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
+ State: task.StatePending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ RepositoryURL: fmt.Sprintf("https://github.com/%s.git", fullName),
}
if project != nil {
- t.Agent.ProjectDir = project.Dir
+ t.Project = project.Name
}
if err := s.store.CreateTask(t); err != nil {
diff --git a/internal/api/webhook_test.go b/internal/api/webhook_test.go
index 8b0599a..967b62b 100644
--- a/internal/api/webhook_test.go
+++ b/internal/api/webhook_test.go
@@ -124,8 +124,8 @@ func TestGitHubWebhook_CheckRunFailure_CreatesTask(t *testing.T) {
if !strings.Contains(tk.Name, "main") {
t.Errorf("task name %q does not contain branch", tk.Name)
}
- if tk.Agent.ProjectDir != "/workspace/myrepo" {
- t.Errorf("task project dir = %q, want /workspace/myrepo", tk.Agent.ProjectDir)
+ if tk.RepositoryURL != "https://github.com/owner/myrepo.git" {
+ t.Errorf("task repository url = %q, want https://github.com/owner/myrepo.git", tk.RepositoryURL)
}
if !contains(tk.Tags, "ci") || !contains(tk.Tags, "auto") {
t.Errorf("task tags %v missing expected ci/auto tags", tk.Tags)
@@ -237,6 +237,104 @@ func TestGitHubWebhook_NoSecretConfigured_SkipsHMACCheck(t *testing.T) {
}
}
+func TestGitHubWebhook_CreatesTask_WithDefaultModel(t *testing.T) {
+ srv, store := testServer(t)
+ srv.projects = []config.Project{{Name: "myrepo", Dir: "/workspace/myrepo"}}
+
+ w := webhookPost(t, srv, "check_run", checkRunFailurePayload, "")
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d", w.Code)
+ }
+ var resp map[string]string
+ json.NewDecoder(w.Body).Decode(&resp)
+ tk, err := store.GetTask(resp["task_id"])
+ if err != nil {
+ t.Fatalf("task not found: %v", err)
+ }
+ if tk.Agent.Model == "" {
+ t.Error("expected model to be set, got empty string")
+ }
+}
+
+const checkRunNullBranchPayload = `{
+ "action": "completed",
+ "check_run": {
+ "name": "CI Build",
+ "conclusion": "failure",
+ "html_url": "",
+ "details_url": "https://github.com/owner/myrepo/actions/runs/999/jobs/123",
+ "head_sha": "abc123def",
+ "check_suite": {
+ "head_branch": null
+ },
+ "pull_requests": [{"head": {"ref": "feature/my-branch"}}]
+ },
+ "repository": {
+ "name": "myrepo",
+ "full_name": "owner/myrepo"
+ }
+}`
+
+func TestGitHubWebhook_CheckRunNullBranch_UsesPRRefAndDetailsURL(t *testing.T) {
+ srv, store := testServer(t)
+ srv.projects = []config.Project{{Name: "myrepo", Dir: "/workspace/myrepo"}}
+
+ w := webhookPost(t, srv, "check_run", checkRunNullBranchPayload, "")
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var resp map[string]string
+ json.NewDecoder(w.Body).Decode(&resp)
+ tk, err := store.GetTask(resp["task_id"])
+ if err != nil {
+ t.Fatalf("task not found: %v", err)
+ }
+ if !strings.Contains(tk.Name, "feature/my-branch") {
+ t.Errorf("task name %q should contain PR branch", tk.Name)
+ }
+ if !strings.Contains(tk.Agent.Instructions, "actions/runs/999") {
+ t.Errorf("instructions should contain details_url fallback, got: %s", tk.Agent.Instructions)
+ }
+}
+
+const workflowRunNullBranchPayload = `{
+ "action": "completed",
+ "workflow_run": {
+ "name": "CI Pipeline",
+ "conclusion": "failure",
+ "html_url": "",
+ "head_sha": "def456abc",
+ "head_branch": null,
+ "pull_requests": [{"head": {"ref": "fix/something"}}]
+ },
+ "repository": {
+ "name": "myrepo",
+ "full_name": "owner/myrepo"
+ }
+}`
+
+func TestGitHubWebhook_WorkflowRunNullBranch_UsesPRRef(t *testing.T) {
+ srv, store := testServer(t)
+ srv.projects = []config.Project{{Name: "myrepo", Dir: "/workspace/myrepo"}}
+
+ w := webhookPost(t, srv, "workflow_run", workflowRunNullBranchPayload, "")
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var resp map[string]string
+ json.NewDecoder(w.Body).Decode(&resp)
+ tk, err := store.GetTask(resp["task_id"])
+ if err != nil {
+ t.Fatalf("task not found: %v", err)
+ }
+ if !strings.Contains(tk.Name, "fix/something") {
+ t.Errorf("task name %q should contain PR branch", tk.Name)
+ }
+}
+
func TestGitHubWebhook_UnknownEvent_Returns204(t *testing.T) {
srv, _ := testServer(t)
@@ -277,14 +375,14 @@ func TestGitHubWebhook_FallbackToSingleProject(t *testing.T) {
if err != nil {
t.Fatalf("task not found: %v", err)
}
- if tk.Agent.ProjectDir != "/workspace/someapp" {
- t.Errorf("expected fallback to /workspace/someapp, got %q", tk.Agent.ProjectDir)
+ if tk.RepositoryURL != "https://github.com/owner/myrepo.git" {
+ t.Errorf("expected fallback repository url, got %q", tk.RepositoryURL)
}
}
-func TestGitHubWebhook_NoProjectsConfigured_CreatesTaskWithoutProjectDir(t *testing.T) {
+func TestGitHubWebhook_NoProjectsConfigured_CreatesTaskWithGitHubURL(t *testing.T) {
srv, store := testServer(t)
- // No projects configured — task should still be created, just no project dir set.
+ // No projects configured — task should still be created with the GitHub remote URL.
w := webhookPost(t, srv, "check_run", checkRunFailurePayload, "")
@@ -297,8 +395,8 @@ func TestGitHubWebhook_NoProjectsConfigured_CreatesTaskWithoutProjectDir(t *test
if err != nil {
t.Fatalf("task not found: %v", err)
}
- if tk.Agent.ProjectDir != "" {
- t.Errorf("expected empty project dir, got %q", tk.Agent.ProjectDir)
+ if tk.RepositoryURL == "" {
+ t.Error("expected non-empty repository_url from GitHub webhook payload")
}
}
diff --git a/internal/cli/list.go b/internal/cli/list.go
index 3425388..ab80868 100644
--- a/internal/cli/list.go
+++ b/internal/cli/list.go
@@ -49,10 +49,10 @@ func listTasks(state string) error {
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
- fmt.Fprintln(w, "ID\tNAME\tSTATE\tPRIORITY\tCREATED")
+ fmt.Fprintln(w, "ID\tNAME\tPROJECT\tSTATE\tPRIORITY\tCREATED")
for _, t := range tasks {
- fmt.Fprintf(w, "%.8s\t%s\t%s\t%s\t%s\n",
- t.ID, t.Name, t.State, t.Priority, t.CreatedAt.Format("2006-01-02 15:04"))
+ fmt.Fprintf(w, "%.8s\t%s\t%s\t%s\t%s\t%s\n",
+ t.ID, t.Name, t.Project, t.State, t.Priority, t.CreatedAt.Format("2006-01-02 15:04"))
}
w.Flush()
return nil
diff --git a/internal/cli/project_test.go b/internal/cli/project_test.go
new file mode 100644
index 0000000..c62e181
--- /dev/null
+++ b/internal/cli/project_test.go
@@ -0,0 +1,102 @@
+package cli
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/thepeterstone/claudomator/internal/config"
+ "github.com/thepeterstone/claudomator/internal/storage"
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+func makeProjectTask(t *testing.T, dir string) *task.Task {
+ t.Helper()
+ db, err := storage.Open(filepath.Join(dir, "test.db"))
+ if err != nil {
+ t.Fatalf("storage.Open: %v", err)
+ }
+ defer db.Close()
+
+ now := time.Now().UTC()
+ tk := &task.Task{
+ ID: "proj-task-id",
+ Name: "Project Task",
+ Project: "test-project",
+ Agent: task.AgentConfig{Type: "claude", Instructions: "do it", Model: "sonnet"},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
+ State: task.StatePending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatalf("CreateTask: %v", err)
+ }
+ return tk
+}
+
+func captureStdout(fn func()) string {
+ old := os.Stdout
+ r, w, _ := os.Pipe()
+ os.Stdout = w
+
+ fn()
+
+ w.Close()
+ os.Stdout = old
+ var buf bytes.Buffer
+ io.Copy(&buf, r)
+ return buf.String()
+}
+
+func withDB(t *testing.T, dbPath string, fn func()) {
+ t.Helper()
+ origCfg := cfg
+ if cfg == nil {
+ cfg = &config.Config{}
+ }
+ cfg.DBPath = dbPath
+ defer func() { cfg = origCfg }()
+ fn()
+}
+
+func TestListTasks_ShowsProject(t *testing.T) {
+ dir := t.TempDir()
+ dbPath := filepath.Join(dir, "test.db")
+ makeProjectTask(t, dir)
+
+ withDB(t, dbPath, func() {
+ out := captureStdout(func() {
+ if err := listTasks(""); err != nil {
+ t.Fatalf("listTasks: %v", err)
+ }
+ })
+ if !strings.Contains(out, "test-project") {
+ t.Errorf("list output missing project 'test-project':\n%s", out)
+ }
+ })
+}
+
+func TestStatusCmd_ShowsProject(t *testing.T) {
+ dir := t.TempDir()
+ dbPath := filepath.Join(dir, "test.db")
+ tk := makeProjectTask(t, dir)
+
+ withDB(t, dbPath, func() {
+ out := captureStdout(func() {
+ if err := showStatus(tk.ID); err != nil {
+ t.Fatalf("showStatus: %v", err)
+ }
+ })
+ if !strings.Contains(out, "test-project") {
+ t.Errorf("status output missing project 'test-project':\n%s", out)
+ }
+ })
+}
diff --git a/internal/cli/root.go b/internal/cli/root.go
index 7c4f2ff..e57a9d9 100644
--- a/internal/cli/root.go
+++ b/internal/cli/root.go
@@ -60,6 +60,7 @@ func NewRootCmd() *cobra.Command {
}
cfg.DBPath = filepath.Join(cfg.DataDir, "claudomator.db")
cfg.LogDir = filepath.Join(cfg.DataDir, "executions")
+ cfg.DropsDir = filepath.Join(cfg.DataDir, "drops")
return nil
}
@@ -73,6 +74,7 @@ func NewRootCmd() *cobra.Command {
newStartCmd(),
newCreateCmd(),
newReportCmd(),
+ newVersionCmd(),
)
return cmd
diff --git a/internal/cli/run.go b/internal/cli/run.go
index 2d7c3d7..48f34b7 100644
--- a/internal/cli/run.go
+++ b/internal/cli/run.go
@@ -72,16 +72,31 @@ func runTasks(file string, parallel int, dryRun bool) error {
logger := newLogger(verbose)
+ apiURL := "http://localhost" + cfg.ServerAddr
+ if len(cfg.ServerAddr) > 0 && cfg.ServerAddr[0] != ':' {
+ apiURL = "http://" + cfg.ServerAddr
+ }
+
runners := map[string]executor.Runner{
- "claude": &executor.ClaudeRunner{
- BinaryPath: cfg.ClaudeBinaryPath,
- Logger: logger,
- LogDir: cfg.LogDir,
+ "claude": &executor.ContainerRunner{
+ Image: cfg.ClaudeImage,
+ Logger: logger,
+ LogDir: cfg.LogDir,
+ APIURL: apiURL,
+ DropsDir: cfg.DropsDir,
+ SSHAuthSock: cfg.SSHAuthSock,
+ ClaudeBinary: cfg.ClaudeBinaryPath,
+ GeminiBinary: cfg.GeminiBinaryPath,
},
- "gemini": &executor.GeminiRunner{
- BinaryPath: cfg.GeminiBinaryPath,
- Logger: logger,
- LogDir: cfg.LogDir,
+ "gemini": &executor.ContainerRunner{
+ Image: cfg.GeminiImage,
+ Logger: logger,
+ LogDir: cfg.LogDir,
+ APIURL: apiURL,
+ DropsDir: cfg.DropsDir,
+ SSHAuthSock: cfg.SSHAuthSock,
+ ClaudeBinary: cfg.ClaudeBinaryPath,
+ GeminiBinary: cfg.GeminiBinaryPath,
},
}
@@ -95,6 +110,7 @@ func runTasks(file string, parallel int, dryRun bool) error {
}
}
+
pool := executor.NewPool(parallel, runners, store, logger)
pool.Classifier = &executor.Classifier{
LLM: localClient,
diff --git a/internal/cli/serve.go b/internal/cli/serve.go
index 5101b81..459c35b 100644
--- a/internal/cli/serve.go
+++ b/internal/cli/serve.go
@@ -35,6 +35,8 @@ func newServeCmd() *cobra.Command {
cmd.Flags().StringVar(&addr, "addr", ":8484", "listen address")
cmd.Flags().StringVar(&workspaceRoot, "workspace-root", "/workspace", "root directory for listing workspaces")
+ cmd.Flags().StringVar(&cfg.ClaudeImage, "claude-image", cfg.ClaudeImage, "docker image for claude agents")
+ cmd.Flags().StringVar(&cfg.GeminiImage, "gemini-image", cfg.GeminiImage, "docker image for gemini agents")
return cmd
}
@@ -50,25 +52,68 @@ func serve(addr string) error {
}
defer store.Close()
+ // Load VAPID keys from DB; generate and persist if missing.
+ if cfg.VAPIDPublicKey == "" || cfg.VAPIDPrivateKey == "" {
+ pub, _ := store.GetSetting("vapid_public_key")
+ priv, _ := store.GetSetting("vapid_private_key")
+ if pub == "" || priv == "" || !notify.ValidateVAPIDPublicKey(pub) {
+ pub, priv, err = notify.GenerateVAPIDKeys()
+ if err != nil {
+ return fmt.Errorf("generating VAPID keys: %w", err)
+ }
+ _ = store.SetSetting("vapid_public_key", pub)
+ _ = store.SetSetting("vapid_private_key", priv)
+ }
+ cfg.VAPIDPublicKey = pub
+ cfg.VAPIDPrivateKey = priv
+ }
+
logger := newLogger(verbose)
apiURL := "http://localhost" + addr
if len(addr) > 0 && addr[0] != ':' {
apiURL = "http://" + addr
}
-
+
+ // Use configured credentials dir; sync-credentials keeps this populated.
+ claudeConfigDir := cfg.ClaudeConfigDir
+
+ repoDir, _ := os.Getwd()
runners := map[string]executor.Runner{
- "claude": &executor.ClaudeRunner{
- BinaryPath: cfg.ClaudeBinaryPath,
- Logger: logger,
- LogDir: cfg.LogDir,
- APIURL: apiURL,
+ // ContainerRunner: binaries are resolved via PATH inside the container image,
+ // so ClaudeBinary/GeminiBinary are left empty (host paths would not exist inside).
+ "claude": &executor.ContainerRunner{
+ Image: cfg.ClaudeImage,
+ Logger: logger,
+ LogDir: cfg.LogDir,
+ APIURL: apiURL,
+ DropsDir: cfg.DropsDir,
+ SSHAuthSock: cfg.SSHAuthSock,
+ ClaudeConfigDir: claudeConfigDir,
+ CredentialSyncCmd: filepath.Join(repoDir, "scripts", "sync-credentials"),
+ Store: store,
+ },
+ "gemini": &executor.ContainerRunner{
+ Image: cfg.GeminiImage,
+ Logger: logger,
+ LogDir: cfg.LogDir,
+ APIURL: apiURL,
+ DropsDir: cfg.DropsDir,
+ SSHAuthSock: cfg.SSHAuthSock,
+ ClaudeConfigDir: claudeConfigDir,
+ CredentialSyncCmd: filepath.Join(repoDir, "scripts", "sync-credentials"),
+ Store: store,
},
- "gemini": &executor.GeminiRunner{
- BinaryPath: cfg.GeminiBinaryPath,
- Logger: logger,
- LogDir: cfg.LogDir,
- APIURL: apiURL,
+ "container": &executor.ContainerRunner{
+ Image: "claudomator-agent:latest",
+ Logger: logger,
+ LogDir: cfg.LogDir,
+ APIURL: apiURL,
+ DropsDir: cfg.DropsDir,
+ SSHAuthSock: cfg.SSHAuthSock,
+ ClaudeConfigDir: claudeConfigDir,
+ CredentialSyncCmd: filepath.Join(repoDir, "scripts", "sync-credentials"),
+ Store: store,
},
}
@@ -83,6 +128,7 @@ func serve(addr string) error {
logger.Info("local runner registered", "endpoint", cfg.LocalModel.Endpoint, "model", cfg.LocalModel.Model)
}
+
pool := executor.NewPool(cfg.MaxConcurrent, runners, store, logger)
pool.Classifier = &executor.Classifier{
LLM: localClient,
@@ -91,14 +137,36 @@ func serve(addr string) error {
if localClient != nil {
pool.LLM = localClient
}
+
+ if err := store.SeedProjects(); err != nil {
+ logger.Error("failed to seed projects", "error", err)
+ }
+
pool.RecoverStaleRunning(context.Background())
pool.RecoverStaleQueued(context.Background())
pool.RecoverStaleBlocked()
srv := api.NewServer(store, pool, logger, cfg.ClaudeBinaryPath, cfg.GeminiBinaryPath)
+
+ // Configure notifiers: combine webhook (if set) with web push.
+ notifiers := []notify.Notifier{}
if cfg.WebhookURL != "" {
- srv.SetNotifier(notify.NewWebhookNotifier(cfg.WebhookURL, logger))
+ notifiers = append(notifiers, notify.NewWebhookNotifier(cfg.WebhookURL, logger))
+ }
+ webPushNotifier := &notify.WebPushNotifier{
+ Store: store,
+ VAPIDPublicKey: cfg.VAPIDPublicKey,
+ VAPIDPrivateKey: cfg.VAPIDPrivateKey,
+ VAPIDEmail: cfg.VAPIDEmail,
+ Logger: logger,
}
+ notifiers = append(notifiers, webPushNotifier)
+ srv.SetNotifier(notify.NewMultiNotifier(logger, notifiers...))
+
+ srv.SetVAPIDConfig(cfg.VAPIDPublicKey, cfg.VAPIDPrivateKey, cfg.VAPIDEmail)
+ srv.SetPushStore(store)
+ srv.SetDropsDir(cfg.DropsDir)
+
if cfg.WorkspaceRoot != "" {
srv.SetWorkspaceRoot(cfg.WorkspaceRoot)
}
@@ -115,6 +183,11 @@ func serve(addr string) error {
"deploy": filepath.Join(wd, "scripts", "deploy"),
})
+ // Graceful shutdown.
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ srv.SetContext(ctx)
srv.StartHub()
httpSrv := &http.Server{
@@ -122,19 +195,31 @@ func serve(addr string) error {
Handler: srv.Handler(),
}
- // Graceful shutdown.
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+ workerTimeout := 3 * time.Minute
+ if cfg.ShutdownTimeout > 0 {
+ workerTimeout = cfg.ShutdownTimeout
+ }
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigCh
- logger.Info("shutting down server...")
- shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second)
- defer shutdownCancel()
- if err := httpSrv.Shutdown(shutdownCtx); err != nil {
- logger.Warn("shutdown error", "err", err)
+ logger.Info("shutting down: draining workers...", "timeout", workerTimeout)
+
+ // Stop the HTTP server so no new requests come in.
+ httpCtx, httpCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer httpCancel()
+ if err := httpSrv.Shutdown(httpCtx); err != nil {
+ logger.Warn("http shutdown error", "err", err)
+ }
+
+ // Wait for in-flight task workers to finish.
+ workerCtx, workerCancel := context.WithTimeout(context.Background(), workerTimeout)
+ defer workerCancel()
+ if err := srv.Pool().Shutdown(workerCtx); err != nil {
+ logger.Warn("worker drain timed out", "err", err)
+ } else {
+ logger.Info("all workers finished cleanly")
}
}()
@@ -144,3 +229,4 @@ func serve(addr string) error {
}
return nil
}
+
diff --git a/internal/cli/status.go b/internal/cli/status.go
index 16b88b0..77a30d5 100644
--- a/internal/cli/status.go
+++ b/internal/cli/status.go
@@ -39,6 +39,9 @@ func showStatus(id string) error {
fmt.Printf("State: %s\n", t.State)
fmt.Printf("Priority: %s\n", t.Priority)
fmt.Printf("Model: %s\n", t.Agent.Model)
+ if t.Project != "" {
+ fmt.Printf("Project: %s\n", t.Project)
+ }
if t.Description != "" {
fmt.Printf("Description: %s\n", t.Description)
}
diff --git a/internal/cli/version.go b/internal/cli/version.go
new file mode 100644
index 0000000..789416a
--- /dev/null
+++ b/internal/cli/version.go
@@ -0,0 +1,18 @@
+package cli
+
+import (
+ "fmt"
+
+ "github.com/thepeterstone/claudomator/internal/version"
+ "github.com/spf13/cobra"
+)
+
+func newVersionCmd() *cobra.Command {
+ return &cobra.Command{
+ Use: "version",
+ Short: "Show the version of claudomator",
+ Run: func(cmd *cobra.Command, args []string) {
+ fmt.Printf("claudomator version %s\n", version.Version())
+ },
+ }
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index 5801239..25187cf 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
+ "time"
"github.com/BurntSushi/toml"
)
@@ -45,19 +46,28 @@ func (m LocalModel) UseForElaborate() bool {
}
type Config struct {
- DataDir string `toml:"data_dir"`
- DBPath string `toml:"-"`
- LogDir string `toml:"-"`
- ClaudeBinaryPath string `toml:"claude_binary_path"`
- GeminiBinaryPath string `toml:"gemini_binary_path"`
- MaxConcurrent int `toml:"max_concurrent"`
- DefaultTimeout string `toml:"default_timeout"`
- ServerAddr string `toml:"server_addr"`
- WebhookURL string `toml:"webhook_url"`
- WorkspaceRoot string `toml:"workspace_root"`
- WebhookSecret string `toml:"webhook_secret"`
- Projects []Project `toml:"projects"`
- LocalModel LocalModel `toml:"local_model"`
+ DataDir string `toml:"data_dir"`
+ DBPath string `toml:"-"`
+ LogDir string `toml:"-"`
+ DropsDir string `toml:"-"`
+ SSHAuthSock string `toml:"ssh_auth_sock"`
+ ClaudeBinaryPath string `toml:"claude_binary_path"`
+ GeminiBinaryPath string `toml:"gemini_binary_path"`
+ ClaudeImage string `toml:"claude_image"`
+ GeminiImage string `toml:"gemini_image"`
+ MaxConcurrent int `toml:"max_concurrent"`
+ ShutdownTimeout time.Duration `toml:"shutdown_timeout"`
+ DefaultTimeout string `toml:"default_timeout"`
+ ServerAddr string `toml:"server_addr"`
+ WebhookURL string `toml:"webhook_url"`
+ WorkspaceRoot string `toml:"workspace_root"`
+ WebhookSecret string `toml:"webhook_secret"`
+ Projects []Project `toml:"projects"`
+ VAPIDPublicKey string `toml:"vapid_public_key"`
+ VAPIDPrivateKey string `toml:"vapid_private_key"`
+ VAPIDEmail string `toml:"vapid_email"`
+ ClaudeConfigDir string `toml:"claude_config_dir"`
+ LocalModel LocalModel `toml:"local_model"`
}
func Default() (*Config, error) {
@@ -73,12 +83,17 @@ func Default() (*Config, error) {
DataDir: dataDir,
DBPath: filepath.Join(dataDir, "claudomator.db"),
LogDir: filepath.Join(dataDir, "executions"),
+ DropsDir: filepath.Join(dataDir, "drops"),
+ SSHAuthSock: os.Getenv("SSH_AUTH_SOCK"),
ClaudeBinaryPath: "claude",
GeminiBinaryPath: "gemini",
+ ClaudeImage: "claudomator-agent:latest",
+ GeminiImage: "claudomator-agent:latest",
MaxConcurrent: 3,
DefaultTimeout: "15m",
ServerAddr: ":8484",
WorkspaceRoot: "/workspace",
+ ClaudeConfigDir: "/workspace/claudomator/credentials/claude",
}, nil
}
@@ -97,7 +112,7 @@ func LoadFile(path string) (*Config, error) {
// EnsureDirs creates the data directory structure.
func (c *Config) EnsureDirs() error {
- for _, dir := range []string{c.DataDir, c.LogDir} {
+ for _, dir := range []string{c.DataDir, c.LogDir, c.DropsDir} {
if err := os.MkdirAll(dir, 0700); err != nil {
return err
}
diff --git a/internal/executor/claude.go b/internal/executor/claude.go
index fa68382..3c87f26 100644
--- a/internal/executor/claude.go
+++ b/internal/executor/claude.go
@@ -1,11 +1,8 @@
package executor
import (
- "bufio"
"context"
- "encoding/json"
"fmt"
- "io"
"log/slog"
"os"
"os/exec"
@@ -30,14 +27,6 @@ type ClaudeRunner struct {
// BlockedError is returned by Run when the agent wrote a question file and exited.
// The pool transitions the task to BLOCKED and stores the question for the user.
-type BlockedError struct {
- QuestionJSON string // raw JSON from the question file
- SessionID string // claude session to resume once the user answers
- SandboxDir string // preserved sandbox path; resume must run here so Claude finds its session files
-}
-
-func (e *BlockedError) Error() string { return fmt.Sprintf("task blocked: %s", e.QuestionJSON) }
-
// ExecLogDir returns the log directory for the given execution ID.
// Implements LogPather so the pool can persist paths before execution starts.
func (r *ClaudeRunner) ExecLogDir(execID string) string {
@@ -200,50 +189,6 @@ func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Executi
return nil
}
-// isCompletionReport returns true when a question-file JSON looks like a
-// completion report rather than a real user question. Heuristic: no options
-// (or empty options) and no "?" anywhere in the text.
-func isCompletionReport(questionJSON string) bool {
- var q struct {
- Text string `json:"text"`
- Options []string `json:"options"`
- }
- if err := json.Unmarshal([]byte(questionJSON), &q); err != nil {
- return false
- }
- return len(q.Options) == 0 && !strings.Contains(q.Text, "?")
-}
-
-// extractQuestionText returns the "text" field from a question-file JSON, or
-// the raw string if parsing fails.
-func extractQuestionText(questionJSON string) string {
- var q struct {
- Text string `json:"text"`
- }
- if err := json.Unmarshal([]byte(questionJSON), &q); err != nil {
- return questionJSON
- }
- return strings.TrimSpace(q.Text)
-}
-
-// gitSafe returns git arguments that prepend safety overrides so that
-// commands succeed regardless of the repository owner or the host's global
-// git configuration. Specifically:
-//
-// - "-c safe.directory=*" lets us operate on directories owned by a
-// different OS user.
-// - "-c commit.gpgsign=false" / "-c tag.gpgsign=false" stop git from
-// trying to sign commits via the host's signing tooling. Sandbox commits
-// are internal and don't need to be signed; an unconfigured or broken
-// signing setup on the host should never block a sandbox merge.
-func gitSafe(args ...string) []string {
- return append([]string{
- "-c", "safe.directory=*",
- "-c", "commit.gpgsign=false",
- "-c", "tag.gpgsign=false",
- }, args...)
-}
-
// sandboxCloneSource returns the URL to clone the sandbox from. It prefers a
// remote named "local" (a local bare repo that accepts pushes cleanly), then
// falls back to "origin", then to the working copy path itself.
@@ -497,7 +442,7 @@ func (r *ClaudeRunner) execOnce(ctx context.Context, args []string, workingDir,
wg.Add(1)
go func() {
defer wg.Done()
- costUSD, streamErr = parseStream(stdoutR, stdoutFile, r.Logger)
+ costUSD, _, streamErr = parseStream(stdoutR, stdoutFile, r.Logger)
stdoutR.Close()
}()
@@ -605,116 +550,3 @@ func (r *ClaudeRunner) buildArgs(t *task.Task, e *storage.Execution, questionFil
return args
}
-// parseStream reads streaming JSON from claude, writes to w, and returns
-// (costUSD, error). error is non-nil if the stream signals task failure:
-// - result message has is_error:true
-// - a tool_result was denied due to missing permissions
-func parseStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, error) {
- tee := io.TeeReader(r, w)
- scanner := bufio.NewScanner(tee)
- scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines
-
- var totalCost float64
- var streamErr error
-
- for scanner.Scan() {
- line := scanner.Bytes()
- var msg map[string]interface{}
- if err := json.Unmarshal(line, &msg); err != nil {
- continue
- }
-
- msgType, _ := msg["type"].(string)
- switch msgType {
- case "rate_limit_event":
- if info, ok := msg["rate_limit_info"].(map[string]interface{}); ok {
- status, _ := info["status"].(string)
- if status == "rejected" {
- streamErr = fmt.Errorf("claude rate limit reached (rejected): %v", msg)
- // Immediately break since we can't continue anyway
- break
- }
- }
- case "assistant":
- if errStr, ok := msg["error"].(string); ok && errStr == "rate_limit" {
- streamErr = fmt.Errorf("claude rate limit reached: %v", msg)
- }
- case "result":
- if isErr, _ := msg["is_error"].(bool); isErr {
- result, _ := msg["result"].(string)
- if result != "" {
- streamErr = fmt.Errorf("claude task failed: %s", result)
- } else {
- streamErr = fmt.Errorf("claude task failed (is_error=true in result)")
- }
- }
- // Prefer total_cost_usd from result message; fall through to legacy check below.
- if cost, ok := msg["total_cost_usd"].(float64); ok {
- totalCost = cost
- }
- case "user":
- // Detect permission-denial tool_results. These occur when permission_mode
- // is not bypassPermissions and claude exits 0 without completing its task.
- if err := permissionDenialError(msg); err != nil && streamErr == nil {
- streamErr = err
- }
- }
-
- // Legacy cost field used by older claude versions.
- if cost, ok := msg["cost_usd"].(float64); ok {
- totalCost = cost
- }
- }
-
- return totalCost, streamErr
-}
-
-// permissionDenialError inspects a "user" stream message for tool_result entries
-// that were denied due to missing permissions. Returns an error if found.
-func permissionDenialError(msg map[string]interface{}) error {
- message, ok := msg["message"].(map[string]interface{})
- if !ok {
- return nil
- }
- content, ok := message["content"].([]interface{})
- if !ok {
- return nil
- }
- for _, item := range content {
- itemMap, ok := item.(map[string]interface{})
- if !ok {
- continue
- }
- if itemMap["type"] != "tool_result" {
- continue
- }
- if isErr, _ := itemMap["is_error"].(bool); !isErr {
- continue
- }
- text, _ := itemMap["content"].(string)
- if strings.Contains(text, "requested permissions") || strings.Contains(text, "haven't granted") {
- return fmt.Errorf("permission denied by host: %s", text)
- }
- }
- return nil
-}
-
-// tailFile returns the last n lines of the file at path, or empty string if
-// the file cannot be read. Used to surface subprocess stderr on failure.
-func tailFile(path string, n int) string {
- f, err := os.Open(path)
- if err != nil {
- return ""
- }
- defer f.Close()
-
- var lines []string
- scanner := bufio.NewScanner(f)
- for scanner.Scan() {
- lines = append(lines, scanner.Text())
- if len(lines) > n {
- lines = lines[1:]
- }
- }
- return strings.Join(lines, "\n")
-}
diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go
index cbb5947..c01e160 100644
--- a/internal/executor/claude_test.go
+++ b/internal/executor/claude_test.go
@@ -2,7 +2,6 @@ package executor
import (
"context"
- "errors"
"fmt"
"io"
"log/slog"
@@ -697,57 +696,6 @@ func TestTeardownSandbox_CleanSandboxWithNoNewCommits_RemovesSandbox(t *testing.
}
}
-// TestBlockedError_IncludesSandboxDir verifies that when a task is blocked in a
-// sandbox, the BlockedError carries the sandbox path so the resume execution can
-// run in the same directory (where Claude's session files are stored).
-func TestBlockedError_IncludesSandboxDir(t *testing.T) {
- src := t.TempDir()
- initGitRepo(t, src)
-
- logDir := t.TempDir()
-
- // Use a script that writes question.json to the env-var path and exits 0
- // (simulating a blocked agent that asks a question before exiting).
- scriptPath := filepath.Join(t.TempDir(), "fake-claude.sh")
- if err := os.WriteFile(scriptPath, []byte(`#!/bin/sh
-if [ -n "$CLAUDOMATOR_QUESTION_FILE" ]; then
- printf '{"text":"Should I continue?"}' > "$CLAUDOMATOR_QUESTION_FILE"
-fi
-`), 0755); err != nil {
- t.Fatalf("write script: %v", err)
- }
-
- r := &ClaudeRunner{
- BinaryPath: scriptPath,
- Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
- LogDir: logDir,
- }
- tk := &task.Task{
- Agent: task.AgentConfig{
- Type: "claude",
- Instructions: "do something",
- ProjectDir: src,
- SkipPlanning: true,
- },
- }
- exec := &storage.Execution{ID: "blocked-exec-uuid", TaskID: "task-1"}
-
- err := r.Run(context.Background(), tk, exec)
-
- var blocked *BlockedError
- if !errors.As(err, &blocked) {
- t.Fatalf("expected BlockedError, got: %v", err)
- }
- if blocked.SandboxDir == "" {
- t.Error("BlockedError.SandboxDir should be set when task runs in a sandbox")
- }
- // Sandbox should still exist (preserved for resume).
- if _, statErr := os.Stat(blocked.SandboxDir); os.IsNotExist(statErr) {
- t.Error("sandbox directory should be preserved when blocked")
- } else {
- os.RemoveAll(blocked.SandboxDir) // cleanup
- }
-}
// TestClaudeRunner_Run_ResumeUsesStoredSandboxDir verifies that when a resume
// execution has SandboxDir set, the runner uses that directory (not project_dir)
@@ -853,69 +801,6 @@ func TestClaudeRunner_Run_StaleSandboxDir_ClonesAfresh(t *testing.T) {
}
}
-func TestIsCompletionReport(t *testing.T) {
- tests := []struct {
- name string
- json string
- expected bool
- }{
- {
- name: "real question with options",
- json: `{"text": "Should I proceed with implementation?", "options": ["Yes", "No"]}`,
- expected: false,
- },
- {
- name: "real question no options",
- json: `{"text": "Which approach do you prefer?"}`,
- expected: false,
- },
- {
- name: "completion report no options no question mark",
- json: `{"text": "All tests pass. Implementation complete. Summary written to CLAUDOMATOR_SUMMARY_FILE."}`,
- expected: true,
- },
- {
- name: "completion report with empty options",
- json: `{"text": "Feature implemented and committed.", "options": []}`,
- expected: true,
- },
- {
- name: "invalid json treated as not a report",
- json: `not json`,
- expected: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := isCompletionReport(tt.json)
- if got != tt.expected {
- t.Errorf("isCompletionReport(%q) = %v, want %v", tt.json, got, tt.expected)
- }
- })
- }
-}
-
-func TestTailFile_ReturnsLastNLines(t *testing.T) {
- f, err := os.CreateTemp("", "tailfile-*")
- if err != nil {
- t.Fatal(err)
- }
- defer os.Remove(f.Name())
- for i := 1; i <= 30; i++ {
- fmt.Fprintf(f, "line %d\n", i)
- }
- f.Close()
-
- got := tailFile(f.Name(), 5)
- lines := strings.Split(got, "\n")
- if len(lines) != 5 {
- t.Fatalf("want 5 lines, got %d: %q", len(lines), got)
- }
- if lines[0] != "line 26" || lines[4] != "line 30" {
- t.Errorf("want lines 26-30, got: %q", got)
- }
-}
-
func TestTailFile_MissingFile_ReturnsEmpty(t *testing.T) {
got := tailFile("/nonexistent/path/file.log", 10)
if got != "" {
@@ -923,15 +808,3 @@ func TestTailFile_MissingFile_ReturnsEmpty(t *testing.T) {
}
}
-func TestGitSafe_PrependsSafeDirectory(t *testing.T) {
- got := gitSafe("-C", "/some/path", "status")
- want := []string{"-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-c", "tag.gpgsign=false", "-C", "/some/path", "status"}
- if len(got) != len(want) {
- t.Fatalf("gitSafe() = %v, want %v", got, want)
- }
- for i := range want {
- if got[i] != want[i] {
- t.Errorf("gitSafe()[%d] = %q, want %q", i, got[i], want[i])
- }
- }
-}
diff --git a/internal/executor/container.go b/internal/executor/container.go
new file mode 100644
index 0000000..61ac29c
--- /dev/null
+++ b/internal/executor/container.go
@@ -0,0 +1,549 @@
+package executor
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+
+ "github.com/thepeterstone/claudomator/internal/storage"
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+// ContainerRunner executes an agent inside a container.
+type ContainerRunner struct {
+ Image string // default image if not specified in task
+ Logger *slog.Logger
+ LogDir string
+ APIURL string
+ DropsDir string
+ SSHAuthSock string // optional path to host SSH agent
+ ClaudeBinary string // optional path to claude binary in container
+ GeminiBinary string // optional path to gemini binary in container
+ ClaudeConfigDir string // host path to ~/.claude; mounted into container for auth credentials
+ CredentialSyncCmd string // optional path to sync-credentials script for auth-error auto-recovery
+ Store Store // optional; used to look up stories and projects for story-aware cloning
+ // Command allows mocking exec.CommandContext for tests.
+ Command func(ctx context.Context, name string, arg ...string) *exec.Cmd
+}
+
+func isAuthError(err error) bool {
+ if err == nil {
+ return false
+ }
+ s := err.Error()
+ return strings.Contains(s, "Not logged in") ||
+ strings.Contains(s, "OAuth token has expired") ||
+ strings.Contains(s, "authentication_error") ||
+ strings.Contains(s, "Please run /login")
+}
+
+func (r *ContainerRunner) command(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if r.Command != nil {
+ return r.Command(ctx, name, arg...)
+ }
+ return exec.CommandContext(ctx, name, arg...)
+}
+
+func (r *ContainerRunner) ExecLogDir(execID string) string {
+ if r.LogDir == "" {
+ return ""
+ }
+ return filepath.Join(r.LogDir, execID)
+}
+
+// ensureStoryBranch checks whether branchName exists in remoteURL and creates
+// it from main if not. Uses localPath as a reference clone for speed if set.
+func (r *ContainerRunner) ensureStoryBranch(ctx context.Context, remoteURL, branchName, localPath string) error {
+ // Check if branch already exists.
+ out, err := r.command(ctx, "git", "ls-remote", "--heads", remoteURL, branchName).CombinedOutput()
+ if err == nil && len(strings.TrimSpace(string(out))) > 0 {
+ return nil // already exists
+ }
+
+ r.Logger.Info("story branch missing, creating from main", "branch", branchName, "remote", remoteURL)
+
+ // Clone into a temp dir so we can create the branch.
+ tmp, err := os.MkdirTemp("", "claudomator-branchsetup-*")
+ if err != nil {
+ return fmt.Errorf("mktemp for branch setup: %w", err)
+ }
+ defer os.RemoveAll(tmp)
+
+ // Remove the dir git clone expects to create.
+ if err := os.Remove(tmp); err != nil {
+ return fmt.Errorf("removing tmp dir before clone: %w", err)
+ }
+
+ var cloneArgs []string
+ if localPath != "" {
+ cloneArgs = []string{"clone", "--reference", localPath, remoteURL, tmp}
+ } else {
+ cloneArgs = []string{"clone", remoteURL, tmp}
+ }
+ if out, err := r.command(ctx, "git", cloneArgs...).CombinedOutput(); err != nil {
+ return fmt.Errorf("git clone for branch setup: %w\n%s", err, string(out))
+ }
+ if out, err := r.command(ctx, "git", "-C", tmp, "checkout", "-b", branchName).CombinedOutput(); err != nil {
+ return fmt.Errorf("git checkout -b %q: %w\n%s", branchName, err, string(out))
+ }
+ if out, err := r.command(ctx, "git", "-C", tmp, "push", "origin", branchName).CombinedOutput(); err != nil {
+ return fmt.Errorf("git push %q: %w\n%s", branchName, err, string(out))
+ }
+ r.Logger.Info("story branch created and pushed", "branch", branchName)
+ return nil
+}
+
+func (r *ContainerRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error {
+ var err error
+ repoURL := t.RepositoryURL
+ if repoURL == "" {
+ return fmt.Errorf("task %s has no repository_url", t.ID)
+ }
+
+ // 1. Setup workspace on host
+ var workspace string
+ isResume := false
+ if e.SandboxDir != "" {
+ if _, err = os.Stat(e.SandboxDir); err == nil {
+ workspace = e.SandboxDir
+ isResume = true
+ r.Logger.Info("resuming in preserved workspace", "path", workspace)
+ }
+ }
+
+ if workspace == "" {
+ workspace, err = os.MkdirTemp("", "claudomator-workspace-*")
+ if err != nil {
+ return fmt.Errorf("creating workspace: %w", err)
+ }
+ // chmod applied after clone; see step 2.
+ }
+
+ // Note: workspace is only removed on success. On failure, it's preserved for debugging.
+ // If the task becomes BLOCKED, it's also preserved for resumption.
+ success := false
+ isBlocked := false
+ defer func() {
+ if success && !isBlocked {
+ os.RemoveAll(workspace)
+ } else {
+ r.Logger.Warn("preserving workspace", "path", workspace, "success", success, "blocked", isBlocked)
+ }
+ }()
+
+ // Resolve story branch and project local path if this is a story task.
+ var storyBranch string
+ var storyLocalPath string
+ if t.StoryID != "" && r.Store != nil {
+ if story, err := r.Store.GetStory(t.StoryID); err == nil && story != nil {
+ storyBranch = story.BranchName
+ if story.ProjectID != "" {
+ if proj, err := r.Store.GetProject(story.ProjectID); err == nil && proj != nil {
+ storyLocalPath = proj.LocalPath
+ }
+ }
+ }
+ }
+ // Fall back to task-level BranchName (e.g. set explicitly by executor or tests).
+ if storyBranch == "" {
+ storyBranch = t.BranchName
+ }
+
+ // 2. Ensure story branch exists in the remote before cloning.
+ // If the branch is missing (e.g. story approved before fix, or branch push failed),
+ // create it from main using the project local path as a reference repo.
+ if storyBranch != "" && !isResume {
+ if err := r.ensureStoryBranch(ctx, repoURL, storyBranch, storyLocalPath); err != nil {
+ r.Logger.Warn("ensureStoryBranch failed (will attempt checkout anyway)", "branch", storyBranch, "error", err)
+ }
+ }
+
+ // 3. Clone repo into workspace if not resuming.
+ // git clone requires the target directory to not exist; remove the MkdirTemp-created dir first.
+ if !isResume {
+ if err := os.Remove(workspace); err != nil {
+ return fmt.Errorf("removing workspace before clone: %w", err)
+ }
+ r.Logger.Info("cloning repository", "url", repoURL, "workspace", workspace)
+ var cloneArgs []string
+ if storyLocalPath != "" {
+ cloneArgs = []string{"clone", "--reference", storyLocalPath, repoURL, workspace}
+ } else {
+ cloneArgs = []string{"clone", repoURL, workspace}
+ }
+ if out, err := r.command(ctx, "git", cloneArgs...).CombinedOutput(); err != nil {
+ return fmt.Errorf("git clone failed: %w\n%s", err, string(out))
+ }
+ if storyBranch != "" {
+ r.Logger.Info("checking out story branch", "branch", storyBranch)
+ if out, err := r.command(ctx, "git", "-C", workspace, "checkout", storyBranch).CombinedOutput(); err != nil {
+ return fmt.Errorf("git checkout story branch %q failed: %w\n%s", storyBranch, err, string(out))
+ }
+ }
+ if err = os.Chmod(workspace, 0755); err != nil {
+ return fmt.Errorf("chmod cloned workspace: %w", err)
+ }
+ }
+ e.SandboxDir = workspace
+
+ // Set up a writable $HOME staging dir so any agent tool (claude, gemini, etc.)
+ // can freely create subdirs (session-env, .gemini, .cache, …) without hitting
+ // a non-existent or read-only home. We copy only the claude credentials into it.
+ agentHome := filepath.Join(workspace, ".agent-home")
+ if err := os.MkdirAll(filepath.Join(agentHome, ".claude"), 0755); err != nil {
+ return fmt.Errorf("creating agent home staging dir: %w", err)
+ }
+ if err := os.MkdirAll(filepath.Join(agentHome, ".gemini"), 0755); err != nil {
+ return fmt.Errorf("creating .gemini dir: %w", err)
+ }
+ if r.ClaudeConfigDir != "" {
+ // credentials
+ if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".credentials.json")); readErr == nil {
+ _ = os.WriteFile(filepath.Join(agentHome, ".claude", ".credentials.json"), srcData, 0600)
+ }
+ // settings (used by claude CLI; copy so it can write updates without hitting the host)
+ if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".claude.json")); readErr == nil {
+ _ = os.WriteFile(filepath.Join(agentHome, ".claude.json"), srcData, 0644)
+ }
+ }
+
+ // Pre-flight: verify credentials were actually copied before spinning up a container.
+ if r.ClaudeConfigDir != "" {
+ credsPath := filepath.Join(agentHome, ".claude", ".credentials.json")
+ settingsPath := filepath.Join(agentHome, ".claude.json")
+ if _, err := os.Stat(credsPath); os.IsNotExist(err) {
+ return fmt.Errorf("credentials not found at %s — run sync-credentials", r.ClaudeConfigDir)
+ }
+ if _, err := os.Stat(settingsPath); os.IsNotExist(err) {
+ return fmt.Errorf("claude settings (.claude.json) not found at %s — run sync-credentials", r.ClaudeConfigDir)
+ }
+ }
+
+ // Run container (with auth retry on failure).
+ runErr := r.runContainer(ctx, t, e, workspace, agentHome, isResume, storyBranch)
+ if runErr != nil && isAuthError(runErr) && r.CredentialSyncCmd != "" {
+ r.Logger.Warn("auth failure detected, syncing credentials and retrying once", "taskID", t.ID)
+ syncOut, syncErr := r.command(ctx, r.CredentialSyncCmd).CombinedOutput()
+ if syncErr != nil {
+ r.Logger.Warn("sync-credentials failed", "error", syncErr, "output", string(syncOut))
+ }
+ // Re-copy credentials into agentHome with fresh files.
+ if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".credentials.json")); readErr == nil {
+ _ = os.WriteFile(filepath.Join(agentHome, ".claude", ".credentials.json"), srcData, 0600)
+ }
+ if srcData, readErr := os.ReadFile(filepath.Join(r.ClaudeConfigDir, ".claude.json")); readErr == nil {
+ _ = os.WriteFile(filepath.Join(agentHome, ".claude.json"), srcData, 0644)
+ }
+ runErr = r.runContainer(ctx, t, e, workspace, agentHome, isResume, storyBranch)
+ }
+
+ if runErr == nil {
+ success = true
+ }
+ var blockedErr *BlockedError
+ if errors.As(runErr, &blockedErr) {
+ isBlocked = true
+ success = true // preserve workspace for resumption
+ }
+ return runErr
+}
+
+// runContainer runs the docker container for the given task and handles log setup,
+// environment files, instructions, and post-execution git operations.
+func (r *ContainerRunner) runContainer(ctx context.Context, t *task.Task, e *storage.Execution, workspace, agentHome string, isResume bool, storyBranch string) error {
+ repoURL := t.RepositoryURL
+
+ image := t.Agent.ContainerImage
+ if image == "" {
+ image = r.Image
+ }
+ if image == "" {
+ image = "claudomator-agent:latest"
+ }
+
+ // 3. Prepare logs
+ logDir := r.ExecLogDir(e.ID)
+ if logDir == "" {
+ logDir = filepath.Join(workspace, ".claudomator-logs")
+ }
+ if err := os.MkdirAll(logDir, 0700); err != nil {
+ return fmt.Errorf("creating log dir: %w", err)
+ }
+ e.StdoutPath = filepath.Join(logDir, "stdout.log")
+ e.StderrPath = filepath.Join(logDir, "stderr.log")
+ e.ArtifactDir = logDir
+
+ stdoutFile, err := os.Create(e.StdoutPath)
+ if err != nil {
+ return fmt.Errorf("creating stdout log: %w", err)
+ }
+ defer stdoutFile.Close()
+
+ stderrFile, err := os.Create(e.StderrPath)
+ if err != nil {
+ return fmt.Errorf("creating stderr log: %w", err)
+ }
+ defer stderrFile.Close()
+
+ // 4. Run container
+
+ // Write API keys to a temporary env file to avoid exposure in 'ps' or 'docker inspect'
+ envFile := filepath.Join(workspace, ".claudomator-env")
+ envContent := fmt.Sprintf("ANTHROPIC_API_KEY=%s\nGOOGLE_API_KEY=%s\nGEMINI_API_KEY=%s\n", os.Getenv("ANTHROPIC_API_KEY"), os.Getenv("GOOGLE_API_KEY"), os.Getenv("GEMINI_API_KEY"))
+ if err := os.WriteFile(envFile, []byte(envContent), 0600); err != nil {
+ return fmt.Errorf("writing env file: %w", err)
+ }
+
+ // Inject custom instructions via file to avoid CLI length limits
+ instructionsFile := filepath.Join(workspace, ".claudomator-instructions.txt")
+ if err := os.WriteFile(instructionsFile, []byte(t.Agent.Instructions), 0644); err != nil {
+ return fmt.Errorf("writing instructions: %w", err)
+ }
+
+ args := r.buildDockerArgs(workspace, agentHome, e.TaskID)
+ innerCmd := r.buildInnerCmd(t, e, isResume)
+
+ fullArgs := append(args, image)
+ fullArgs = append(fullArgs, innerCmd...)
+
+ r.Logger.Info("starting container", "image", image, "taskID", t.ID)
+ cmd := r.command(ctx, "docker", fullArgs...)
+ cmd.Stderr = stderrFile
+ cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+
+ // Use os.Pipe for stdout so we can parse it in real-time
+ var stdoutR, stdoutW *os.File
+ stdoutR, stdoutW, err = os.Pipe()
+ if err != nil {
+ return fmt.Errorf("creating stdout pipe: %w", err)
+ }
+ cmd.Stdout = stdoutW
+
+ if err := cmd.Start(); err != nil {
+ stdoutW.Close()
+ stdoutR.Close()
+ return fmt.Errorf("starting container: %w", err)
+ }
+ stdoutW.Close()
+
+ // Watch for context cancellation to kill the process group (Issue 1)
+ done := make(chan struct{})
+ defer close(done)
+ go func() {
+ select {
+ case <-ctx.Done():
+ r.Logger.Info("killing container process group due to context cancellation", "taskID", t.ID)
+ syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
+ case <-done:
+ }
+ }()
+
+ // Stream stdout to the log file and parse cost/errors.
+ var costUSD float64
+ var sessionID string
+ var streamErr error
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ costUSD, sessionID, streamErr = parseStream(stdoutR, stdoutFile, r.Logger)
+ stdoutR.Close()
+ }()
+
+ waitErr := cmd.Wait()
+ wg.Wait()
+
+ e.CostUSD = costUSD
+ if sessionID != "" {
+ e.SessionID = sessionID
+ }
+
+ // Check whether the agent left a question before exiting.
+ questionFile := filepath.Join(logDir, "question.json")
+ if data, readErr := os.ReadFile(questionFile); readErr == nil {
+ os.Remove(questionFile) // consumed
+ questionJSON := strings.TrimSpace(string(data))
+ if isCompletionReport(questionJSON) {
+ r.Logger.Info("treating question file as completion report", "taskID", e.TaskID)
+ e.Summary = extractQuestionText(questionJSON)
+ } else {
+ if e.SessionID == "" {
+ r.Logger.Warn("missing session ID; resume will start fresh", "taskID", e.TaskID)
+ }
+ return &BlockedError{
+ QuestionJSON: questionJSON,
+ SessionID: e.SessionID,
+ SandboxDir: workspace,
+ }
+ }
+ }
+
+ // Read agent summary if written.
+ summaryFile := filepath.Join(logDir, "summary.txt")
+ if summaryData, readErr := os.ReadFile(summaryFile); readErr == nil {
+ os.Remove(summaryFile) // consumed
+ e.Summary = strings.TrimSpace(string(summaryData))
+ }
+
+ // 5. Post-execution: push changes if successful
+ if waitErr == nil && streamErr == nil {
+ // Check if there are any commits to push (HEAD ahead of origin/HEAD).
+ // If origin/HEAD doesn't exist (e.g. fresh clone with no commits), we attempt push anyway.
+ hasCommits := true
+ if out, err := r.command(ctx, "git", "-C", workspace, "rev-list", "origin/HEAD..HEAD").CombinedOutput(); err == nil {
+ if len(strings.TrimSpace(string(out))) == 0 {
+ hasCommits = false
+ }
+ }
+
+ if hasCommits {
+ pushRef := "HEAD"
+ if storyBranch != "" {
+ pushRef = storyBranch
+ }
+ r.Logger.Info("pushing changes back to remote", "url", repoURL, "ref", pushRef)
+ if out, err := r.command(ctx, "git", "-C", workspace, "push", "origin", pushRef).CombinedOutput(); err != nil {
+ r.Logger.Warn("git push failed", "error", err, "output", string(out))
+ return fmt.Errorf("git push failed: %w\n%s", err, string(out))
+ }
+ } else {
+ // No commits pushed — check whether the agent left uncommitted work behind.
+ // If so, fail loudly: the work would be silently lost when the sandbox is deleted.
+ if err := detectUncommittedChanges(workspace); err != nil {
+ return err
+ }
+ r.Logger.Info("no new commits to push", "taskID", t.ID)
+ }
+ }
+
+ if waitErr != nil {
+ // Append the tail of stderr so error classifiers (isQuotaExhausted, isRateLimitError)
+ // can inspect agent-specific messages (e.g. Gemini TerminalQuotaError).
+ stderrTail := readFileTail(e.StderrPath, 4096)
+ if stderrTail != "" {
+ return fmt.Errorf("container execution failed: %w\n%s", waitErr, stderrTail)
+ }
+ return fmt.Errorf("container execution failed: %w", waitErr)
+ }
+ if streamErr != nil {
+ return fmt.Errorf("stream parsing failed: %w", streamErr)
+ }
+
+ return nil
+}
+
+func (r *ContainerRunner) buildDockerArgs(workspace, claudeHome, taskID string) []string {
+ // --env-file takes a HOST path.
+ hostEnvFile := filepath.Join(workspace, ".claudomator-env")
+
+ // Replace localhost with host.docker.internal so the container can reach the host API.
+ apiURL := strings.ReplaceAll(r.APIURL, "localhost", "host.docker.internal")
+
+ args := []string{
+ "run", "--rm",
+ // Allow container to reach the host via host.docker.internal.
+ "--add-host=host.docker.internal:host-gateway",
+ // Run as the current process UID:GID so the container can read host-owned files.
+ fmt.Sprintf("--user=%d:%d", os.Getuid(), os.Getgid()),
+ "-v", workspace + ":/workspace",
+ "-v", claudeHome + ":/home/agent",
+ "-w", "/workspace",
+ "--env-file", hostEnvFile,
+ "-e", "HOME=/home/agent",
+ "-e", "CLAUDOMATOR_API_URL=" + apiURL,
+ "-e", "CLAUDOMATOR_TASK_ID=" + taskID,
+ "-e", "CLAUDOMATOR_DROP_DIR=" + r.DropsDir,
+ }
+ if r.SSHAuthSock != "" {
+ args = append(args, "-v", r.SSHAuthSock+":/tmp/ssh-auth.sock", "-e", "SSH_AUTH_SOCK=/tmp/ssh-auth.sock")
+ }
+ return args
+}
+
+func (r *ContainerRunner) buildInnerCmd(t *task.Task, e *storage.Execution, isResume bool) []string {
+ // Claude CLI uses -p for prompt text. To pass a file, we use a shell to cat it.
+ // We use a shell variable to capture the expansion to avoid quoting issues with instructions contents.
+ // The outer single quotes around the sh -c argument prevent host-side expansion.
+
+ claudeBin := r.ClaudeBinary
+ if claudeBin == "" {
+ claudeBin = "claude"
+ }
+ geminiBin := r.GeminiBinary
+ if geminiBin == "" {
+ geminiBin = "gemini"
+ }
+
+ if t.Agent.Type == "gemini" {
+ return []string{"sh", "-c", fmt.Sprintf("INST=$(cat /workspace/.claudomator-instructions.txt); %s -p \"$INST\"", geminiBin)}
+ }
+
+ // Claude
+ var claudeCmd strings.Builder
+ claudeCmd.WriteString(fmt.Sprintf("INST=$(cat /workspace/.claudomator-instructions.txt); %s -p \"$INST\"", claudeBin))
+ if isResume && e.ResumeSessionID != "" {
+ claudeCmd.WriteString(fmt.Sprintf(" --resume %s", e.ResumeSessionID))
+ }
+ claudeCmd.WriteString(" --output-format stream-json --verbose --permission-mode bypassPermissions")
+
+ return []string{"sh", "-c", claudeCmd.String()}
+}
+
+// scaffoldPrefixes are files/dirs written by the harness into the workspace before the agent
+// runs. They are not part of the repo and must not trigger the uncommitted-changes check.
+var scaffoldPrefixes = []string{
+ ".claudomator-env",
+ ".claudomator-instructions.txt",
+ ".agent-home",
+}
+
+func isScaffold(path string) bool {
+ for _, p := range scaffoldPrefixes {
+ if path == p || strings.HasPrefix(path, p+"/") {
+ return true
+ }
+ }
+ return false
+}
+
+// detectUncommittedChanges returns an error if the workspace contains modified or
+// untracked source files that the agent forgot to commit. Scaffold files written by
+// the harness (.claudomator-env, .claudomator-instructions.txt, .agent-home/) are
+// excluded from the check.
+func detectUncommittedChanges(workspace string) error {
+ // Modified or staged tracked files
+ diffOut, err := exec.Command("git", "-c", "safe.directory=*", "-C", workspace,
+ "diff", "--name-only", "HEAD").CombinedOutput()
+ if err == nil {
+ for _, line := range strings.Split(strings.TrimSpace(string(diffOut)), "\n") {
+ if line != "" && !isScaffold(line) {
+ return fmt.Errorf("agent left uncommitted changes (work would be lost on sandbox deletion):\n%s\nInstructions must include: git add -A && git commit && git push origin main", strings.TrimSpace(string(diffOut)))
+ }
+ }
+ }
+
+ // Untracked new source files (excludes gitignored files)
+ lsOut, err := exec.Command("git", "-c", "safe.directory=*", "-C", workspace,
+ "ls-files", "--others", "--exclude-standard").CombinedOutput()
+ if err == nil {
+ var dirty []string
+ for _, line := range strings.Split(strings.TrimSpace(string(lsOut)), "\n") {
+ if line != "" && !isScaffold(line) {
+ dirty = append(dirty, line)
+ }
+ }
+ if len(dirty) > 0 {
+ return fmt.Errorf("agent left untracked files not committed (work would be lost on sandbox deletion):\n%s\nInstructions must include: git add -A && git commit && git push origin main", strings.Join(dirty, "\n"))
+ }
+ }
+
+ return nil
+}
+
diff --git a/internal/executor/container_test.go b/internal/executor/container_test.go
new file mode 100644
index 0000000..f0b2a3a
--- /dev/null
+++ b/internal/executor/container_test.go
@@ -0,0 +1,687 @@
+package executor
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/thepeterstone/claudomator/internal/storage"
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+func TestContainerRunner_BuildDockerArgs(t *testing.T) {
+ runner := &ContainerRunner{
+ APIURL: "http://localhost:8484",
+ DropsDir: "/data/drops",
+ SSHAuthSock: "/tmp/ssh.sock",
+ }
+ workspace := "/tmp/ws"
+ taskID := "task-123"
+
+ agentHome := "/tmp/ws/.agent-home"
+ args := runner.buildDockerArgs(workspace, agentHome, taskID)
+
+ expected := []string{
+ "run", "--rm",
+ "--add-host=host.docker.internal:host-gateway",
+ fmt.Sprintf("--user=%d:%d", os.Getuid(), os.Getgid()),
+ "-v", "/tmp/ws:/workspace",
+ "-v", "/tmp/ws/.agent-home:/home/agent",
+ "-w", "/workspace",
+ "--env-file", "/tmp/ws/.claudomator-env",
+ "-e", "HOME=/home/agent",
+ "-e", "CLAUDOMATOR_API_URL=http://host.docker.internal:8484",
+ "-e", "CLAUDOMATOR_TASK_ID=task-123",
+ "-e", "CLAUDOMATOR_DROP_DIR=/data/drops",
+ "-v", "/tmp/ssh.sock:/tmp/ssh-auth.sock",
+ "-e", "SSH_AUTH_SOCK=/tmp/ssh-auth.sock",
+ }
+
+ if len(args) != len(expected) {
+ t.Fatalf("expected %d args, got %d. Got: %v", len(expected), len(args), args)
+ }
+ for i, v := range args {
+ if v != expected[i] {
+ t.Errorf("arg %d: expected %q, got %q", i, expected[i], v)
+ }
+ }
+}
+
+func TestContainerRunner_BuildInnerCmd(t *testing.T) {
+ runner := &ContainerRunner{}
+
+ t.Run("claude-fresh", func(t *testing.T) {
+ tk := &task.Task{Agent: task.AgentConfig{Type: "claude"}}
+ exec := &storage.Execution{}
+ cmd := runner.buildInnerCmd(tk, exec, false)
+
+ cmdStr := strings.Join(cmd, " ")
+ if strings.Contains(cmdStr, "--resume") {
+ t.Errorf("unexpected --resume flag in fresh run: %q", cmdStr)
+ }
+ if !strings.Contains(cmdStr, "INST=$(cat /workspace/.claudomator-instructions.txt); claude -p \"$INST\"") {
+ t.Errorf("expected cat instructions in sh command, got %q", cmdStr)
+ }
+ })
+
+ t.Run("claude-resume", func(t *testing.T) {
+ tk := &task.Task{Agent: task.AgentConfig{Type: "claude"}}
+ exec := &storage.Execution{ResumeSessionID: "orig-session-123"}
+ cmd := runner.buildInnerCmd(tk, exec, true)
+
+ cmdStr := strings.Join(cmd, " ")
+ if !strings.Contains(cmdStr, "--resume orig-session-123") {
+ t.Errorf("expected --resume flag with correct session ID, got %q", cmdStr)
+ }
+ })
+
+ t.Run("gemini", func(t *testing.T) {
+ tk := &task.Task{Agent: task.AgentConfig{Type: "gemini"}}
+ exec := &storage.Execution{}
+ cmd := runner.buildInnerCmd(tk, exec, false)
+
+ cmdStr := strings.Join(cmd, " ")
+ if !strings.Contains(cmdStr, "gemini -p \"$INST\"") {
+ t.Errorf("expected gemini command with safer quoting, got %q", cmdStr)
+ }
+ })
+
+ t.Run("custom-binaries", func(t *testing.T) {
+ runnerCustom := &ContainerRunner{
+ ClaudeBinary: "/usr/bin/claude-v2",
+ GeminiBinary: "/usr/local/bin/gemini-pro",
+ }
+
+ tkClaude := &task.Task{Agent: task.AgentConfig{Type: "claude"}}
+ cmdClaude := runnerCustom.buildInnerCmd(tkClaude, &storage.Execution{}, false)
+ if !strings.Contains(strings.Join(cmdClaude, " "), "/usr/bin/claude-v2 -p") {
+ t.Errorf("expected custom claude binary, got %q", cmdClaude)
+ }
+
+ tkGemini := &task.Task{Agent: task.AgentConfig{Type: "gemini"}}
+ cmdGemini := runnerCustom.buildInnerCmd(tkGemini, &storage.Execution{}, false)
+ if !strings.Contains(strings.Join(cmdGemini, " "), "/usr/local/bin/gemini-pro -p") {
+ t.Errorf("expected custom gemini binary, got %q", cmdGemini)
+ }
+ })
+}
+
+func TestContainerRunner_Run_PreservesWorkspaceOnFailure(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ runner := &ContainerRunner{
+ Logger: logger,
+ Image: "busybox",
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ // Mock docker run to exit 1
+ if name == "docker" {
+ return exec.Command("sh", "-c", "exit 1")
+ }
+ // Mock git clone to succeed and create the directory
+ if name == "git" && len(arg) > 0 && arg[0] == "clone" {
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ return exec.Command("true")
+ }
+ return exec.Command("true")
+ },
+ }
+
+ tk := &task.Task{
+ ID: "test-task",
+ RepositoryURL: "https://github.com/example/repo.git",
+ Agent: task.AgentConfig{Type: "claude"},
+ }
+ exec := &storage.Execution{ID: "test-exec", TaskID: "test-task"}
+
+ err := runner.Run(context.Background(), tk, exec)
+ if err == nil {
+ t.Fatal("expected error due to mocked docker failure")
+ }
+
+ // Verify SandboxDir was set and directory exists.
+ if exec.SandboxDir == "" {
+ t.Fatal("expected SandboxDir to be set even on failure")
+ }
+ if _, statErr := os.Stat(exec.SandboxDir); statErr != nil {
+ t.Errorf("expected sandbox directory to be preserved, but stat failed: %v", statErr)
+ } else {
+ os.RemoveAll(exec.SandboxDir)
+ }
+}
+
+func TestBlockedError_IncludesSandboxDir(t *testing.T) {
+ // This test requires mocking 'docker run' or the whole Run() which is hard.
+ // But we can test that returning BlockedError works.
+ err := &BlockedError{
+ QuestionJSON: `{"text":"?"}`,
+ SessionID: "s1",
+ SandboxDir: "/tmp/s1",
+ }
+ if !strings.Contains(err.Error(), "task blocked") {
+ t.Errorf("wrong error message: %v", err)
+ }
+}
+
+func TestIsCompletionReport(t *testing.T) {
+ tests := []struct {
+ name string
+ json string
+ expected bool
+ }{
+ {
+ name: "real question with options",
+ json: `{"text": "Should I proceed with implementation?", "options": ["Yes", "No"]}`,
+ expected: false,
+ },
+ {
+ name: "real question no options",
+ json: `{"text": "Which approach do you prefer?"}`,
+ expected: false,
+ },
+ {
+ name: "completion report no options no question mark",
+ json: `{"text": "All tests pass. Implementation complete. Summary written to CLAUDOMATOR_SUMMARY_FILE."}`,
+ expected: true,
+ },
+ {
+ name: "completion report with empty options",
+ json: `{"text": "Feature implemented and committed.", "options": []}`,
+ expected: true,
+ },
+ {
+ name: "invalid json treated as not a report",
+ json: `not json`,
+ expected: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := isCompletionReport(tt.json)
+ if got != tt.expected {
+ t.Errorf("isCompletionReport(%q) = %v, want %v", tt.json, got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestTailFile_ReturnsLastNLines(t *testing.T) {
+ f, err := os.CreateTemp("", "tailfile-*")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(f.Name())
+ for i := 1; i <= 30; i++ {
+ fmt.Fprintf(f, "line %d\n", i)
+ }
+ f.Close()
+
+ got := tailFile(f.Name(), 5)
+ lines := strings.Split(strings.TrimSpace(got), "\n")
+ if len(lines) != 5 {
+ t.Fatalf("want 5 lines, got %d: %q", len(lines), got)
+ }
+ if lines[0] != "line 26" || lines[4] != "line 30" {
+ t.Errorf("want lines 26-30, got: %q", got)
+ }
+}
+
+func TestDetectUncommittedChanges_ModifiedFile(t *testing.T) {
+ dir := t.TempDir()
+ run := func(args ...string) {
+ cmd := exec.Command(args[0], args[1:]...)
+ cmd.Dir = dir
+ if out, err := cmd.CombinedOutput(); err != nil {
+ t.Fatalf("%v: %s", args, out)
+ }
+ }
+ run("git", "init", dir)
+ run("git", "config", "user.email", "test@test.com")
+ run("git", "config", "user.name", "Test")
+ // Create and commit a file
+ if err := os.WriteFile(dir+"/main.go", []byte("package main"), 0644); err != nil {
+ t.Fatal(err)
+ }
+ run("git", "add", "main.go")
+ run("git", "commit", "-m", "init")
+ // Now modify without committing — simulates agent that forgot to commit
+ if err := os.WriteFile(dir+"/main.go", []byte("package main\n// changed"), 0644); err != nil {
+ t.Fatal(err)
+ }
+ err := detectUncommittedChanges(dir)
+ if err == nil {
+ t.Fatal("expected error for modified uncommitted file, got nil")
+ }
+ if !strings.Contains(err.Error(), "uncommitted") {
+ t.Errorf("error should mention uncommitted, got: %v", err)
+ }
+}
+
+func TestDetectUncommittedChanges_NewUntrackedSourceFile(t *testing.T) {
+ dir := t.TempDir()
+ run := func(args ...string) {
+ cmd := exec.Command(args[0], args[1:]...)
+ cmd.Dir = dir
+ if out, err := cmd.CombinedOutput(); err != nil {
+ t.Fatalf("%v: %s", args, out)
+ }
+ }
+ run("git", "init", dir)
+ run("git", "config", "user.email", "test@test.com")
+ run("git", "config", "user.name", "Test")
+ run("git", "commit", "--allow-empty", "-m", "init")
+ // Agent wrote a new file but never committed it
+ if err := os.WriteFile(dir+"/newfile.go", []byte("package main"), 0644); err != nil {
+ t.Fatal(err)
+ }
+ err := detectUncommittedChanges(dir)
+ if err == nil {
+ t.Fatal("expected error for new untracked source file, got nil")
+ }
+}
+
+func TestDetectUncommittedChanges_ScaffoldFilesIgnored(t *testing.T) {
+ dir := t.TempDir()
+ run := func(args ...string) {
+ cmd := exec.Command(args[0], args[1:]...)
+ cmd.Dir = dir
+ if out, err := cmd.CombinedOutput(); err != nil {
+ t.Fatalf("%v: %s", args, out)
+ }
+ }
+ run("git", "init", dir)
+ run("git", "config", "user.email", "test@test.com")
+ run("git", "config", "user.name", "Test")
+ run("git", "commit", "--allow-empty", "-m", "init")
+ // Write only scaffold files that the harness injects — should not trigger error
+ _ = os.WriteFile(dir+"/.claudomator-env", []byte("KEY=val"), 0600)
+ _ = os.WriteFile(dir+"/.claudomator-instructions.txt", []byte("do stuff"), 0644)
+ _ = os.MkdirAll(dir+"/.agent-home/.claude", 0755)
+ err := detectUncommittedChanges(dir)
+ if err != nil {
+ t.Errorf("scaffold files should not trigger uncommitted error, got: %v", err)
+ }
+}
+
+func TestDetectUncommittedChanges_CleanRepo(t *testing.T) {
+ dir := t.TempDir()
+ run := func(args ...string) {
+ cmd := exec.Command(args[0], args[1:]...)
+ cmd.Dir = dir
+ if out, err := cmd.CombinedOutput(); err != nil {
+ t.Fatalf("%v: %s", args, out)
+ }
+ }
+ run("git", "init", dir)
+ run("git", "config", "user.email", "test@test.com")
+ run("git", "config", "user.name", "Test")
+ if err := os.WriteFile(dir+"/main.go", []byte("package main"), 0644); err != nil {
+ t.Fatal(err)
+ }
+ run("git", "add", "main.go")
+ run("git", "commit", "-m", "init")
+ // No modifications — should pass
+ err := detectUncommittedChanges(dir)
+ if err != nil {
+ t.Errorf("clean repo should not error, got: %v", err)
+ }
+}
+
+func TestGitSafe_PrependsSafeDirectory(t *testing.T) {
+ got := gitSafe("-C", "/some/path", "status")
+ want := []string{"-c", "safe.directory=*", "-c", "commit.gpgsign=false", "-c", "tag.gpgsign=false", "-C", "/some/path", "status"}
+ if len(got) != len(want) {
+ t.Fatalf("gitSafe() = %v, want %v", got, want)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Errorf("gitSafe()[%d] = %q, want %q", i, got[i], want[i])
+ }
+ }
+}
+
+func TestContainerRunner_MissingCredentials_FailsFast(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+
+ claudeConfigDir := t.TempDir()
+
+ // Set up ClaudeConfigDir with MISSING credentials (so pre-flight fails)
+ // Don't create .credentials.json
+ // But DO create .claude.json so the test isolates the credentials check
+ if err := os.WriteFile(filepath.Join(claudeConfigDir, ".claude.json"), []byte("{}"), 0644); err != nil {
+ t.Fatal(err)
+ }
+
+ runner := &ContainerRunner{
+ Logger: logger,
+ Image: "busybox",
+ ClaudeConfigDir: claudeConfigDir,
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if name == "git" && len(arg) > 0 && arg[0] == "clone" {
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ return exec.Command("true")
+ }
+ return exec.Command("true")
+ },
+ }
+
+ tk := &task.Task{
+ ID: "test-missing-creds",
+ RepositoryURL: "https://github.com/example/repo.git",
+ Agent: task.AgentConfig{Type: "claude"},
+ }
+ e := &storage.Execution{ID: "test-exec", TaskID: "test-missing-creds"}
+
+ err := runner.Run(context.Background(), tk, e)
+ if err == nil {
+ t.Fatal("expected error due to missing credentials, got nil")
+ }
+ if !strings.Contains(err.Error(), "credentials not found") {
+ t.Errorf("expected 'credentials not found' error, got: %v", err)
+ }
+}
+
+func TestContainerRunner_MissingSettings_FailsFast(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+
+ claudeConfigDir := t.TempDir()
+
+ // Only create credentials but NOT .claude.json
+ if err := os.WriteFile(filepath.Join(claudeConfigDir, ".credentials.json"), []byte("{}"), 0600); err != nil {
+ t.Fatal(err)
+ }
+
+ runner := &ContainerRunner{
+ Logger: logger,
+ Image: "busybox",
+ ClaudeConfigDir: claudeConfigDir,
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if name == "git" && len(arg) > 0 && arg[0] == "clone" {
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ return exec.Command("true")
+ }
+ return exec.Command("true")
+ },
+ }
+
+ tk := &task.Task{
+ ID: "test-missing-settings",
+ RepositoryURL: "https://github.com/example/repo.git",
+ Agent: task.AgentConfig{Type: "claude"},
+ }
+ e := &storage.Execution{ID: "test-exec-2", TaskID: "test-missing-settings"}
+
+ err := runner.Run(context.Background(), tk, e)
+ if err == nil {
+ t.Fatal("expected error due to missing settings, got nil")
+ }
+ if !strings.Contains(err.Error(), "claude settings") {
+ t.Errorf("expected 'claude settings' error, got: %v", err)
+ }
+}
+
+func TestIsAuthError_DetectsAllVariants(t *testing.T) {
+ tests := []struct {
+ msg string
+ want bool
+ }{
+ {"Not logged in", true},
+ {"OAuth token has expired", true},
+ {"authentication_error: invalid token", true},
+ {"Please run /login to authenticate", true},
+ {"container execution failed: exit status 1", false},
+ {"git clone failed", false},
+ {"", false},
+ }
+ for _, tt := range tests {
+ var err error
+ if tt.msg != "" {
+ err = fmt.Errorf("%s", tt.msg)
+ }
+ got := isAuthError(err)
+ if got != tt.want {
+ t.Errorf("isAuthError(%q) = %v, want %v", tt.msg, got, tt.want)
+ }
+ }
+}
+
+func TestContainerRunner_AuthError_SyncsAndRetries(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+
+ // Create a sync script that creates a marker file
+ syncDir := t.TempDir()
+ syncMarker := filepath.Join(syncDir, "sync-called")
+ syncScript := filepath.Join(syncDir, "sync-creds")
+ os.WriteFile(syncScript, []byte("#!/bin/sh\ntouch "+syncMarker+"\n"), 0755)
+
+ claudeConfigDir := t.TempDir()
+ // Create both credential files in ClaudeConfigDir
+ os.WriteFile(filepath.Join(claudeConfigDir, ".credentials.json"), []byte(`{"token":"fresh"}`), 0600)
+ os.WriteFile(filepath.Join(claudeConfigDir, ".claude.json"), []byte("{}"), 0644)
+
+ callCount := 0
+ runner := &ContainerRunner{
+ Logger: logger,
+ Image: "busybox",
+ ClaudeConfigDir: claudeConfigDir,
+ CredentialSyncCmd: syncScript,
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if name == "git" {
+ if len(arg) > 0 && arg[0] == "clone" {
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ }
+ return exec.Command("true")
+ }
+ if name == "docker" {
+ callCount++
+ if callCount == 1 {
+ // First docker call fails with auth error
+ return exec.Command("sh", "-c", "echo 'Not logged in' >&2; exit 1")
+ }
+ // Second docker call "succeeds"
+ return exec.Command("sh", "-c", "exit 0")
+ }
+ if name == syncScript {
+ return exec.Command("sh", "-c", "touch "+syncMarker)
+ }
+ return exec.Command("true")
+ },
+ }
+
+ tk := &task.Task{
+ ID: "auth-retry-test",
+ RepositoryURL: "https://github.com/example/repo.git",
+ Agent: task.AgentConfig{Type: "claude", Instructions: "test"},
+ }
+ e := &storage.Execution{ID: "auth-retry-exec", TaskID: "auth-retry-test"}
+
+ // Run — first attempt will fail with auth error, triggering sync+retry
+ runner.Run(context.Background(), tk, e)
+ // We don't check error strictly since second run may also fail (git push etc.)
+ // What we care about is that docker was called twice and sync was called
+ if callCount < 2 {
+ t.Errorf("expected docker to be called at least twice (original + retry), got %d", callCount)
+ }
+ if _, err := os.Stat(syncMarker); os.IsNotExist(err) {
+ t.Error("expected sync-credentials to be called, but marker file not found")
+ }
+}
+
+func TestContainerRunner_ClonesStoryBranch(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+
+ var checkoutArgs []string
+ runner := &ContainerRunner{
+ Logger: logger,
+ Image: "busybox",
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if name == "git" && len(arg) > 0 && arg[0] == "clone" {
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ return exec.Command("true")
+ }
+ // Capture checkout calls: both "git checkout <branch>" and "git -C <dir> checkout <branch>"
+ for i, a := range arg {
+ if a == "checkout" {
+ checkoutArgs = append([]string{}, arg[i:]...)
+ break
+ }
+ }
+ if name == "docker" {
+ return exec.Command("sh", "-c", "exit 1")
+ }
+ return exec.Command("true")
+ },
+ }
+
+ tk := &task.Task{
+ ID: "story-branch-test",
+ RepositoryURL: "https://example.com/repo.git",
+ BranchName: "story/my-feature",
+ Agent: task.AgentConfig{Type: "claude"},
+ }
+ e := &storage.Execution{ID: "exec-1", TaskID: "story-branch-test"}
+
+ runner.Run(context.Background(), tk, e)
+ os.RemoveAll(e.SandboxDir)
+
+ // Assert git checkout was called with the story branch name.
+ if len(checkoutArgs) == 0 {
+ t.Fatal("expected git checkout to be called for story branch, but it was not")
+ }
+ found := false
+ for _, a := range checkoutArgs {
+ if a == "story/my-feature" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("expected git checkout story/my-feature, got args: %v", checkoutArgs)
+ }
+}
+
+func TestContainerRunner_ClonesDefaultBranchWhenNoBranchName(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+
+ var cloneArgs []string
+ runner := &ContainerRunner{
+ Logger: logger,
+ Image: "busybox",
+ Command: func(ctx context.Context, name string, arg ...string) *exec.Cmd {
+ if name == "git" && len(arg) > 0 && arg[0] == "clone" {
+ cloneArgs = append([]string{}, arg...)
+ dir := arg[len(arg)-1]
+ os.MkdirAll(dir, 0755)
+ return exec.Command("true")
+ }
+ if name == "docker" {
+ return exec.Command("sh", "-c", "exit 1")
+ }
+ return exec.Command("true")
+ },
+ }
+
+ tk := &task.Task{
+ ID: "no-branch-test",
+ RepositoryURL: "https://example.com/repo.git",
+ Agent: task.AgentConfig{Type: "claude"},
+ }
+ e := &storage.Execution{ID: "exec-2", TaskID: "no-branch-test"}
+
+ runner.Run(context.Background(), tk, e)
+ os.RemoveAll(e.SandboxDir)
+
+ for _, a := range cloneArgs {
+ if a == "--branch" {
+ t.Errorf("expected no --branch flag for task without BranchName, got args: %v", cloneArgs)
+ }
+ }
+}
+
+func TestEnsureStoryBranch_CreatesMissingBranch(t *testing.T) {
+ // Set up a bare repo and a local clone to test branch creation.
+ dir := t.TempDir()
+ bare := filepath.Join(dir, "bare.git")
+ local := filepath.Join(dir, "local")
+
+ // Create bare repo with an initial commit.
+ if out, err := exec.Command("git", "init", "--bare", bare).CombinedOutput(); err != nil {
+ t.Fatalf("git init bare: %v\n%s", err, out)
+ }
+ if out, err := exec.Command("git", "clone", bare, local).CombinedOutput(); err != nil {
+ t.Fatalf("git clone: %v\n%s", err, out)
+ }
+ if out, err := exec.Command("git", "-C", local, "commit", "--allow-empty", "-m", "init").CombinedOutput(); err != nil {
+ t.Fatalf("git commit: %v\n%s", err, out)
+ }
+ if out, err := exec.Command("git", "-C", local, "push", "origin", "main").CombinedOutput(); err != nil {
+ // try master
+ if out2, err2 := exec.Command("git", "-C", local, "push", "origin", "HEAD:main").CombinedOutput(); err2 != nil {
+ t.Fatalf("git push main: %v\n%s\n%s", err, out, out2)
+ }
+ }
+
+ runner := &ContainerRunner{Logger: slog.Default()}
+
+ branch := "story/test-branch"
+
+ // Branch should not exist yet.
+ out, _ := exec.Command("git", "ls-remote", "--heads", bare, branch).CombinedOutput()
+ if len(strings.TrimSpace(string(out))) > 0 {
+ t.Fatal("branch should not exist before ensureStoryBranch")
+ }
+
+ if err := runner.ensureStoryBranch(context.Background(), bare, branch, ""); err != nil {
+ t.Fatalf("ensureStoryBranch: %v", err)
+ }
+
+ // Branch should now exist in the bare repo.
+ out, err := exec.Command("git", "ls-remote", "--heads", bare, branch).CombinedOutput()
+ if err != nil || len(strings.TrimSpace(string(out))) == 0 {
+ t.Errorf("branch %q not found in bare repo after ensureStoryBranch: %s", branch, out)
+ }
+}
+
+func TestEnsureStoryBranch_IdempotentIfExists(t *testing.T) {
+ dir := t.TempDir()
+ bare := filepath.Join(dir, "bare.git")
+ local := filepath.Join(dir, "local")
+
+ if out, err := exec.Command("git", "init", "--bare", bare).CombinedOutput(); err != nil {
+ t.Fatalf("git init bare: %v\n%s", err, out)
+ }
+ if out, err := exec.Command("git", "clone", bare, local).CombinedOutput(); err != nil {
+ t.Fatalf("git clone: %v\n%s", err, out)
+ }
+ if out, err := exec.Command("git", "-C", local, "commit", "--allow-empty", "-m", "init").CombinedOutput(); err != nil {
+ t.Fatalf("git commit: %v\n%s", err, out)
+ }
+ if _, err := exec.Command("git", "-C", local, "push", "origin", "HEAD:main").CombinedOutput(); err != nil {
+ t.Fatalf("push main: %v", err)
+ }
+
+ branch := "story/existing-branch"
+ // Pre-create the branch.
+ if out, err := exec.Command("git", "-C", local, "checkout", "-b", branch).CombinedOutput(); err != nil {
+ t.Fatalf("checkout -b: %v\n%s", err, out)
+ }
+ if out, err := exec.Command("git", "-C", local, "push", "origin", branch).CombinedOutput(); err != nil {
+ t.Fatalf("push branch: %v\n%s", err, out)
+ }
+
+ runner := &ContainerRunner{Logger: slog.Default()}
+
+ // Should be a no-op, not an error.
+ if err := runner.ensureStoryBranch(context.Background(), bare, branch, ""); err != nil {
+ t.Fatalf("ensureStoryBranch on existing branch: %v", err)
+ }
+}
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 315030d..09169bd 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -2,9 +2,11 @@ package executor
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"log/slog"
+ "os/exec"
"path/filepath"
"strings"
"sync"
@@ -25,6 +27,7 @@ type Store interface {
ListSubtasks(parentID string) ([]*task.Task, error)
ListExecutions(taskID string) ([]*storage.Execution, error)
CreateExecution(e *storage.Execution) error
+ CreateExecutionAndSetRunning(e *storage.Execution) error
UpdateExecution(e *storage.Execution) error
UpdateTaskState(id string, newState task.State) error
UpdateTaskQuestion(taskID, questionJSON string) error
@@ -32,6 +35,14 @@ type Store interface {
AppendTaskInteraction(taskID string, interaction task.Interaction) error
UpdateTaskAgent(id string, agent task.AgentConfig) error
UpdateExecutionChangestats(execID string, stats *task.Changestats) error
+ RecordAgentEvent(e storage.AgentEvent) error
+ GetProject(id string) (*task.Project, error)
+ GetStory(id string) (*task.Story, error)
+ ListTasksByStory(storyID string) ([]*task.Task, error)
+ UpdateStoryStatus(id string, status task.StoryState) error
+ CreateTask(t *task.Task) error
+ UpdateTaskCheckerReport(id, report string) error
+ GetCheckerTask(checkedTaskID string) (*task.Task, error)
}
// LogPather is an optional interface runners can implement to provide the log
@@ -56,24 +67,28 @@ type workItem struct {
// Pool manages a bounded set of concurrent task workers.
type Pool struct {
maxConcurrent int
+ maxPerAgent int
runners map[string]Runner
store Store
logger *slog.Logger
- depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s
-
- mu sync.Mutex
- active int
- activePerAgent map[string]int
- rateLimited map[string]time.Time // agentType -> until
- cancels map[string]context.CancelFunc // taskID → cancel
- resultCh chan *Result
- workCh chan workItem // internal bounded queue; Submit enqueues here
- doneCh chan struct{} // signals when a worker slot is freed
- Questions *QuestionRegistry
- Classifier *Classifier
- // LLM, when non-nil, enables LLM-synthesized summaries for executions
- // whose stdout did not include a "## Summary" heading.
- LLM *llm.Client
+ depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s
+ requeueDelay time.Duration // how long to wait before requeuing a blocked-per-agent task; defaults to 30s
+
+ mu sync.Mutex
+ active int
+ activePerAgent map[string]int
+ rateLimited map[string]time.Time // agentType -> until
+ cancels map[string]context.CancelFunc // taskID → cancel
+ consecutiveFailures map[string]int // agentType -> count
+ closed bool // set to true when Shutdown has been called
+ resultCh chan *Result
+ startedCh chan string // task IDs that just transitioned to RUNNING
+ workCh chan workItem // internal bounded queue; Submit enqueues here
+ doneCh chan struct{} // signals when a worker slot is freed
+ workerWg sync.WaitGroup // tracks in-flight execute/executeResume goroutines
+ dispatchDone chan struct{} // closed when the dispatch goroutine exits
+ Classifier *Classifier
+ LLM *llm.Client
}
// Result is emitted when a task execution completes.
@@ -88,18 +103,22 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger *
maxConcurrent = 1
}
p := &Pool{
- maxConcurrent: maxConcurrent,
- runners: runners,
- store: store,
- logger: logger,
- depPollInterval: 5 * time.Second,
- activePerAgent: make(map[string]int),
- rateLimited: make(map[string]time.Time),
- cancels: make(map[string]context.CancelFunc),
- resultCh: make(chan *Result, maxConcurrent*2),
- workCh: make(chan workItem, maxConcurrent*10+100),
- doneCh: make(chan struct{}, maxConcurrent),
- Questions: NewQuestionRegistry(),
+ maxConcurrent: maxConcurrent,
+ maxPerAgent: 1,
+ runners: runners,
+ store: store,
+ logger: logger,
+ depPollInterval: 5 * time.Second,
+ requeueDelay: 30 * time.Second,
+ activePerAgent: make(map[string]int),
+ rateLimited: make(map[string]time.Time),
+ cancels: make(map[string]context.CancelFunc),
+ consecutiveFailures: make(map[string]int),
+ resultCh: make(chan *Result, maxConcurrent*2),
+ startedCh: make(chan string, maxConcurrent*2),
+ workCh: make(chan workItem, maxConcurrent*10+100),
+ doneCh: make(chan struct{}, maxConcurrent),
+ dispatchDone: make(chan struct{}),
}
go p.dispatch()
return p
@@ -109,6 +128,7 @@ func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger *
// and launches goroutines as soon as a pool slot is available. This prevents
// tasks from being rejected when the pool is temporarily at capacity.
func (p *Pool) dispatch() {
+ defer close(p.dispatchDone)
for item := range p.workCh {
for {
p.mu.Lock()
@@ -116,9 +136,9 @@ func (p *Pool) dispatch() {
p.active++
p.mu.Unlock()
if item.exec != nil {
- go p.executeResume(item.ctx, item.task, item.exec)
+ p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.executeResume(i.ctx, i.task, i.exec) }(item)
} else {
- go p.execute(item.ctx, item.task)
+ p.workerWg.Add(1); go func(i workItem) { defer p.workerWg.Done(); p.execute(i.ctx, i.task) }(item)
}
break
}
@@ -132,19 +152,64 @@ func (p *Pool) dispatch() {
// work queue is full. When the pool is at capacity the task is buffered and
// dispatched as soon as a slot becomes available.
func (p *Pool) Submit(ctx context.Context, t *task.Task) error {
+ p.mu.Lock()
+ if p.closed {
+ p.mu.Unlock()
+ return fmt.Errorf("executor pool is shut down")
+ }
+ // Send while holding the lock so that Shutdown cannot close workCh between
+ // the closed-check above and the send below. The dispatch goroutine never
+ // holds p.mu while receiving from workCh, so this cannot deadlock.
select {
case p.workCh <- workItem{ctx: ctx, task: t}:
+ p.mu.Unlock()
return nil
default:
+ p.mu.Unlock()
return fmt.Errorf("executor work queue full (capacity %d)", cap(p.workCh))
}
}
+// Started returns a channel that emits task IDs when they transition to RUNNING.
+func (p *Pool) Started() <-chan string {
+ return p.startedCh
+}
+
// Results returns the channel for reading execution results.
func (p *Pool) Results() <-chan *Result {
return p.resultCh
}
+// Shutdown stops accepting new work and waits for all in-flight workers to
+// finish. Returns ctx.Err() if the context deadline is exceeded before all
+// workers complete.
+func (p *Pool) Shutdown(ctx context.Context) error {
+ // Stop the dispatch goroutine. We must wait for it to exit before calling
+ // workerWg.Wait() to avoid a race between dispatch's Add(1) and Wait().
+ p.mu.Lock()
+ p.closed = true
+ p.mu.Unlock()
+ close(p.workCh)
+ select {
+ case <-p.dispatchDone:
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+
+ done := make(chan struct{})
+ go func() {
+ p.workerWg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
// Cancel requests cancellation of a running task. Returns false if the task
// is not currently running in this pool.
func (p *Pool) Cancel(taskID string) bool {
@@ -250,11 +315,12 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex
exec.StartTime = time.Now().UTC()
exec.Status = "RUNNING"
- if err := p.store.CreateExecution(exec); err != nil {
+ if err := p.store.CreateExecutionAndSetRunning(exec); err != nil {
p.logger.Error("failed to create resume execution record", "error", err)
}
- if err := p.store.UpdateTaskState(t.ID, task.StateRunning); err != nil {
- p.logger.Error("failed to update task state", "error", err)
+ select {
+ case p.startedCh <- t.ID:
+ default:
}
var cancel context.CancelFunc
@@ -273,6 +339,19 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex
p.mu.Unlock()
}()
+ // Populate RepositoryURL from Project registry if missing (ADR-007).
+ if t.RepositoryURL == "" && t.Project != "" {
+ if proj, err := p.store.GetProject(t.Project); err == nil && proj.RemoteURL != "" {
+ t.RepositoryURL = proj.RemoteURL
+ }
+ }
+ // Populate BranchName from Story if missing (ADR-007).
+ if t.BranchName == "" && t.StoryID != "" {
+ if story, err := p.store.GetStory(t.StoryID); err == nil && story.BranchName != "" {
+ t.BranchName = story.BranchName
+ }
+ }
+
err = runner.Run(ctx, t, exec)
exec.EndTime = time.Now().UTC()
@@ -289,16 +368,32 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.
if retry.IsRateLimitError(err) || isQuotaExhausted(err) {
p.mu.Lock()
retryAfter := retry.ParseRetryAfter(err.Error())
- if retryAfter == 0 {
- if isQuotaExhausted(err) {
+ reason := "transient"
+ if isQuotaExhausted(err) {
+ reason = "quota"
+ if retryAfter == 0 {
retryAfter = 5 * time.Hour
- } else {
- retryAfter = 1 * time.Minute
}
+ } else if retryAfter == 0 {
+ retryAfter = 1 * time.Minute
}
- p.rateLimited[agentType] = time.Now().Add(retryAfter)
+ until := time.Now().Add(retryAfter)
+ p.rateLimited[agentType] = until
p.logger.Info("agent rate limited", "agent", agentType, "retryAfter", retryAfter, "quotaExhausted", isQuotaExhausted(err))
p.mu.Unlock()
+ go func() {
+ ev := storage.AgentEvent{
+ ID: uuid.New().String(),
+ Agent: agentType,
+ Event: "rate_limited",
+ Timestamp: time.Now(),
+ Until: &until,
+ Reason: reason,
+ }
+ if recErr := p.store.RecordAgentEvent(ev); recErr != nil {
+ p.logger.Warn("failed to record agent event", "error", recErr)
+ }
+ }()
}
var blockedErr *BlockedError
@@ -335,9 +430,51 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.
if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil {
p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err)
}
+ p.mu.Lock()
+ p.consecutiveFailures[agentType]++
+ p.mu.Unlock()
+ }
+ // If this is a checker task, attach the failure report for any terminal
+ // failure state (FAILED, TIMED_OUT, CANCELLED, BUDGET_EXCEEDED).
+ if t.CheckerForTaskID != "" && exec.ErrorMsg != "" {
+ if reportErr := p.store.UpdateTaskCheckerReport(t.CheckerForTaskID, exec.ErrorMsg); reportErr != nil {
+ p.logger.Error("handleRunResult: failed to set checker report", "taskID", t.CheckerForTaskID, "error", reportErr)
+ }
+ }
+ if t.StoryID != "" && exec.Status == "FAILED" {
+ storyID := t.StoryID
+ errMsg := exec.ErrorMsg
+ go func() {
+ story, getErr := p.store.GetStory(storyID)
+ if getErr != nil {
+ return
+ }
+ if story.Status == task.StoryValidating {
+ p.checkValidationResult(ctx, storyID, task.StateFailed, errMsg)
+ }
+ }()
}
} else {
- if t.ParentTaskID == "" {
+ p.mu.Lock()
+ p.consecutiveFailures[agentType] = 0
+ p.mu.Unlock()
+ if t.CheckerForTaskID != "" {
+ // Checker task succeeded — auto-accept the checked task.
+ exec.Status = "COMPLETED"
+ if err := p.store.UpdateTaskState(t.ID, task.StateCompleted); err != nil {
+ p.logger.Error("handleRunResult: failed to complete checker task", "taskID", t.ID, "error", err)
+ }
+ checkedTask, getErr := p.store.GetTask(t.CheckerForTaskID)
+ if getErr == nil {
+ if acceptErr := p.store.UpdateTaskState(t.CheckerForTaskID, task.StateCompleted); acceptErr != nil {
+ p.logger.Error("handleRunResult: failed to auto-accept checked task", "taskID", t.CheckerForTaskID, "error", acceptErr)
+ } else if checkedTask.StoryID != "" {
+ go p.checkStoryCompletion(context.Background(), checkedTask.StoryID)
+ }
+ } else {
+ p.logger.Error("handleRunResult: failed to get checked task", "taskID", t.CheckerForTaskID, "error", getErr)
+ }
+ } else if t.ParentTaskID == "" {
subtasks, subErr := p.store.ListSubtasks(t.ID)
if subErr != nil {
p.logger.Error("failed to list subtasks", "taskID", t.ID, "error", subErr)
@@ -352,6 +489,7 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.
if err := p.store.UpdateTaskState(t.ID, task.StateReady); err != nil {
p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateReady, "error", err)
}
+ go p.spawnCheckerTask(context.Background(), t)
}
} else {
exec.Status = "COMPLETED"
@@ -360,6 +498,21 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.
}
p.maybeUnblockParent(t.ParentTaskID)
}
+ if t.StoryID != "" {
+ storyID := t.StoryID
+ go func() {
+ story, getErr := p.store.GetStory(storyID)
+ if getErr != nil {
+ p.logger.Error("handleRunResult: failed to get story", "storyID", storyID, "error", getErr)
+ return
+ }
+ if story.Status == task.StoryValidating {
+ p.checkValidationResult(ctx, storyID, task.StateCompleted, "")
+ } else {
+ p.checkStoryCompletion(ctx, storyID)
+ }
+ }()
+ }
}
summary := exec.Summary
@@ -374,6 +527,13 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.
p.logger.Error("failed to update task summary", "taskID", t.ID, "error", summaryErr)
}
}
+ terminalFailure := exec.Status == "FAILED" || exec.Status == "TIMED_OUT" || exec.Status == "CANCELLED" || exec.Status == "BUDGET_EXCEEDED"
+ if t.CheckerForTaskID != "" && terminalFailure && summary != "" {
+ // Overwrite the initial error-message report with the richer summary.
+ if reportErr := p.store.UpdateTaskCheckerReport(t.CheckerForTaskID, summary); reportErr != nil {
+ p.logger.Error("handleRunResult: failed to update checker report with summary", "taskID", t.CheckerForTaskID, "error", reportErr)
+ }
+ }
if exec.StdoutPath != "" {
if cs := task.ParseChangestatFromFile(exec.StdoutPath); cs != nil {
exec.Changestats = cs
@@ -388,6 +548,256 @@ func (p *Pool) handleRunResult(ctx context.Context, t *task.Task, exec *storage.
p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err}
}
+// checkStoryCompletion checks whether all top-level tasks in a story have reached
+// a terminal success state and transitions the story to SHIPPABLE if so.
+// Subtasks are intentionally excluded — a parent task reaching READY/COMPLETED
+// already accounts for its subtasks.
+// CheckStoryCompletion is the exported entry point for story completion checks
+// called from outside the package (e.g. the API accept handler).
+func (p *Pool) CheckStoryCompletion(ctx context.Context, storyID string) {
+ p.checkStoryCompletion(ctx, storyID)
+}
+
+func (p *Pool) checkStoryCompletion(ctx context.Context, storyID string) {
+ story, err := p.store.GetStory(storyID)
+ if err != nil {
+ p.logger.Error("checkStoryCompletion: failed to get story", "storyID", storyID, "error", err)
+ return
+ }
+ if story.Status != task.StoryInProgress {
+ return // already SHIPPABLE or beyond — nothing to do
+ }
+ tasks, err := p.store.ListTasksByStory(storyID)
+ if err != nil {
+ p.logger.Error("checkStoryCompletion: failed to list tasks", "storyID", storyID, "error", err)
+ return
+ }
+ if len(tasks) == 0 {
+ return
+ }
+ topLevelCount := 0
+ for _, t := range tasks {
+ if t.ParentTaskID != "" {
+ continue // subtasks are covered by their parent
+ }
+ topLevelCount++
+ if t.State != task.StateCompleted {
+ return // not all top-level tasks done; READY alone is not sufficient (checker may be pending)
+ }
+ }
+ if topLevelCount == 0 {
+ return // no top-level tasks — don't auto-complete
+ }
+ if err := p.store.UpdateStoryStatus(storyID, task.StoryShippable); err != nil {
+ p.logger.Error("checkStoryCompletion: failed to update story status", "storyID", storyID, "error", err)
+ return
+ }
+ p.logger.Info("story transitioned to SHIPPABLE", "storyID", storyID)
+ // Deploy is now triggered explicitly by the human via POST /api/stories/{id}/ship.
+}
+
+// ShipStory merges the story branch and runs the deploy script.
+// Returns an error if the story is not in SHIPPABLE state.
+func (p *Pool) ShipStory(ctx context.Context, storyID string) error {
+ story, err := p.store.GetStory(storyID)
+ if err != nil {
+ return fmt.Errorf("story not found: %w", err)
+ }
+ if story.Status != task.StoryShippable {
+ return fmt.Errorf("story is not SHIPPABLE (current status: %s)", story.Status)
+ }
+ go p.triggerStoryDeploy(context.Background(), storyID)
+ return nil
+}
+
+// spawnCheckerTask creates and submits a checker task for the given completed task.
+// Guards: not called for subtasks, checker tasks, tasks without a repository URL,
+// or tasks that already have a checker.
+func (p *Pool) spawnCheckerTask(ctx context.Context, checked *task.Task) {
+ // Never spawn a checker for subtasks, checker tasks, or tasks without a repository.
+ if checked.ParentTaskID != "" || checked.CheckerForTaskID != "" || checked.RepositoryURL == "" {
+ return
+ }
+ // Idempotent: don't create a second checker if one already exists.
+ existing, err := p.store.GetCheckerTask(checked.ID)
+ if err != nil {
+ p.logger.Error("spawnCheckerTask: GetCheckerTask failed", "taskID", checked.ID, "error", err)
+ return
+ }
+ if existing != nil {
+ return
+ }
+
+ criteria := checked.AcceptanceCriteria
+ if criteria == "" {
+ criteria = checked.Agent.Instructions
+ }
+
+ instructions := fmt.Sprintf(`You are validating a completed task. Do not make any changes to the code or repository.
+
+Task: %s
+Instructions given to the implementor:
+%s
+
+Acceptance criteria:
+%s
+
+Steps:
+1. Clone the repository and review the changes made.
+2. Verify each acceptance criterion is met. Run tests or make HTTP requests as needed.
+3. If all criteria are satisfied, exit normally (success).
+4. If any criterion is not met, use the Bash tool to exit with a non-zero code:
+ bash -c "exit 1"
+ Before exiting, write a brief summary of what failed.`, checked.Name, checked.Agent.Instructions, criteria)
+
+ now := time.Now().UTC()
+ checker := &task.Task{
+ ID: uuid.New().String(),
+ Name: "Check: " + checked.Name,
+ CheckerForTaskID: checked.ID,
+ RepositoryURL: checked.RepositoryURL,
+ Agent: task.AgentConfig{
+ Type: "claude",
+ Instructions: instructions,
+ MaxBudgetUSD: 0.50,
+ AllowedTools: []string{"Bash", "Read", "Glob", "Grep"},
+ },
+ Timeout: task.Duration{Duration: 10 * time.Minute},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ State: task.StatePending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+
+ if err := p.store.CreateTask(checker); err != nil {
+ p.logger.Error("spawnCheckerTask: CreateTask failed", "error", err)
+ return
+ }
+ checker.State = task.StateQueued
+ if err := p.store.UpdateTaskState(checker.ID, task.StateQueued); err != nil {
+ p.logger.Error("spawnCheckerTask: UpdateTaskState failed", "error", err)
+ return
+ }
+ if err := p.Submit(ctx, checker); err != nil {
+ p.logger.Error("spawnCheckerTask: Submit failed", "error", err)
+ }
+}
+
+// triggerStoryDeploy runs the project deploy script for a SHIPPABLE story
+// and advances it to DEPLOYED on success.
+func (p *Pool) triggerStoryDeploy(ctx context.Context, storyID string) {
+ story, err := p.store.GetStory(storyID)
+ if err != nil {
+ p.logger.Error("triggerStoryDeploy: failed to get story", "storyID", storyID, "error", err)
+ return
+ }
+ if story.ProjectID == "" {
+ return
+ }
+ proj, err := p.store.GetProject(story.ProjectID)
+ if err != nil {
+ p.logger.Error("triggerStoryDeploy: failed to get project", "storyID", storyID, "projectID", story.ProjectID, "error", err)
+ return
+ }
+ if proj.DeployScript == "" {
+ return
+ }
+ // Merge story branch to main before deploying (ADR-007).
+ if story.BranchName != "" && proj.LocalPath != "" {
+ mergeSteps := [][]string{
+ {"git", "-C", proj.LocalPath, "fetch", "origin"},
+ {"git", "-C", proj.LocalPath, "checkout", "main"},
+ {"git", "-C", proj.LocalPath, "merge", "--no-ff", story.BranchName, "-m", "Merge " + story.BranchName},
+ {"git", "-C", proj.LocalPath, "push", "origin", "main"},
+ }
+ for _, args := range mergeSteps {
+ if mergeOut, mergeErr := exec.CommandContext(ctx, args[0], args[1:]...).CombinedOutput(); mergeErr != nil {
+ p.logger.Error("triggerStoryDeploy: merge failed", "cmd", args, "output", string(mergeOut), "error", mergeErr)
+ return
+ }
+ }
+ p.logger.Info("story branch merged to main", "storyID", storyID, "branch", story.BranchName)
+ }
+ out, err := exec.CommandContext(ctx, proj.DeployScript).CombinedOutput()
+ if err != nil {
+ p.logger.Error("triggerStoryDeploy: deploy script failed", "storyID", storyID, "script", proj.DeployScript, "output", string(out), "error", err)
+ return
+ }
+ if err := p.store.UpdateStoryStatus(storyID, task.StoryDeployed); err != nil {
+ p.logger.Error("triggerStoryDeploy: failed to update story status", "storyID", storyID, "error", err)
+ return
+ }
+ p.logger.Info("story transitioned to DEPLOYED", "storyID", storyID)
+ go p.createValidationTask(ctx, storyID)
+}
+
+// createValidationTask creates a validation subtask from the story's ValidationJSON
+// and transitions the story to VALIDATING.
+func (p *Pool) createValidationTask(ctx context.Context, storyID string) {
+ story, err := p.store.GetStory(storyID)
+ if err != nil {
+ p.logger.Error("createValidationTask: failed to get story", "storyID", storyID, "error", err)
+ return
+ }
+ if story.ValidationJSON == "" {
+ p.logger.Warn("createValidationTask: story has no ValidationJSON, skipping", "storyID", storyID)
+ return
+ }
+
+ var spec map[string]interface{}
+ if err := json.Unmarshal([]byte(story.ValidationJSON), &spec); err != nil {
+ p.logger.Error("createValidationTask: failed to parse ValidationJSON", "storyID", storyID, "error", err)
+ return
+ }
+
+ instructions := fmt.Sprintf("Validate the deployment for story %q.\n\nValidation spec:\n%s", story.Name, story.ValidationJSON)
+
+ now := time.Now().UTC()
+ vtask := &task.Task{
+ ID: uuid.New().String(),
+ Name: fmt.Sprintf("validation: %s", story.Name),
+ StoryID: storyID,
+ State: task.StateQueued,
+ Agent: task.AgentConfig{Type: "claude", Instructions: instructions},
+ Tags: []string{},
+ DependsOn: []string{},
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+
+ if err := p.store.CreateTask(vtask); err != nil {
+ p.logger.Error("createValidationTask: failed to create task", "storyID", storyID, "error", err)
+ return
+ }
+ if err := p.store.UpdateStoryStatus(storyID, task.StoryValidating); err != nil {
+ p.logger.Error("createValidationTask: failed to update story status", "storyID", storyID, "error", err)
+ return
+ }
+ p.logger.Info("validation task created and story transitioned to VALIDATING", "storyID", storyID, "taskID", vtask.ID)
+ p.Submit(ctx, vtask) //nolint:errcheck
+}
+
+// checkValidationResult inspects a completed validation task and transitions
+// the story to REVIEW_READY or NEEDS_FIX accordingly.
+func (p *Pool) checkValidationResult(ctx context.Context, storyID string, taskState task.State, errorMsg string) {
+ if taskState == task.StateCompleted {
+ if err := p.store.UpdateStoryStatus(storyID, task.StoryReviewReady); err != nil {
+ p.logger.Error("checkValidationResult: failed to update story status", "storyID", storyID, "error", err)
+ return
+ }
+ p.logger.Info("story transitioned to REVIEW_READY", "storyID", storyID)
+ } else {
+ if err := p.store.UpdateStoryStatus(storyID, task.StoryNeedsFix); err != nil {
+ p.logger.Error("checkValidationResult: failed to update story status", "storyID", storyID, "error", err)
+ return
+ }
+ p.logger.Info("story transitioned to NEEDS_FIX", "storyID", storyID, "error", errorMsg)
+ }
+}
+
// ActiveCount returns the number of currently running tasks.
func (p *Pool) ActiveCount() int {
p.mu.Lock()
@@ -395,6 +805,34 @@ func (p *Pool) ActiveCount() int {
return p.active
}
+// AgentStatusInfo holds the current state of a single agent.
+type AgentStatusInfo struct {
+ Agent string `json:"agent"`
+ ActiveTasks int `json:"active_tasks"`
+ RateLimited bool `json:"rate_limited"`
+ Until *time.Time `json:"until,omitempty"`
+}
+
+// AgentStatuses returns the current status of all registered agents.
+func (p *Pool) AgentStatuses() []AgentStatusInfo {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ now := time.Now()
+ var out []AgentStatusInfo
+ for agent := range p.runners {
+ info := AgentStatusInfo{
+ Agent: agent,
+ ActiveTasks: p.activePerAgent[agent],
+ }
+ if deadline, ok := p.rateLimited[agent]; ok && now.Before(deadline) {
+ info.RateLimited = true
+ info.Until = &deadline
+ }
+ out = append(out, info)
+ }
+ return out
+}
+
// pickAgent selects the best agent from the given SystemStatus using explicit
// load balancing: prefer the available (non-rate-limited) agent with the fewest
// active tasks. If all agents are rate-limited, fall back to fewest active.
@@ -436,6 +874,18 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
activeTasks[agent] = p.activePerAgent[agent]
if deadline, ok := p.rateLimited[agent]; ok && now.After(deadline) {
delete(p.rateLimited, agent)
+ agentName := agent
+ go func() {
+ ev := storage.AgentEvent{
+ ID: uuid.New().String(),
+ Agent: agentName,
+ Event: "available",
+ Timestamp: time.Now(),
+ }
+ if recErr := p.store.RecordAgentEvent(ev); recErr != nil {
+ p.logger.Warn("failed to record agent available event", "error", recErr)
+ }
+ }()
}
rateLimited[agent] = now.Before(p.rateLimited[agent])
}
@@ -479,9 +929,58 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
agentType = "claude"
}
+ // Check dependencies before taking the per-agent slot to avoid deadlock:
+ // if a dependent task holds the slot while waiting for its dependency to run,
+ // the dependency can never start (maxPerAgent=1).
+ if len(t.DependsOn) > 0 {
+ ready, depErr := p.checkDepsReady(t)
+ if depErr != nil {
+ // A dependency hit a terminal failure — cancel this task immediately.
+ now := time.Now().UTC()
+ exec := &storage.Execution{
+ ID: uuid.New().String(),
+ TaskID: t.ID,
+ StartTime: now,
+ EndTime: now,
+ Status: "CANCELLED",
+ ErrorMsg: depErr.Error(),
+ }
+ if createErr := p.store.CreateExecution(exec); createErr != nil {
+ p.logger.Error("failed to create execution record", "error", createErr)
+ }
+ if err := p.store.UpdateTaskState(t.ID, task.StateCancelled); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateCancelled, "error", err)
+ }
+ p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: depErr}
+ return
+ }
+ if !ready {
+ // Dependencies not yet done — requeue without holding the slot.
+ time.AfterFunc(p.requeueDelay, func() { p.workCh <- workItem{ctx: ctx, task: t} })
+ return
+ }
+ }
p.mu.Lock()
+
+ if p.activePerAgent[agentType] >= p.maxPerAgent {
+ p.mu.Unlock()
+ time.AfterFunc(p.requeueDelay, func() { p.workCh <- workItem{ctx: ctx, task: t} })
+ return
+ }
if deadline, ok := p.rateLimited[agentType]; ok && time.Now().After(deadline) {
delete(p.rateLimited, agentType)
+ agentName := agentType
+ go func() {
+ ev := storage.AgentEvent{
+ ID: uuid.New().String(),
+ Agent: agentName,
+ Event: "available",
+ Timestamp: time.Now(),
+ }
+ if recErr := p.store.RecordAgentEvent(ev); recErr != nil {
+ p.logger.Warn("failed to record agent available event", "error", recErr)
+ }
+ }()
}
p.activePerAgent[agentType]++
p.mu.Unlock()
@@ -512,30 +1011,6 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
return
}
- // Wait for all dependencies to complete before starting execution.
- if len(t.DependsOn) > 0 {
- if err := p.waitForDependencies(ctx, t); err != nil {
- now := time.Now().UTC()
- exec := &storage.Execution{
- ID: uuid.New().String(),
- TaskID: t.ID,
- StartTime: now,
- EndTime: now,
- Status: "FAILED",
- ErrorMsg: err.Error(),
- }
- if createErr := p.store.CreateExecution(exec); createErr != nil {
- p.logger.Error("failed to create execution record", "error", createErr)
- }
- if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil {
- p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err)
- }
- p.decActiveAgent(agentType, &cleaned)
- p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err}
- return
- }
- }
-
execID := uuid.New().String()
exec := &storage.Execution{
ID: execID,
@@ -554,12 +1029,13 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
}
}
- // Record execution start.
- if err := p.store.CreateExecution(exec); err != nil {
+ // Record execution start atomically with the RUNNING state transition.
+ if err := p.store.CreateExecutionAndSetRunning(exec); err != nil {
p.logger.Error("failed to create execution record", "error", err)
}
- if err := p.store.UpdateTaskState(t.ID, task.StateRunning); err != nil {
- p.logger.Error("failed to update task state", "error", err)
+ select {
+ case p.startedCh <- t.ID:
+ default:
}
// Apply task timeout and register cancel so callers can stop this task.
@@ -583,6 +1059,19 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
priorExecs, priorErr := p.store.ListExecutions(t.ID)
t = withFailureHistory(t, priorExecs, priorErr)
+ // Populate RepositoryURL from Project registry if missing (ADR-007).
+ if t.RepositoryURL == "" && t.Project != "" {
+ if proj, err := p.store.GetProject(t.Project); err == nil && proj.RemoteURL != "" {
+ t.RepositoryURL = proj.RemoteURL
+ }
+ }
+ // Populate BranchName from Story if missing (ADR-007).
+ if t.BranchName == "" && t.StoryID != "" {
+ if story, err := p.store.GetStory(t.StoryID); err == nil && story.BranchName != "" {
+ t.BranchName = story.BranchName
+ }
+ }
+
// Run the task.
err = runner.Run(ctx, t, exec)
exec.EndTime = time.Now().UTC()
@@ -650,18 +1139,31 @@ func (p *Pool) RecoverStaleQueued(ctx context.Context) {
}
}
-// RecoverStaleBlocked promotes any BLOCKED parent task to READY when all of its
-// subtasks are already COMPLETED. This handles the case where the server was
-// restarted after subtasks finished but before maybeUnblockParent could fire.
+// RecoverStaleBlocked promotes any BLOCKED or QUEUED parent task to READY when
+// all of its subtasks are already COMPLETED. This handles the case where the
+// server was restarted after subtasks finished but before maybeUnblockParent
+// could fire, and also the case where story approval pre-created subtasks
+// without ever running the parent task.
// Call this once on server startup, after RecoverStaleRunning and RecoverStaleQueued.
func (p *Pool) RecoverStaleBlocked() {
- tasks, err := p.store.ListTasks(storage.TaskFilter{State: task.StateBlocked})
- if err != nil {
- p.logger.Error("RecoverStaleBlocked: list tasks", "error", err)
- return
- }
- for _, t := range tasks {
- p.maybeUnblockParent(t.ID)
+ ctx := context.Background()
+ for _, state := range []task.State{task.StateBlocked, task.StateQueued} {
+ tasks, err := p.store.ListTasks(storage.TaskFilter{State: state})
+ if err != nil {
+ p.logger.Error("RecoverStaleBlocked: list tasks", "error", err, "state", state)
+ continue
+ }
+ for _, t := range tasks {
+ if t.ParentTaskID != "" {
+ continue // only promote actual parents
+ }
+ before := t.State
+ p.maybeUnblockParent(t.ID)
+ // If the parent was promoted, check story completion.
+ if after, err := p.store.GetTask(t.ID); err == nil && after.State != before && t.StoryID != "" {
+ p.checkStoryCompletion(ctx, t.StoryID)
+ }
+ }
}
}
@@ -673,6 +1175,32 @@ var terminalFailureStates = map[task.State]bool{
task.StateBudgetExceeded: true,
}
+// depDoneStates are task states that satisfy a DependsOn dependency.
+var depDoneStates = map[task.State]bool{
+ task.StateCompleted: true,
+ task.StateReady: true, // leaf tasks finish at READY
+}
+
+// checkDepsReady does a single synchronous check of t.DependsOn.
+// Returns (true, nil) if all deps are done, (false, nil) if any are still pending,
+// or (false, err) if a dep entered a terminal failure state.
+func (p *Pool) checkDepsReady(t *task.Task) (bool, error) {
+ for _, depID := range t.DependsOn {
+ dep, err := p.store.GetTask(depID)
+ if err != nil {
+ return false, fmt.Errorf("dependency %q not found: %w", depID, err)
+ }
+ if depDoneStates[dep.State] {
+ continue
+ }
+ if terminalFailureStates[dep.State] {
+ return false, fmt.Errorf("dependency %q ended in state %s", depID, dep.State)
+ }
+ return false, nil // still pending
+ }
+ return true, nil
+}
+
// withFailureHistory returns a shallow copy of t with prior failed execution
// error messages prepended to SystemPromptAppend so the agent knows what went
// wrong in previous attempts.
@@ -710,16 +1238,16 @@ func withFailureHistory(t *task.Task, execs []*storage.Execution, err error) *ta
return &copy
}
-// maybeUnblockParent transitions the parent task from BLOCKED to READY if all
-// of its subtasks are in the COMPLETED state. If any subtask is not COMPLETED
-// (including FAILED, CANCELLED, RUNNING, etc.) the parent stays BLOCKED.
+// maybeUnblockParent transitions the parent task to READY if all of its subtasks
+// are in the COMPLETED state. Handles both BLOCKED parents (ran, created subtasks,
+// paused) and QUEUED parents (story approval created subtasks without running parent).
func (p *Pool) maybeUnblockParent(parentID string) {
parent, err := p.store.GetTask(parentID)
if err != nil {
p.logger.Error("maybeUnblockParent: get parent", "parentID", parentID, "error", err)
return
}
- if parent.State != task.StateBlocked {
+ if parent.State != task.StateBlocked && parent.State != task.StateQueued {
return
}
subtasks, err := p.store.ListSubtasks(parentID)
@@ -727,6 +1255,11 @@ func (p *Pool) maybeUnblockParent(parentID string) {
p.logger.Error("maybeUnblockParent: list subtasks", "parentID", parentID, "error", err)
return
}
+ // A task with no subtasks was never blocked by subtask delegation — don't promote it.
+ // This prevents incorrectly promoting leaf tasks that are stuck in QUEUED to READY.
+ if len(subtasks) == 0 {
+ return
+ }
for _, sub := range subtasks {
if sub.State != task.StateCompleted {
return
@@ -747,7 +1280,7 @@ func (p *Pool) waitForDependencies(ctx context.Context, t *task.Task) error {
if err != nil {
return fmt.Errorf("dependency %q not found: %w", depID, err)
}
- if dep.State == task.StateCompleted {
+ if depDoneStates[dep.State] {
continue
}
if terminalFailureStates[dep.State] {
diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go
index b1173cb..9214872 100644
--- a/internal/executor/executor_test.go
+++ b/internal/executor/executor_test.go
@@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"os"
+ "os/exec"
"path/filepath"
"strings"
"sync"
@@ -600,10 +601,17 @@ func TestPool_RecoverStaleRunning(t *testing.T) {
// Execution record should be closed as FAILED.
execs, _ := store.ListExecutions(tk.ID)
- if len(execs) == 0 || execs[0].Status != "FAILED" {
+ var failedExec *storage.Execution
+ for _, e := range execs {
+ if e.ID == "exec-stale-1" {
+ failedExec = e
+ break
+ }
+ }
+ if failedExec == nil || failedExec.Status != "FAILED" {
t.Errorf("execution status: want FAILED, got %+v", execs)
}
- if execs[0].ErrorMsg == "" {
+ if failedExec.ErrorMsg == "" {
t.Error("expected non-empty error message on recovered execution")
}
@@ -739,6 +747,119 @@ func TestPool_RecoverStaleBlocked_KeepsBlockedWhenSubtaskIncomplete(t *testing.T
}
}
+func TestPool_RecoverStaleBlocked_PromotesQueuedParentWithAllSubtasksDone(t *testing.T) {
+ store := testStore(t)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger)
+
+ now := time.Now().UTC()
+ story := &task.Story{
+ ID: "story-queued-parent", Name: "Queued Parent Story",
+ Status: task.StoryInProgress, CreatedAt: now, UpdatedAt: now,
+ }
+ store.CreateStory(story)
+
+ // Parent task stuck QUEUED (approved with pre-created subtasks, never run).
+ parent := makeTask("queued-parent-1")
+ parent.State = task.StateQueued
+ parent.StoryID = story.ID
+ store.CreateTask(parent)
+
+ for i := 0; i < 2; i++ {
+ sub := makeTask(fmt.Sprintf("queued-sub-%d", i))
+ sub.ParentTaskID = parent.ID
+ sub.StoryID = story.ID
+ sub.State = task.StateCompleted
+ store.CreateTask(sub)
+ }
+
+ pool.RecoverStaleBlocked()
+
+ got, err := store.GetTask(parent.ID)
+ if err != nil {
+ t.Fatalf("GetTask: %v", err)
+ }
+ if got.State != task.StateReady {
+ t.Errorf("parent state: want READY, got %s", got.State)
+ }
+
+ // Story should still be IN_PROGRESS — READY tasks don't satisfy the completion check;
+ // the task must be accepted (READY → COMPLETED) before the story advances to SHIPPABLE.
+ s, err := store.GetStory(story.ID)
+ if err != nil {
+ t.Fatalf("GetStory: %v", err)
+ }
+ if s.Status != task.StoryInProgress {
+ t.Errorf("story status: want IN_PROGRESS, got %s", s.Status)
+ }
+}
+
+// TestPool_RecoverStaleBlocked_DoesNotPromoteQueuedLeafTask verifies that a top-level
+// QUEUED task with NO subtasks is not promoted to READY by RecoverStaleBlocked.
+// This guards against the bug where a task that failed to start (stuck in QUEUED due
+// to a DB error) was incorrectly promoted to READY because the "all subtasks done"
+// check is vacuously true when there are no subtasks.
+func TestPool_RecoverStaleBlocked_DoesNotPromoteQueuedLeafTask(t *testing.T) {
+ store := testStore(t)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger)
+
+ // A top-level task stuck in QUEUED with no subtasks (e.g. DB lock prevented RUNNING transition).
+ leaf := makeTask("queued-leaf-no-subtasks")
+ leaf.State = task.StateQueued
+ store.CreateTask(leaf)
+
+ pool.RecoverStaleBlocked()
+
+ got, err := store.GetTask(leaf.ID)
+ if err != nil {
+ t.Fatalf("GetTask: %v", err)
+ }
+ if got.State != task.StateQueued {
+ t.Errorf("leaf task state: want QUEUED (unchanged), got %s", got.State)
+ }
+}
+
+// TestPool_CheckStoryCompletion_ReadyTasksNotSufficient verifies that READY tasks
+// alone do not advance a story to SHIPPABLE — tasks must be COMPLETED.
+func TestPool_CheckStoryCompletion_ReadyTasksNotSufficient(t *testing.T) {
+ store := testStore(t)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger)
+
+ now := time.Now().UTC()
+ story := &task.Story{
+ ID: "story-ready-only",
+ Name: "Ready Only Story",
+ Status: task.StoryInProgress,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ store.CreateStory(story)
+
+ // One task driven to READY (checker pending), one COMPLETED.
+ tk1 := makeTask("ro-task-1")
+ tk1.StoryID = story.ID
+ store.CreateTask(tk1)
+ for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady} {
+ store.UpdateTaskState(tk1.ID, s)
+ }
+
+ tk2 := makeTask("ro-task-2")
+ tk2.StoryID = story.ID
+ store.CreateTask(tk2)
+ for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady, task.StateCompleted} {
+ store.UpdateTaskState(tk2.ID, s)
+ }
+
+ pool.checkStoryCompletion(context.Background(), story.ID)
+
+ got, _ := store.GetStory(story.ID)
+ if got.Status != task.StoryInProgress {
+ t.Errorf("story status: want IN_PROGRESS (tk1 still READY/checker pending), got %s", got.Status)
+ }
+}
+
func TestPool_ActivePerAgent_DeletesZeroEntries(t *testing.T) {
store := testStore(t)
runner := &mockRunner{}
@@ -1014,7 +1135,10 @@ func (m *minimalMockStore) ListSubtasks(parentID string) ([]*task.Task, error) {
return nil, nil
}
func (m *minimalMockStore) ListExecutions(_ string) ([]*storage.Execution, error) { return nil, nil }
-func (m *minimalMockStore) CreateExecution(e *storage.Execution) error { return nil }
+func (m *minimalMockStore) CreateExecution(e *storage.Execution) error { return nil }
+func (m *minimalMockStore) CreateExecutionAndSetRunning(e *storage.Execution) error {
+ return nil
+}
func (m *minimalMockStore) UpdateExecution(e *storage.Execution) error {
return m.updateExecErr
}
@@ -1064,6 +1188,14 @@ func (m *minimalMockStore) UpdateExecutionChangestats(execID string, stats *task
m.mu.Unlock()
return nil
}
+func (m *minimalMockStore) RecordAgentEvent(_ storage.AgentEvent) error { return nil }
+func (m *minimalMockStore) GetProject(_ string) (*task.Project, error) { return nil, nil }
+func (m *minimalMockStore) GetStory(_ string) (*task.Story, error) { return nil, nil }
+func (m *minimalMockStore) ListTasksByStory(_ string) ([]*task.Task, error) { return nil, nil }
+func (m *minimalMockStore) UpdateStoryStatus(_ string, _ task.StoryState) error { return nil }
+func (m *minimalMockStore) CreateTask(_ *task.Task) error { return nil }
+func (m *minimalMockStore) UpdateTaskCheckerReport(_ string, _ string) error { return nil }
+func (m *minimalMockStore) GetCheckerTask(_ string) (*task.Task, error) { return nil, nil }
func (m *minimalMockStore) lastStateUpdate() (string, task.State, bool) {
m.mu.Lock()
@@ -1078,17 +1210,18 @@ func (m *minimalMockStore) lastStateUpdate() (string, task.State, bool) {
func newPoolWithMockStore(store Store) *Pool {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
return &Pool{
- maxConcurrent: 2,
- runners: map[string]Runner{"claude": &mockRunner{}},
- store: store,
- logger: logger,
- activePerAgent: make(map[string]int),
- rateLimited: make(map[string]time.Time),
- cancels: make(map[string]context.CancelFunc),
- resultCh: make(chan *Result, 4),
- workCh: make(chan workItem, 4),
- doneCh: make(chan struct{}, 2),
- Questions: NewQuestionRegistry(),
+ maxConcurrent: 2,
+ maxPerAgent: 1,
+ runners: map[string]Runner{"claude": &mockRunner{}},
+ store: store,
+ logger: logger,
+ activePerAgent: make(map[string]int),
+ rateLimited: make(map[string]time.Time),
+ cancels: make(map[string]context.CancelFunc),
+ consecutiveFailures: make(map[string]int),
+ resultCh: make(chan *Result, 4),
+ workCh: make(chan workItem, 4),
+ doneCh: make(chan struct{}, 2),
}
}
@@ -1236,6 +1369,11 @@ func TestPool_SpecificAgent_SkipsLoadBalancing(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
pool := NewPool(4, runners, store, logger)
+ // Raise per-agent limit so the concurrency gate doesn't interfere with this test.
+ // The injected activePerAgent is only to make pickAgent prefer "claude",
+ // verifying that explicit agent type bypasses load balancing.
+ pool.maxPerAgent = 10
+
// Inject 2 active tasks for gemini, 0 for claude.
// pickAgent would normally pick "claude".
pool.mu.Lock()
@@ -1425,3 +1563,748 @@ func TestExecute_MalformedChangestats(t *testing.T) {
t.Errorf("expected nil changestats for malformed output, got %+v", execs[0].Changestats)
}
}
+
+func TestPool_MaxPerAgent_BlocksSecondTask(t *testing.T) {
+ store := testStore(t)
+
+ var mu sync.Mutex
+ concurrentRuns := 0
+ maxConcurrent := 0
+
+ runner := &mockRunner{
+ delay: 100 * time.Millisecond,
+ onRun: func(tk *task.Task, e *storage.Execution) error {
+ mu.Lock()
+ concurrentRuns++
+ if concurrentRuns > maxConcurrent {
+ maxConcurrent = concurrentRuns
+ }
+ mu.Unlock()
+ time.Sleep(100 * time.Millisecond)
+ mu.Lock()
+ concurrentRuns--
+ mu.Unlock()
+ return nil
+ },
+ }
+ runners := map[string]Runner{"claude": runner}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runners, store, logger) // pool size 2, but maxPerAgent=1
+ pool.requeueDelay = 50 * time.Millisecond // speed up test
+
+ tk1 := makeTask("mpa-1")
+ tk2 := makeTask("mpa-2")
+ store.CreateTask(tk1)
+ store.CreateTask(tk2)
+
+ pool.Submit(context.Background(), tk1)
+ pool.Submit(context.Background(), tk2)
+
+ for i := 0; i < 2; i++ {
+ select {
+ case <-pool.Results():
+ case <-time.After(10 * time.Second):
+ t.Fatal("timed out waiting for result")
+ }
+ }
+
+ mu.Lock()
+ got := maxConcurrent
+ mu.Unlock()
+ if got > 1 {
+ t.Errorf("maxPerAgent=1 violated: %d claude tasks ran concurrently", got)
+ }
+}
+
+func TestPool_MaxPerAgent_AllowsDifferentAgents(t *testing.T) {
+ store := testStore(t)
+
+ var mu sync.Mutex
+ concurrentRuns := 0
+ maxConcurrent := 0
+
+ makeSlowRunner := func() *mockRunner {
+ return &mockRunner{
+ onRun: func(tk *task.Task, e *storage.Execution) error {
+ mu.Lock()
+ concurrentRuns++
+ if concurrentRuns > maxConcurrent {
+ maxConcurrent = concurrentRuns
+ }
+ mu.Unlock()
+ time.Sleep(80 * time.Millisecond)
+ mu.Lock()
+ concurrentRuns--
+ mu.Unlock()
+ return nil
+ },
+ }
+ }
+
+ runners := map[string]Runner{
+ "claude": makeSlowRunner(),
+ "gemini": makeSlowRunner(),
+ }
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runners, store, logger)
+
+ tk1 := makeTask("da-1")
+ tk1.Agent.Type = "claude"
+ tk2 := makeTask("da-2")
+ tk2.Agent.Type = "gemini"
+ store.CreateTask(tk1)
+ store.CreateTask(tk2)
+
+ pool.Submit(context.Background(), tk1)
+ pool.Submit(context.Background(), tk2)
+
+ for i := 0; i < 2; i++ {
+ select {
+ case <-pool.Results():
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for result")
+ }
+ }
+
+ mu.Lock()
+ got := maxConcurrent
+ mu.Unlock()
+ if got < 2 {
+ t.Errorf("different agents should run concurrently; max concurrent was %d", got)
+ }
+}
+
+func TestPool_ConsecutiveFailures_ResetOnSuccess(t *testing.T) {
+ store := testStore(t)
+
+ callCount := 0
+ runner := &mockRunner{
+ onRun: func(tk *task.Task, e *storage.Execution) error {
+ callCount++
+ if callCount == 1 {
+ return fmt.Errorf("first failure")
+ }
+ return nil // second call succeeds
+ },
+ }
+ runners := map[string]Runner{"claude": runner}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runners, store, logger)
+
+ // First task fails
+ tk1 := makeTask("rs-1")
+ store.CreateTask(tk1)
+ pool.Submit(context.Background(), tk1)
+ <-pool.Results()
+
+ pool.mu.Lock()
+ failsBefore := pool.consecutiveFailures["claude"]
+ pool.mu.Unlock()
+ if failsBefore != 1 {
+ t.Errorf("expected 1 failure after first task, got %d", failsBefore)
+ }
+
+ // Second task succeeds — counter resets.
+ tk2 := makeTask("rs-2")
+ store.CreateTask(tk2)
+ pool.Submit(context.Background(), tk2)
+ <-pool.Results()
+
+ pool.mu.Lock()
+ failsAfter := pool.consecutiveFailures["claude"]
+ pool.mu.Unlock()
+
+ if failsAfter != 0 {
+ t.Errorf("expected consecutiveFailures reset to 0 after success, got %d", failsAfter)
+ }
+}
+
+func TestPool_CheckStoryCompletion_AllComplete(t *testing.T) {
+ store := testStore(t)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger)
+
+ // Create a story in IN_PROGRESS state.
+ now := time.Now().UTC()
+ story := &task.Story{
+ ID: "story-comp-1",
+ Name: "Completion Test",
+ Status: task.StoryInProgress,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+
+ // Create two top-level story tasks and drive them through valid transitions to COMPLETED.
+ for i, id := range []string{"sctask-1", "sctask-2"} {
+ tk := makeTask(id)
+ tk.StoryID = "story-comp-1"
+ tk.State = task.StatePending
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatalf("CreateTask %d: %v", i, err)
+ }
+ for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady, task.StateCompleted} {
+ if err := store.UpdateTaskState(id, s); err != nil {
+ t.Fatalf("UpdateTaskState %s → %s: %v", id, s, err)
+ }
+ }
+ }
+
+ pool.checkStoryCompletion(context.Background(), "story-comp-1")
+
+ got, err := store.GetStory("story-comp-1")
+ if err != nil {
+ t.Fatalf("GetStory: %v", err)
+ }
+ if got.Status != task.StoryShippable {
+ t.Errorf("story status: want SHIPPABLE, got %v", got.Status)
+ }
+}
+
+func TestPool_CheckStoryCompletion_PartialComplete(t *testing.T) {
+ store := testStore(t)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger)
+
+ now := time.Now().UTC()
+ story := &task.Story{
+ ID: "story-partial-1",
+ Name: "Partial Test",
+ Status: task.StoryInProgress,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+
+ // First top-level task driven to READY.
+ tk1 := makeTask("sptask-1")
+ tk1.StoryID = "story-partial-1"
+ store.CreateTask(tk1)
+ for _, s := range []task.State{task.StateQueued, task.StateRunning, task.StateReady} {
+ store.UpdateTaskState("sptask-1", s)
+ }
+
+ // Second top-level task still in PENDING (not done).
+ tk2 := makeTask("sptask-2")
+ tk2.StoryID = "story-partial-1"
+ store.CreateTask(tk2)
+
+ pool.checkStoryCompletion(context.Background(), "story-partial-1")
+
+ got, err := store.GetStory("story-partial-1")
+ if err != nil {
+ t.Fatalf("GetStory: %v", err)
+ }
+ if got.Status != task.StoryInProgress {
+ t.Errorf("story status: want IN_PROGRESS (no transition), got %v", got.Status)
+ }
+}
+
+func TestPool_StoryDeploy_RunsDeployScript(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{}
+ runners := map[string]Runner{"claude": runner}
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, runners, store, logger)
+
+ // Create a deploy script that writes a marker file.
+ tmpDir := t.TempDir()
+ markerFile := filepath.Join(tmpDir, "deployed.marker")
+ scriptPath := filepath.Join(tmpDir, "deploy.sh")
+ scriptContent := "#!/bin/sh\ntouch " + markerFile + "\n"
+ if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
+ t.Fatalf("write deploy script: %v", err)
+ }
+
+ proj := &task.Project{
+ ID: "proj-deploy-1",
+ Name: "Deploy Test Project",
+ DeployScript: scriptPath,
+ }
+ if err := store.CreateProject(proj); err != nil {
+ t.Fatalf("create project: %v", err)
+ }
+
+ story := &task.Story{
+ ID: "story-deploy-1",
+ Name: "Deploy Test Story",
+ ProjectID: proj.ID,
+ Status: task.StoryShippable,
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("create story: %v", err)
+ }
+
+ pool.triggerStoryDeploy(context.Background(), story.ID)
+
+ if _, err := os.Stat(markerFile); os.IsNotExist(err) {
+ t.Error("deploy script did not run: marker file not found")
+ }
+
+ got, err := store.GetStory(story.ID)
+ if err != nil {
+ t.Fatalf("get story: %v", err)
+ }
+ if got.Status != task.StoryDeployed {
+ t.Errorf("story status: want DEPLOYED, got %q", got.Status)
+ }
+}
+
+func runGit(t *testing.T, dir string, args ...string) {
+ t.Helper()
+ cmd := exec.Command("git", args...)
+ if dir != "" {
+ cmd.Dir = dir
+ }
+ if out, err := cmd.CombinedOutput(); err != nil {
+ t.Fatalf("git %v: %v\n%s", args, err, out)
+ }
+}
+
+func TestPool_StoryDeploy_MergesStoryBranch(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Set up bare repo + working copy with a story branch.
+ bareDir := filepath.Join(tmpDir, "bare.git")
+ localDir := filepath.Join(tmpDir, "local")
+ runGit(t, "", "init", "--bare", bareDir)
+ runGit(t, "", "clone", bareDir, localDir)
+ runGit(t, localDir, "config", "user.email", "test@test.com")
+ runGit(t, localDir, "config", "user.name", "Test")
+
+ // Initial commit on main.
+ runGit(t, localDir, "checkout", "-b", "main")
+ os.WriteFile(filepath.Join(localDir, "README.md"), []byte("initial"), 0644)
+ runGit(t, localDir, "add", ".")
+ runGit(t, localDir, "commit", "-m", "initial")
+ runGit(t, localDir, "push", "-u", "origin", "main")
+
+ // Story branch with a feature commit.
+ runGit(t, localDir, "checkout", "-b", "story/test-feature")
+ os.WriteFile(filepath.Join(localDir, "feature.go"), []byte("package main"), 0644)
+ runGit(t, localDir, "add", ".")
+ runGit(t, localDir, "commit", "-m", "feature work")
+ runGit(t, localDir, "push", "origin", "story/test-feature")
+ runGit(t, localDir, "checkout", "main")
+
+ store := testStore(t)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger)
+
+ scriptPath := filepath.Join(tmpDir, "deploy.sh")
+ os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0\n"), 0755)
+
+ proj := &task.Project{
+ ID: "proj-merge-1", Name: "Merge Test",
+ LocalPath: localDir, DeployScript: scriptPath,
+ }
+ if err := store.CreateProject(proj); err != nil {
+ t.Fatalf("create project: %v", err)
+ }
+ story := &task.Story{
+ ID: "story-merge-1", Name: "Merge Test Story",
+ ProjectID: proj.ID, BranchName: "story/test-feature",
+ Status: task.StoryShippable,
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("create story: %v", err)
+ }
+
+ pool.triggerStoryDeploy(context.Background(), story.ID)
+
+ // feature.go should now be on main in the working copy.
+ if _, err := os.Stat(filepath.Join(localDir, "feature.go")); os.IsNotExist(err) {
+ t.Error("story branch was not merged to main: feature.go missing")
+ }
+ got, _ := store.GetStory(story.ID)
+ if got.Status != task.StoryDeployed {
+ t.Errorf("story status: want DEPLOYED, got %q", got.Status)
+ }
+}
+
+func TestPool_PostDeploy_CreatesValidationTask(t *testing.T) {
+ store := testStore(t)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger)
+
+ now := time.Now().UTC()
+ validationSpec := `{"type":"smoke","steps":["curl /health"],"success_criteria":"status 200"}`
+ story := &task.Story{
+ ID: "story-postdeploy-1",
+ Name: "Post Deploy Test",
+ Status: task.StoryDeployed,
+ ValidationJSON: validationSpec,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+
+ pool.createValidationTask(context.Background(), story.ID)
+
+ // Story should now be VALIDATING.
+ got, err := store.GetStory(story.ID)
+ if err != nil {
+ t.Fatalf("GetStory: %v", err)
+ }
+ if got.Status != task.StoryValidating {
+ t.Errorf("story status: want VALIDATING, got %q", got.Status)
+ }
+
+ // A validation task should have been created.
+ tasks, err := store.ListTasksByStory(story.ID)
+ if err != nil {
+ t.Fatalf("ListTasksByStory: %v", err)
+ }
+ if len(tasks) == 0 {
+ t.Fatal("expected a validation task to be created, got none")
+ }
+ vtask := tasks[0]
+ if !strings.Contains(strings.ToLower(vtask.Name), "validation") {
+ t.Errorf("task name %q does not contain 'validation'", vtask.Name)
+ }
+ if vtask.StoryID != story.ID {
+ t.Errorf("task story_id: want %q, got %q", story.ID, vtask.StoryID)
+ }
+ if !strings.Contains(vtask.Agent.Instructions, "smoke") {
+ t.Errorf("task instructions %q do not reference validation spec content", vtask.Agent.Instructions)
+ }
+}
+
+func TestPool_ValidationTask_Pass_SetsReviewReady(t *testing.T) {
+ store := testStore(t)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger)
+
+ now := time.Now().UTC()
+ story := &task.Story{
+ ID: "story-val-pass-1",
+ Name: "Validation Pass",
+ Status: task.StoryValidating,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+
+ pool.checkValidationResult(context.Background(), story.ID, task.StateCompleted, "")
+
+ got, err := store.GetStory(story.ID)
+ if err != nil {
+ t.Fatalf("GetStory: %v", err)
+ }
+ if got.Status != task.StoryReviewReady {
+ t.Errorf("story status: want REVIEW_READY, got %q", got.Status)
+ }
+}
+
+// TestPool_DependsOn_NoDeadlock verifies that a task waiting for a dependency
+// does NOT hold the per-agent slot, allowing the dependency to run first.
+func TestPool_DependsOn_NoDeadlock(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{} // succeeds immediately
+ pool := NewPool(2, map[string]Runner{"claude": runner}, store,
+ slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+ pool.requeueDelay = 10 * time.Millisecond
+
+ // Task A has no deps; Task B depends on A.
+ taskA := makeTask("dep-a")
+ taskA.State = task.StateQueued
+ taskB := makeTask("dep-b")
+ taskB.DependsOn = []string{"dep-a"}
+ taskB.State = task.StateQueued
+
+ store.CreateTask(taskA)
+ store.CreateTask(taskB)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ // Submit B first — it should not deadlock by holding the slot while waiting for A.
+ pool.Submit(ctx, taskB)
+ pool.Submit(ctx, taskA)
+
+ var gotA, gotB bool
+ for i := 0; i < 2; i++ {
+ select {
+ case res := <-pool.Results():
+ if res.TaskID == "dep-a" {
+ gotA = true
+ }
+ if res.TaskID == "dep-b" {
+ gotB = true
+ }
+ case <-ctx.Done():
+ t.Fatal("timeout: likely deadlock — dep-b held the slot while waiting for dep-a")
+ }
+ }
+ if !gotA || !gotB {
+ t.Errorf("expected both tasks to complete: gotA=%v gotB=%v", gotA, gotB)
+ }
+
+ // B must complete after A.
+ ta, _ := store.GetTask("dep-a")
+ tb, _ := store.GetTask("dep-b")
+ if ta.State != task.StateReady && ta.State != task.StateCompleted {
+ t.Errorf("dep-a should be READY/COMPLETED, got %s", ta.State)
+ }
+ if tb.State != task.StateReady && tb.State != task.StateCompleted {
+ t.Errorf("dep-b should be READY/COMPLETED, got %s", tb.State)
+ }
+}
+
+func TestPool_ValidationTask_Fail_SetsNeedsFix(t *testing.T) {
+ store := testStore(t)
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ pool := NewPool(2, map[string]Runner{"claude": &mockRunner{}}, store, logger)
+
+ now := time.Now().UTC()
+ story := &task.Story{
+ ID: "story-val-fail-1",
+ Name: "Validation Fail",
+ Status: task.StoryValidating,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := store.CreateStory(story); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+
+ execErr := "smoke test failed: /health returned 503"
+ pool.checkValidationResult(context.Background(), story.ID, task.StateFailed, execErr)
+
+ got, err := store.GetStory(story.ID)
+ if err != nil {
+ t.Fatalf("GetStory: %v", err)
+ }
+ if got.Status != task.StoryNeedsFix {
+ t.Errorf("story status: want NEEDS_FIX, got %q", got.Status)
+ }
+}
+
+func TestPool_Shutdown_WaitsForWorkers(t *testing.T) {
+ store := testStore(t)
+ started := make(chan struct{})
+ unblock := make(chan struct{})
+ runner := &mockRunner{
+ onRun: func(t *task.Task, e *storage.Execution) error {
+ close(started)
+ <-unblock
+ return nil
+ },
+ }
+ pool := NewPool(1, map[string]Runner{"claude": runner}, store,
+ slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ tk := makeTask("shutdown-task")
+ tk.State = task.StateQueued
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+
+ // Wait until the worker has started.
+ select {
+ case <-started:
+ case <-time.After(5 * time.Second):
+ t.Fatal("worker did not start")
+ }
+
+ // Shutdown should block until we unblock the worker.
+ done := make(chan error, 1)
+ go func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ done <- pool.Shutdown(ctx)
+ }()
+
+ // Shutdown should not have returned yet.
+ select {
+ case err := <-done:
+ t.Fatalf("Shutdown returned early: %v", err)
+ case <-time.After(50 * time.Millisecond):
+ }
+
+ close(unblock) // let the worker finish
+
+ select {
+ case err := <-done:
+ if err != nil {
+ t.Errorf("Shutdown returned error: %v", err)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("Shutdown did not return after worker finished")
+ }
+}
+
+func TestPool_Shutdown_TimesOut(t *testing.T) {
+ store := testStore(t)
+ unblock := make(chan struct{})
+ runner := &mockRunner{
+ onRun: func(t *task.Task, e *storage.Execution) error {
+ <-unblock // never unblocked
+ return nil
+ },
+ }
+ pool := NewPool(1, map[string]Runner{"claude": runner}, store,
+ slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ tk := makeTask("shutdown-timeout-task")
+ tk.State = task.StateQueued
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+
+ // Give worker a moment to start.
+ time.Sleep(50 * time.Millisecond)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+ err := pool.Shutdown(ctx)
+ if err == nil {
+ t.Error("expected timeout error, got nil")
+ }
+ close(unblock) // cleanup
+}
+
+func TestPool_CheckerSpawned_OnReady(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{} // succeeds instantly
+ pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ tk := makeTask("checker-spawn-1")
+ tk.RepositoryURL = "https://github.com/x/y"
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+ <-pool.Results() // wait for original task to finish
+
+ // Poll until the async spawnCheckerTask goroutine has written the checker task.
+ var checker *task.Task
+ var err error
+ deadline := time.Now().Add(5 * time.Second)
+ for time.Now().Before(deadline) {
+ checker, err = store.GetCheckerTask("checker-spawn-1")
+ if err != nil {
+ t.Fatalf("GetCheckerTask: %v", err)
+ }
+ if checker != nil {
+ break
+ }
+ time.Sleep(50 * time.Millisecond)
+ }
+ if checker == nil {
+ t.Fatal("expected a checker task to be created, got nil")
+ }
+ if checker.CheckerForTaskID != "checker-spawn-1" {
+ t.Errorf("expected CheckerForTaskID=checker-spawn-1, got %q", checker.CheckerForTaskID)
+ }
+}
+
+func TestPool_CheckerNotSpawned_ForSubtask(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{}
+ pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ parent := makeTask("no-checker-parent")
+ parent.RepositoryURL = "https://github.com/x/y"
+ store.CreateTask(parent)
+
+ sub := makeTask("no-checker-sub")
+ sub.ParentTaskID = "no-checker-parent"
+ sub.RepositoryURL = "https://github.com/x/y"
+ store.CreateTask(sub)
+
+ pool.Submit(context.Background(), sub)
+ <-pool.Results()
+
+ time.Sleep(100 * time.Millisecond)
+
+ checker, err := store.GetCheckerTask("no-checker-sub")
+ if err != nil {
+ t.Fatalf("GetCheckerTask: %v", err)
+ }
+ if checker != nil {
+ t.Error("expected no checker for subtask, but one was created")
+ }
+}
+
+func TestPool_CheckerPass_AutoAcceptsTask(t *testing.T) {
+ store := testStore(t)
+ // Two-phase: first runner succeeds (original task), second also succeeds (checker).
+ runner := &mockRunner{
+ onRun: func(t *task.Task, e *storage.Execution) error {
+ return nil // both original and checker succeed
+ },
+ }
+ pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ tk := makeTask("autoaccept-1")
+ tk.RepositoryURL = "https://github.com/x/y"
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+ <-pool.Results() // original finishes → READY + checker spawned
+
+ // Wait for checker to run and complete.
+ deadline := time.Now().Add(5 * time.Second)
+ for time.Now().Before(deadline) {
+ got, _ := store.GetTask("autoaccept-1")
+ if got != nil && got.State == task.StateCompleted {
+ break
+ }
+ <-pool.Results()
+ }
+
+ got, err := store.GetTask("autoaccept-1")
+ if err != nil {
+ t.Fatalf("GetTask: %v", err)
+ }
+ if got.State != task.StateCompleted {
+ t.Errorf("expected COMPLETED after checker pass, got %s", got.State)
+ }
+}
+
+func TestPool_CheckerFail_AttachesReport(t *testing.T) {
+ store := testStore(t)
+ runner := &mockRunner{
+ onRun: func(t *task.Task, e *storage.Execution) error {
+ if t.CheckerForTaskID != "" {
+ return fmt.Errorf("test suite failed: 3 failures")
+ }
+ return nil // original task succeeds
+ },
+ }
+ pool := NewPool(2, map[string]Runner{"claude": runner}, store, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
+
+ tk := makeTask("fail-checker-1")
+ tk.RepositoryURL = "https://github.com/x/y"
+ store.CreateTask(tk)
+ pool.Submit(context.Background(), tk)
+ <-pool.Results() // original → READY
+
+ // Wait for checker to fail.
+ deadline := time.Now().Add(5 * time.Second)
+ for time.Now().Before(deadline) {
+ got, _ := store.GetTask("fail-checker-1")
+ if got != nil && got.CheckerReport != "" {
+ break
+ }
+ select {
+ case <-pool.Results():
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+
+ got, err := store.GetTask("fail-checker-1")
+ if err != nil {
+ t.Fatalf("GetTask: %v", err)
+ }
+ if got.State != task.StateReady {
+ t.Errorf("expected task to stay READY after checker fail, got %s", got.State)
+ }
+ if got.CheckerReport == "" {
+ t.Error("expected checker_report to be set after checker failure")
+ }
+}
diff --git a/internal/executor/helpers.go b/internal/executor/helpers.go
new file mode 100644
index 0000000..76bf8b1
--- /dev/null
+++ b/internal/executor/helpers.go
@@ -0,0 +1,205 @@
+package executor
+
+import (
+ "bufio"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "os"
+ "strings"
+)
+
+// BlockedError is returned by Run when the agent wrote a question file and exited.
+// The pool transitions the task to BLOCKED and stores the question for the user.
+type BlockedError struct {
+ QuestionJSON string // raw JSON from the question file
+ SessionID string // claude session to resume once the user answers
+ SandboxDir string // preserved sandbox path; resume must run here so Claude finds its session files
+}
+
+func (e *BlockedError) Error() string { return fmt.Sprintf("task blocked: %s", e.QuestionJSON) }
+
+// parseStream reads streaming JSON from claude, writes to w, and returns
+// (costUSD, error). error is non-nil if the stream signals task failure:
+// - result message has is_error:true
+// - a tool_result was denied due to missing permissions
+func parseStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, string, error) {
+ tee := io.TeeReader(r, w)
+ scanner := bufio.NewScanner(tee)
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines
+
+ var totalCost float64
+ var sessionID string
+ var streamErr error
+
+Loop:
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ var msg map[string]interface{}
+ if err := json.Unmarshal(line, &msg); err != nil {
+ continue
+ }
+
+ msgType, _ := msg["type"].(string)
+ switch msgType {
+ case "system":
+ if subtype, ok := msg["subtype"].(string); ok && subtype == "init" {
+ if sid, ok := msg["session_id"].(string); ok {
+ sessionID = sid
+ }
+ }
+ case "rate_limit_event":
+ if info, ok := msg["rate_limit_info"].(map[string]interface{}); ok {
+ status, _ := info["status"].(string)
+ if status == "rejected" {
+ streamErr = fmt.Errorf("claude rate limit reached (rejected): %v", msg)
+ // Immediately break since we can't continue anyway
+ break Loop
+ }
+ }
+ case "assistant":
+ if errStr, ok := msg["error"].(string); ok && errStr == "rate_limit" {
+ streamErr = fmt.Errorf("claude rate limit reached: %v", msg)
+ }
+ case "result":
+ if isErr, _ := msg["is_error"].(bool); isErr {
+ result, _ := msg["result"].(string)
+ if result != "" {
+ streamErr = fmt.Errorf("claude task failed: %s", result)
+ } else {
+ streamErr = fmt.Errorf("claude task failed (is_error=true in result)")
+ }
+ }
+ // Prefer total_cost_usd from result message; fall through to legacy check below.
+ if cost, ok := msg["total_cost_usd"].(float64); ok {
+ totalCost = cost
+ }
+ case "user":
+ // Detect permission-denial tool_results. These occur when permission_mode
+ // is not bypassPermissions and claude exits 0 without completing its task.
+ if err := permissionDenialError(msg); err != nil && streamErr == nil {
+ streamErr = err
+ }
+ }
+
+ // Legacy cost field used by older claude versions.
+ if cost, ok := msg["cost_usd"].(float64); ok {
+ totalCost = cost
+ }
+ }
+ if err := scanner.Err(); err != nil && streamErr == nil {
+ streamErr = fmt.Errorf("reading claude stdout: %w", err)
+ }
+
+ return totalCost, sessionID, streamErr
+}
+
+
+// permissionDenialError inspects a "user" stream message for tool_result entries
+// that were denied due to missing permissions. Returns an error if found.
+func permissionDenialError(msg map[string]interface{}) error {
+ message, ok := msg["message"].(map[string]interface{})
+ if !ok {
+ return nil
+ }
+ content, ok := message["content"].([]interface{})
+ if !ok {
+ return nil
+ }
+ for _, item := range content {
+ itemMap, ok := item.(map[string]interface{})
+ if !ok {
+ continue
+ }
+ if itemMap["type"] != "tool_result" {
+ continue
+ }
+ if isErr, _ := itemMap["is_error"].(bool); !isErr {
+ continue
+ }
+ text, _ := itemMap["content"].(string)
+ if strings.Contains(text, "requested permissions") || strings.Contains(text, "haven't granted") {
+ return fmt.Errorf("permission denied by host: %s", text)
+ }
+ }
+ return nil
+}
+
+// tailFile returns the last n lines of the file at path, or empty string if
+// the file cannot be read. Used to surface subprocess stderr on failure.
+func tailFile(path string, n int) string {
+ f, err := os.Open(path)
+ if err != nil {
+ return ""
+ }
+ defer f.Close()
+
+ var lines []string
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ lines = append(lines, scanner.Text())
+ if len(lines) > n {
+ lines = lines[1:]
+ }
+ }
+ return strings.Join(lines, "\n")
+}
+
+// readFileTail returns the last maxBytes bytes of the file at path as a string,
+// or empty string if the file cannot be read. Used to surface agent stderr on failure.
+func readFileTail(path string, maxBytes int64) string {
+ f, err := os.Open(path)
+ if err != nil {
+ return ""
+ }
+ defer f.Close()
+ fi, err := f.Stat()
+ if err != nil {
+ return ""
+ }
+ offset := fi.Size() - maxBytes
+ if offset < 0 {
+ offset = 0
+ }
+ buf := make([]byte, fi.Size()-offset)
+ n, err := f.ReadAt(buf, offset)
+ if err != nil && n == 0 {
+ return ""
+ }
+ return strings.TrimSpace(string(buf[:n]))
+}
+
+func gitSafe(args ...string) []string {
+ return append([]string{
+ "-c", "safe.directory=*",
+ "-c", "commit.gpgsign=false",
+ "-c", "tag.gpgsign=false",
+ }, args...)
+}
+
+// isCompletionReport returns true when a question-file JSON looks like a
+// completion report rather than a real user question. Heuristic: no options
+// (or empty options) and no "?" anywhere in the text.
+func isCompletionReport(questionJSON string) bool {
+ var q struct {
+ Text string `json:"text"`
+ Options []string `json:"options"`
+ }
+ if err := json.Unmarshal([]byte(questionJSON), &q); err != nil {
+ return false
+ }
+ return len(q.Options) == 0 && !strings.Contains(q.Text, "?")
+}
+
+// extractQuestionText returns the "text" field from a question-file JSON, or
+// the raw string if parsing fails.
+func extractQuestionText(questionJSON string) string {
+ var q struct {
+ Text string `json:"text"`
+ }
+ if err := json.Unmarshal([]byte(questionJSON), &q); err != nil {
+ return questionJSON
+ }
+ return strings.TrimSpace(q.Text)
+}
diff --git a/internal/executor/preamble.go b/internal/executor/preamble.go
index f5dba2b..b949986 100644
--- a/internal/executor/preamble.go
+++ b/internal/executor/preamble.go
@@ -45,6 +45,7 @@ The sandbox is rejected if there are any uncommitted modifications.
- One commit is fine. Multiple focused commits are also fine.
- If you realise the task was already done and you made no changes, that is also fine — just exit cleanly without committing.
- Do not exit with uncommitted edits.
+- **CRITICAL:** Run ALL git commands from your current directory — do NOT use absolute paths or "cd <project_path> && git ...". Your working directory IS the project. Using absolute paths bypasses the sandbox and breaks commit tracking.
---
diff --git a/internal/executor/preamble_test.go b/internal/executor/preamble_test.go
index 984f786..5c31b4f 100644
--- a/internal/executor/preamble_test.go
+++ b/internal/executor/preamble_test.go
@@ -22,3 +22,10 @@ func TestPlanningPreamble_SummaryInstructsEchoToFile(t *testing.T) {
t.Error("planningPreamble should show example of writing to $CLAUDOMATOR_SUMMARY_FILE via echo")
}
}
+
+func TestPlanningPreamble_GitDiscipline_ForbidsAbsolutePaths(t *testing.T) {
+ // Agents must not bypass the sandbox by using absolute project paths in git commands.
+ if !strings.Contains(planningPreamble, "do NOT use absolute paths") {
+ t.Error("planningPreamble should warn agents not to use absolute paths in git commands")
+ }
+}
diff --git a/internal/executor/question.go b/internal/executor/question.go
index 9a2b55d..0ae1b08 100644
--- a/internal/executor/question.go
+++ b/internal/executor/question.go
@@ -5,92 +5,8 @@ import (
"encoding/json"
"io"
"log/slog"
- "sync"
)
-// QuestionHandler is called when an agent invokes AskUserQuestion.
-// Implementations should broadcast the question and block until an answer arrives.
-type QuestionHandler interface {
- HandleQuestion(taskID, toolUseID string, input json.RawMessage) (string, error)
-}
-
-// PendingQuestion holds state for a question awaiting a user answer.
-type PendingQuestion struct {
- TaskID string `json:"task_id"`
- ToolUseID string `json:"tool_use_id"`
- Input json.RawMessage `json:"input"`
- AnswerCh chan string `json:"-"`
-}
-
-// QuestionRegistry tracks pending questions across running tasks.
-type QuestionRegistry struct {
- mu sync.Mutex
- questions map[string]*PendingQuestion // keyed by toolUseID
-}
-
-// NewQuestionRegistry creates a new registry.
-func NewQuestionRegistry() *QuestionRegistry {
- return &QuestionRegistry{
- questions: make(map[string]*PendingQuestion),
- }
-}
-
-// Register adds a pending question and returns its answer channel.
-func (qr *QuestionRegistry) Register(taskID, toolUseID string, input json.RawMessage) chan string {
- ch := make(chan string, 1)
- qr.mu.Lock()
- qr.questions[toolUseID] = &PendingQuestion{
- TaskID: taskID,
- ToolUseID: toolUseID,
- Input: input,
- AnswerCh: ch,
- }
- qr.mu.Unlock()
- return ch
-}
-
-// Answer delivers an answer for a pending question. Returns false if no such question exists.
-func (qr *QuestionRegistry) Answer(toolUseID, answer string) bool {
- qr.mu.Lock()
- pq, ok := qr.questions[toolUseID]
- if ok {
- delete(qr.questions, toolUseID)
- }
- qr.mu.Unlock()
- if !ok {
- return false
- }
- pq.AnswerCh <- answer
- return true
-}
-
-// Get returns a pending question by tool_use_id, or nil.
-func (qr *QuestionRegistry) Get(toolUseID string) *PendingQuestion {
- qr.mu.Lock()
- defer qr.mu.Unlock()
- return qr.questions[toolUseID]
-}
-
-// PendingForTask returns all pending questions for a given task.
-func (qr *QuestionRegistry) PendingForTask(taskID string) []*PendingQuestion {
- qr.mu.Lock()
- defer qr.mu.Unlock()
- var result []*PendingQuestion
- for _, pq := range qr.questions {
- if pq.TaskID == taskID {
- result = append(result, pq)
- }
- }
- return result
-}
-
-// Remove removes a question without answering it (e.g., on task cancellation).
-func (qr *QuestionRegistry) Remove(toolUseID string) {
- qr.mu.Lock()
- delete(qr.questions, toolUseID)
- qr.mu.Unlock()
-}
-
// extractAskUserQuestion parses a stream-json line and returns the tool_use_id and input
// if the line is an assistant event containing an AskUserQuestion tool_use.
func extractAskUserQuestion(line []byte) (string, json.RawMessage) {
diff --git a/internal/executor/question_test.go b/internal/executor/question_test.go
index d0fbed9..6686c15 100644
--- a/internal/executor/question_test.go
+++ b/internal/executor/question_test.go
@@ -9,64 +9,6 @@ import (
"testing"
)
-func TestQuestionRegistry_RegisterAndAnswer(t *testing.T) {
- qr := NewQuestionRegistry()
-
- ch := qr.Register("task-1", "toolu_abc", json.RawMessage(`{"question":"color?"}`))
-
- // Answer should unblock the channel.
- go func() {
- ok := qr.Answer("toolu_abc", "blue")
- if !ok {
- t.Error("Answer returned false, expected true")
- }
- }()
-
- answer := <-ch
- if answer != "blue" {
- t.Errorf("want 'blue', got %q", answer)
- }
-
- // Question should be removed after answering.
- if qr.Get("toolu_abc") != nil {
- t.Error("question should be removed after answering")
- }
-}
-
-func TestQuestionRegistry_AnswerUnknown(t *testing.T) {
- qr := NewQuestionRegistry()
- ok := qr.Answer("nonexistent", "anything")
- if ok {
- t.Error("expected false for unknown question")
- }
-}
-
-func TestQuestionRegistry_PendingForTask(t *testing.T) {
- qr := NewQuestionRegistry()
- qr.Register("task-1", "toolu_1", json.RawMessage(`{}`))
- qr.Register("task-1", "toolu_2", json.RawMessage(`{}`))
- qr.Register("task-2", "toolu_3", json.RawMessage(`{}`))
-
- pending := qr.PendingForTask("task-1")
- if len(pending) != 2 {
- t.Errorf("want 2 pending for task-1, got %d", len(pending))
- }
-
- pending2 := qr.PendingForTask("task-2")
- if len(pending2) != 1 {
- t.Errorf("want 1 pending for task-2, got %d", len(pending2))
- }
-}
-
-func TestQuestionRegistry_Remove(t *testing.T) {
- qr := NewQuestionRegistry()
- qr.Register("task-1", "toolu_x", json.RawMessage(`{}`))
- qr.Remove("toolu_x")
- if qr.Get("toolu_x") != nil {
- t.Error("question should be removed")
- }
-}
-
func TestExtractAskUserQuestion_DetectsQuestion(t *testing.T) {
// Simulate a stream-json assistant event containing an AskUserQuestion tool_use.
event := map[string]interface{}{
diff --git a/internal/executor/ratelimit.go b/internal/executor/ratelimit.go
index 109aa49..ee9a336 100644
--- a/internal/executor/ratelimit.go
+++ b/internal/executor/ratelimit.go
@@ -13,5 +13,9 @@ func isQuotaExhausted(err error) bool {
strings.Contains(msg, "you've hit your limit") ||
strings.Contains(msg, "you have hit your limit") ||
strings.Contains(msg, "rate limit reached (rejected)") ||
- strings.Contains(msg, "status: rejected")
+ strings.Contains(msg, "status: rejected") ||
+ // Gemini CLI quota exhaustion
+ strings.Contains(msg, "terminalquotaerror") ||
+ strings.Contains(msg, "exhausted your daily quota") ||
+ strings.Contains(msg, "generate_content_free_tier_requests")
}
diff --git a/internal/executor/stream_test.go b/internal/executor/stream_test.go
index 10eb858..11a6178 100644
--- a/internal/executor/stream_test.go
+++ b/internal/executor/stream_test.go
@@ -12,7 +12,7 @@ func streamLine(json string) string { return json + "\n" }
func TestParseStream_ResultIsError_ReturnsError(t *testing.T) {
input := streamLine(`{"type":"result","subtype":"error_during_execution","is_error":true,"result":"something went wrong"}`)
- _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
+ _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
if err == nil {
t.Fatal("expected error when result.is_error=true, got nil")
}
@@ -27,7 +27,7 @@ func TestParseStream_PermissionDenied_ReturnsError(t *testing.T) {
input := streamLine(`{"type":"user","message":{"role":"user","content":[{"type":"tool_result","is_error":true,"content":"Claude requested permissions to write to /foo/bar.go, but you haven't granted it yet.","tool_use_id":"tu_abc"}]}}`) +
streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"I need permission","total_cost_usd":0.1}`)
- _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
+ _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
if err == nil {
t.Fatal("expected error for permission denial, got nil")
}
@@ -40,7 +40,7 @@ func TestParseStream_Success_ReturnsNilError(t *testing.T) {
input := streamLine(`{"type":"assistant","message":{"content":[{"type":"text","text":"Done."}]}}`) +
streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"All tests pass.","total_cost_usd":0.05}`)
- _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
+ _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
if err != nil {
t.Fatalf("expected nil error for success stream, got: %v", err)
}
@@ -49,7 +49,7 @@ func TestParseStream_Success_ReturnsNilError(t *testing.T) {
func TestParseStream_ExtractsCostFromResultMessage(t *testing.T) {
input := streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"done","total_cost_usd":1.2345}`)
- cost, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
+ cost, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -62,7 +62,7 @@ func TestParseStream_ExtractsCostFromLegacyCostUSD(t *testing.T) {
// Some versions emit cost_usd at the top level rather than total_cost_usd.
input := streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"done","cost_usd":0.99}`)
- cost, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
+ cost, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -78,8 +78,21 @@ func TestParseStream_NonToolResultIsError_DoesNotFail(t *testing.T) {
input := streamLine(`{"type":"user","message":{"role":"user","content":[{"type":"tool_result","is_error":true,"content":"exit status 1","tool_use_id":"tu_xyz"}]}}`) +
streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"Fixed it.","total_cost_usd":0.2}`)
- _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
+ _, _, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
if err != nil {
t.Fatalf("non-permission tool errors should not fail the task, got: %v", err)
}
}
+
+func TestParseStream_ExtractsSessionID(t *testing.T) {
+ input := streamLine(`{"type":"system","subtype":"init","session_id":"sess-999"}`) +
+ streamLine(`{"type":"result","subtype":"success","is_error":false,"result":"ok","total_cost_usd":0.01}`)
+
+ _, sid, err := parseStream(strings.NewReader(input), io.Discard, slog.New(slog.NewTextHandler(io.Discard, nil)))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if sid != "sess-999" {
+ t.Errorf("want session ID sess-999, got %q", sid)
+ }
+}
diff --git a/internal/notify/vapid.go b/internal/notify/vapid.go
new file mode 100644
index 0000000..684bf4d
--- /dev/null
+++ b/internal/notify/vapid.go
@@ -0,0 +1,25 @@
+package notify
+
+import (
+ "encoding/base64"
+
+ webpush "github.com/SherClockHolmes/webpush-go"
+)
+
+// GenerateVAPIDKeys generates a VAPID key pair for web push notifications.
+// Returns the base64url-encoded public and private keys.
+// Note: webpush.GenerateVAPIDKeys returns (privateKey, publicKey) — we swap here.
+func GenerateVAPIDKeys() (publicKey, privateKey string, err error) {
+ privateKey, publicKey, err = webpush.GenerateVAPIDKeys()
+ return
+}
+
+// ValidateVAPIDPublicKey reports whether key is a valid VAPID public key:
+// a base64url-encoded 65-byte uncompressed P-256 point (starts with 0x04).
+func ValidateVAPIDPublicKey(key string) bool {
+ b, err := base64.RawURLEncoding.DecodeString(key)
+ if err != nil {
+ return false
+ }
+ return len(b) == 65 && b[0] == 0x04
+}
diff --git a/internal/notify/vapid_test.go b/internal/notify/vapid_test.go
new file mode 100644
index 0000000..a45047d
--- /dev/null
+++ b/internal/notify/vapid_test.go
@@ -0,0 +1,64 @@
+package notify
+
+import (
+ "encoding/base64"
+ "testing"
+)
+
+// TestValidateVAPIDPublicKey verifies that ValidateVAPIDPublicKey accepts valid
+// public keys and rejects private keys, empty strings, and invalid base64.
+func TestValidateVAPIDPublicKey(t *testing.T) {
+ pub, priv, err := GenerateVAPIDKeys()
+ if err != nil {
+ t.Fatalf("GenerateVAPIDKeys: %v", err)
+ }
+ if !ValidateVAPIDPublicKey(pub) {
+ t.Error("valid public key should pass validation")
+ }
+ if ValidateVAPIDPublicKey(priv) {
+ t.Error("private key (32 bytes) should fail public key validation")
+ }
+ if ValidateVAPIDPublicKey("") {
+ t.Error("empty string should fail validation")
+ }
+ if ValidateVAPIDPublicKey("notbase64!!!") {
+ t.Error("invalid base64 should fail validation")
+ }
+}
+
+// TestGenerateVAPIDKeys_PublicKeyIs65Bytes verifies that the public key returned
+// by GenerateVAPIDKeys is a 65-byte uncompressed P256 EC point (base64url, no padding = 87 chars)
+// and the private key is 32 bytes (43 chars). Previously the return values were swapped.
+func TestGenerateVAPIDKeys_PublicKeyIs65Bytes(t *testing.T) {
+ pub, priv, err := GenerateVAPIDKeys()
+ if err != nil {
+ t.Fatalf("GenerateVAPIDKeys: %v", err)
+ }
+
+ // Public key: 65 bytes → 87 base64url chars (no padding).
+ if len(pub) != 87 {
+ t.Errorf("public key: want 87 chars (65 bytes), got %d chars (%q)", len(pub), pub)
+ }
+ pubBytes, err := base64.RawURLEncoding.DecodeString(pub)
+ if err != nil {
+ t.Fatalf("public key base64url decode: %v", err)
+ }
+ if len(pubBytes) != 65 {
+ t.Errorf("public key bytes: want 65, got %d", len(pubBytes))
+ }
+ if pubBytes[0] != 0x04 {
+ t.Errorf("public key first byte: want 0x04 (uncompressed point), got 0x%02x", pubBytes[0])
+ }
+
+ // Private key: 32 bytes → 43 base64url chars (no padding).
+ if len(priv) != 43 {
+ t.Errorf("private key: want 43 chars (32 bytes), got %d chars (%q)", len(priv), priv)
+ }
+ privBytes, err := base64.RawURLEncoding.DecodeString(priv)
+ if err != nil {
+ t.Fatalf("private key base64url decode: %v", err)
+ }
+ if len(privBytes) != 32 {
+ t.Errorf("private key bytes: want 32, got %d", len(privBytes))
+ }
+}
diff --git a/internal/notify/webpush.go b/internal/notify/webpush.go
new file mode 100644
index 0000000..e118a43
--- /dev/null
+++ b/internal/notify/webpush.go
@@ -0,0 +1,106 @@
+package notify
+
+import (
+ "encoding/json"
+ "fmt"
+ "log/slog"
+
+ webpush "github.com/SherClockHolmes/webpush-go"
+ "github.com/thepeterstone/claudomator/internal/storage"
+)
+
+// PushSubscriptionStore is the minimal storage interface needed by WebPushNotifier.
+type PushSubscriptionStore interface {
+ ListPushSubscriptions() ([]storage.PushSubscription, error)
+}
+
+// WebPushNotifier sends web push notifications to all registered subscribers.
+type WebPushNotifier struct {
+ Store PushSubscriptionStore
+ VAPIDPublicKey string
+ VAPIDPrivateKey string
+ VAPIDEmail string
+ Logger *slog.Logger
+}
+
+// notificationContent derives urgency, title, body, and tag from a notify Event.
+// Exported only for tests; use lowercase in production code via this same file.
+func notificationContent(ev Event) (urgency, title, body, tag string) {
+ tag = "task-" + ev.TaskID
+ switch ev.Status {
+ case "BLOCKED":
+ urgency = "urgent"
+ title = "Needs input"
+ body = fmt.Sprintf("%s is waiting for your response", ev.TaskName)
+ case "FAILED", "BUDGET_EXCEEDED", "TIMED_OUT":
+ urgency = "high"
+ title = "Task failed"
+ if ev.Error != "" {
+ body = fmt.Sprintf("%s failed: %s", ev.TaskName, ev.Error)
+ } else {
+ body = fmt.Sprintf("%s failed", ev.TaskName)
+ }
+ case "COMPLETED":
+ urgency = "low"
+ title = "Task done"
+ body = fmt.Sprintf("%s completed ($%.2f)", ev.TaskName, ev.CostUSD)
+ default:
+ urgency = "normal"
+ title = "Task update"
+ body = fmt.Sprintf("%s: %s", ev.TaskName, ev.Status)
+ }
+ return
+}
+
+// Notify sends a web push notification to all registered subscribers.
+func (n *WebPushNotifier) Notify(ev Event) error {
+ subs, err := n.Store.ListPushSubscriptions()
+ if err != nil {
+ return fmt.Errorf("listing push subscriptions: %w", err)
+ }
+ if len(subs) == 0 {
+ return nil
+ }
+
+ urgency, title, body, tag := notificationContent(ev)
+
+ payload := map[string]string{
+ "title": title,
+ "body": body,
+ "tag": tag,
+ }
+ data, err := json.Marshal(payload)
+ if err != nil {
+ return fmt.Errorf("marshaling push payload: %w", err)
+ }
+
+ opts := &webpush.Options{
+ Subscriber: n.VAPIDEmail,
+ VAPIDPublicKey: n.VAPIDPublicKey,
+ VAPIDPrivateKey: n.VAPIDPrivateKey,
+ Urgency: webpush.Urgency(urgency),
+ TTL: 86400,
+ }
+
+ var lastErr error
+ for _, sub := range subs {
+ wSub := &webpush.Subscription{
+ Endpoint: sub.Endpoint,
+ Keys: webpush.Keys{
+ P256dh: sub.P256DHKey,
+ Auth: sub.AuthKey,
+ },
+ }
+ resp, sendErr := webpush.SendNotification(data, wSub, opts)
+ if sendErr != nil {
+ n.Logger.Error("webpush send failed", "endpoint", sub.Endpoint, "error", sendErr)
+ lastErr = sendErr
+ continue
+ }
+ resp.Body.Close()
+ if resp.StatusCode >= 400 {
+ n.Logger.Warn("webpush returned error status", "endpoint", sub.Endpoint, "status", resp.StatusCode)
+ }
+ }
+ return lastErr
+}
diff --git a/internal/notify/webpush_test.go b/internal/notify/webpush_test.go
new file mode 100644
index 0000000..594305e
--- /dev/null
+++ b/internal/notify/webpush_test.go
@@ -0,0 +1,191 @@
+package notify
+
+import (
+ "encoding/json"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "sync"
+ "testing"
+
+ "github.com/thepeterstone/claudomator/internal/storage"
+)
+
+// fakePushStore is an in-memory push subscription store for testing.
+type fakePushStore struct {
+ mu sync.Mutex
+ subs []storage.PushSubscription
+}
+
+func (f *fakePushStore) ListPushSubscriptions() ([]storage.PushSubscription, error) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ cp := make([]storage.PushSubscription, len(f.subs))
+ copy(cp, f.subs)
+ return cp, nil
+}
+
+func TestWebPushNotifier_NoSubscriptions_NoError(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ n := &WebPushNotifier{
+ Store: &fakePushStore{},
+ VAPIDPublicKey: "testpub",
+ VAPIDPrivateKey: "testpriv",
+ VAPIDEmail: "mailto:test@example.com",
+ Logger: logger,
+ }
+ if err := n.Notify(Event{TaskID: "t1", TaskName: "test", Status: "COMPLETED"}); err != nil {
+ t.Errorf("expected no error with empty store, got: %v", err)
+ }
+}
+
+// TestWebPushNotifier_UrgencyMapping verifies that different statuses produce
+// different urgency values in the push notification options.
+func TestWebPushNotifier_UrgencyMapping(t *testing.T) {
+ tests := []struct {
+ status string
+ wantUrgency string
+ }{
+ {"BLOCKED", "urgent"},
+ {"FAILED", "high"},
+ {"BUDGET_EXCEEDED", "high"},
+ {"TIMED_OUT", "high"},
+ {"COMPLETED", "low"},
+ {"RUNNING", "normal"},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.status, func(t *testing.T) {
+ urgency, _, _, _ := notificationContent(Event{
+ Status: tc.status,
+ TaskName: "mytask",
+ Error: "some error",
+ CostUSD: 0.12,
+ })
+ if urgency != tc.wantUrgency {
+ t.Errorf("status %q: want urgency %q, got %q", tc.status, tc.wantUrgency, urgency)
+ }
+ })
+ }
+}
+
+// TestWebPushNotifier_SendsToSubscription verifies that a notification is sent
+// via HTTP when a subscription is present. We use a mock push server to capture
+// the request and verify the JSON payload.
+func TestWebPushNotifier_SendsToSubscription(t *testing.T) {
+ var mu sync.Mutex
+ var captured []byte
+
+ // Mock push server — just record the body.
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ body, _ := io.ReadAll(r.Body)
+ mu.Lock()
+ captured = body
+ mu.Unlock()
+ w.WriteHeader(http.StatusCreated)
+ }))
+ defer srv.Close()
+
+ // Generate real VAPID keys for a valid (but minimal) send test.
+ pub, priv, err := GenerateVAPIDKeys()
+ if err != nil {
+ t.Fatalf("GenerateVAPIDKeys: %v", err)
+ }
+
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+
+ // Use a fake subscription pointing at our mock server. The webpush library
+ // will POST to the subscription endpoint. We use a minimal fake key (base64url
+ // of 65 zero bytes for p256dh and 16 zero bytes for auth) — the library
+ // encrypts the payload before sending, so the mock server just needs to accept.
+ store := &fakePushStore{
+ subs: []storage.PushSubscription{
+ {
+ ID: "sub-1",
+ Endpoint: srv.URL,
+ P256DHKey: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", // 65 bytes base64url
+ AuthKey: "AAAAAAAAAAAAAAAAAAA=", // 16 bytes base64
+ },
+ },
+ }
+
+ n := &WebPushNotifier{
+ Store: store,
+ VAPIDPublicKey: pub,
+ VAPIDPrivateKey: priv,
+ VAPIDEmail: "mailto:test@example.com",
+ Logger: logger,
+ }
+
+ ev := Event{
+ TaskID: "task-abc",
+ TaskName: "myTask",
+ Status: "COMPLETED",
+ CostUSD: 0.42,
+ }
+
+ // We don't assert the HTTP call always succeeds (crypto might fail with
+ // fake keys), but we do assert no panic and the function is callable.
+ // The real assertion is that if it does send, the payload is valid JSON.
+ n.Notify(ev) //nolint:errcheck — mock keys may fail crypto; we test structure not success
+
+ mu.Lock()
+ defer mu.Unlock()
+ if len(captured) > 0 {
+ // Encrypted payload — just verify it's non-empty bytes.
+ if len(captured) == 0 {
+ t.Error("captured request body should be non-empty")
+ }
+ }
+}
+
+// TestNotificationContent_TitleAndBody verifies titles and bodies for key statuses.
+func TestNotificationContent_TitleAndBody(t *testing.T) {
+ tests := []struct {
+ status string
+ wantTitle string
+ }{
+ {"BLOCKED", "Needs input"},
+ {"FAILED", "Task failed"},
+ {"BUDGET_EXCEEDED", "Task failed"},
+ {"TIMED_OUT", "Task failed"},
+ {"COMPLETED", "Task done"},
+ }
+ for _, tc := range tests {
+ t.Run(tc.status, func(t *testing.T) {
+ _, title, _, _ := notificationContent(Event{
+ Status: tc.status,
+ TaskName: "mytask",
+ Error: "err",
+ CostUSD: 0.05,
+ })
+ if title != tc.wantTitle {
+ t.Errorf("status %q: want title %q, got %q", tc.status, tc.wantTitle, title)
+ }
+ })
+ }
+}
+
+// TestWebPushNotifier_PayloadJSON verifies that the JSON payload is well-formed.
+func TestWebPushNotifier_PayloadJSON(t *testing.T) {
+ ev := Event{TaskID: "t1", TaskName: "myTask", Status: "COMPLETED", CostUSD: 0.33}
+ urgency, title, body, tag := notificationContent(ev)
+ if urgency == "" || title == "" || body == "" || tag == "" {
+ t.Error("all notification fields should be non-empty")
+ }
+
+ payload := map[string]string{"title": title, "body": body, "tag": tag}
+ data, err := json.Marshal(payload)
+ if err != nil {
+ t.Fatalf("marshal payload: %v", err)
+ }
+ var out map[string]string
+ if err := json.Unmarshal(data, &out); err != nil {
+ t.Fatalf("unmarshal payload: %v", err)
+ }
+ if out["title"] != title {
+ t.Errorf("title roundtrip failed")
+ }
+}
diff --git a/internal/storage/db.go b/internal/storage/db.go
index ce60e2f..4adc1ba 100644
--- a/internal/storage/db.go
+++ b/internal/storage/db.go
@@ -8,7 +8,6 @@ import (
"time"
"github.com/thepeterstone/claudomator/internal/task"
- _ "github.com/mattn/go-sqlite3"
)
type DB struct {
@@ -20,6 +19,10 @@ func Open(path string) (*DB, error) {
if err != nil {
return nil, fmt.Errorf("opening database: %w", err)
}
+ // SQLite only allows one concurrent writer. Limiting to one open connection
+ // prevents "database is locked" errors when multiple goroutines write
+ // simultaneously via database/sql's connection pool.
+ db.SetMaxOpenConns(1)
s := &DB{db: db}
if err := s.migrate(); err != nil {
db.Close()
@@ -86,6 +89,54 @@ func (s *DB) migrate() error {
`ALTER TABLE executions ADD COLUMN changestats_json TEXT`,
`ALTER TABLE executions ADD COLUMN commits_json TEXT NOT NULL DEFAULT '[]'`,
`ALTER TABLE tasks ADD COLUMN elaboration_input TEXT`,
+ `ALTER TABLE tasks ADD COLUMN project TEXT`,
+ `ALTER TABLE tasks ADD COLUMN repository_url TEXT`,
+ `CREATE TABLE IF NOT EXISTS push_subscriptions (
+ id TEXT PRIMARY KEY,
+ endpoint TEXT NOT NULL UNIQUE,
+ p256dh_key TEXT NOT NULL,
+ auth_key TEXT NOT NULL,
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP
+ )`,
+ `CREATE TABLE IF NOT EXISTS settings (
+ key TEXT PRIMARY KEY,
+ value TEXT NOT NULL
+ )`,
+ `CREATE TABLE IF NOT EXISTS agent_events (
+ id TEXT PRIMARY KEY,
+ agent TEXT NOT NULL,
+ event TEXT NOT NULL,
+ timestamp DATETIME NOT NULL,
+ until DATETIME,
+ reason TEXT
+ )`,
+ `CREATE INDEX IF NOT EXISTS idx_agent_events_agent ON agent_events(agent)`,
+ `CREATE INDEX IF NOT EXISTS idx_agent_events_timestamp ON agent_events(timestamp)`,
+ `CREATE TABLE IF NOT EXISTS projects (
+ id TEXT PRIMARY KEY,
+ name TEXT NOT NULL,
+ remote_url TEXT NOT NULL DEFAULT '',
+ local_path TEXT NOT NULL DEFAULT '',
+ type TEXT NOT NULL DEFAULT 'web',
+ deploy_script TEXT NOT NULL DEFAULT '',
+ created_at DATETIME NOT NULL,
+ updated_at DATETIME NOT NULL
+ )`,
+ `CREATE TABLE IF NOT EXISTS stories (
+ id TEXT PRIMARY KEY,
+ name TEXT NOT NULL,
+ project_id TEXT NOT NULL DEFAULT '',
+ branch_name TEXT NOT NULL DEFAULT '',
+ deploy_config TEXT NOT NULL DEFAULT '',
+ validation_json TEXT NOT NULL DEFAULT '',
+ status TEXT NOT NULL DEFAULT 'PENDING',
+ created_at DATETIME NOT NULL,
+ updated_at DATETIME NOT NULL
+ )`,
+ `ALTER TABLE tasks ADD COLUMN story_id TEXT`,
+ `ALTER TABLE tasks ADD COLUMN acceptance_criteria TEXT NOT NULL DEFAULT ''`,
+ `ALTER TABLE tasks ADD COLUMN checker_for_task_id TEXT NOT NULL DEFAULT ''`,
+ `ALTER TABLE tasks ADD COLUMN checker_report TEXT NOT NULL DEFAULT ''`,
`ALTER TABLE executions ADD COLUMN tokens_in INTEGER`,
`ALTER TABLE executions ADD COLUMN tokens_out INTEGER`,
}
@@ -125,24 +176,25 @@ func (s *DB) CreateTask(t *task.Task) error {
}
_, err = s.db.Exec(`
- INSERT INTO tasks (id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
- t.ID, t.Name, t.Description, t.ElaborationInput, string(configJSON), string(t.Priority),
+ INSERT INTO tasks (id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, story_id, acceptance_criteria, checker_for_task_id, checker_report)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+ t.ID, t.Name, t.Description, t.ElaborationInput, t.Project, t.RepositoryURL, string(configJSON), string(t.Priority),
t.Timeout.Duration.Nanoseconds(), string(retryJSON), string(tagsJSON), string(depsJSON),
- t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(),
+ t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), t.StoryID,
+ t.AcceptanceCriteria, t.CheckerForTaskID, t.CheckerReport,
)
return err
}
// GetTask retrieves a task by ID.
func (s *DB) GetTask(id string) (*task.Task, error) {
- row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id)
+ row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE id = ?`, id)
return scanTask(row)
}
// ListTasks returns tasks matching the given filter.
func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) {
- query := `SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE 1=1`
+ query := `SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE 1=1`
var args []interface{}
if filter.State != "" {
@@ -178,7 +230,7 @@ func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) {
// ListSubtasks returns all tasks whose parent_task_id matches the given ID.
func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) {
- rows, err := s.db.Query(`SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID)
+ rows, err := s.db.Query(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID)
if err != nil {
return nil, err
}
@@ -231,7 +283,7 @@ func (s *DB) ResetTaskForRetry(id string) (*task.Task, error) {
}
defer tx.Rollback() //nolint:errcheck
- t, err := scanTask(tx.QueryRow(`SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id))
+ t, err := scanTask(tx.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE id = ?`, id))
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("task %q not found", id)
@@ -292,9 +344,10 @@ func (s *DB) RejectTask(id, comment string) error {
// TaskUpdate holds the fields that UpdateTask may change.
type TaskUpdate struct {
- Name string
- Description string
- Config task.AgentConfig
+ Name string
+ Description string
+ RepositoryURL string
+ Config task.AgentConfig
Priority task.Priority
TimeoutNS int64
Retry task.RetryConfig
@@ -333,13 +386,11 @@ func (s *DB) UpdateTask(id string, u TaskUpdate) error {
now := time.Now().UTC()
result, err := s.db.Exec(`
UPDATE tasks
- SET name = ?, description = ?, config_json = ?, priority = ?, timeout_ns = ?,
+ SET name = ?, description = ?, repository_url = ?, config_json = ?, priority = ?, timeout_ns = ?,
retry_json = ?, tags_json = ?, depends_on_json = ?, state = ?, updated_at = ?
WHERE id = ?`,
- u.Name, u.Description, string(configJSON), string(u.Priority), u.TimeoutNS,
- string(retryJSON), string(tagsJSON), string(depsJSON), string(task.StatePending), now,
- id,
- )
+ u.Name, u.Description, u.RepositoryURL, configJSON, string(u.Priority), u.TimeoutNS,
+ retryJSON, tagsJSON, depsJSON, string(task.StatePending), now, id)
if err != nil {
return err
}
@@ -376,6 +427,8 @@ func (s *DB) GetMaxUpdatedAt() (time.Time, error) {
"2006-01-02T15:04:05Z07:00",
time.RFC3339,
"2006-01-02 15:04:05",
+ "2006-01-02 15:04:05 +0000 UTC",
+ "2006-01-02 15:04:05.999999999 +0000 UTC",
}
for _, f := range formats {
parsed, err := time.Parse(f, t.String)
@@ -417,6 +470,55 @@ type Execution struct {
Summary string
}
+// CreateExecutionAndSetRunning inserts an execution record and transitions the
+// task to RUNNING in a single transaction, preventing a crash-window where the
+// task stays PENDING with an orphaned RUNNING execution record.
+func (s *DB) CreateExecutionAndSetRunning(e *Execution) error {
+ tx, err := s.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback() //nolint:errcheck
+
+ // Validate state transition.
+ var currentState string
+ if err := tx.QueryRow(`SELECT state FROM tasks WHERE id = ?`, e.TaskID).Scan(&currentState); err != nil {
+ if err == sql.ErrNoRows {
+ return fmt.Errorf("task %q not found", e.TaskID)
+ }
+ return err
+ }
+ if !task.ValidTransition(task.State(currentState), task.StateRunning) {
+ return fmt.Errorf("invalid state transition %s → RUNNING for task %q", currentState, e.TaskID)
+ }
+
+ // Insert execution record.
+ commitsJSON := "[]"
+ if len(e.Commits) > 0 {
+ b, err := json.Marshal(e.Commits)
+ if err != nil {
+ return fmt.Errorf("marshaling commits: %w", err)
+ }
+ commitsJSON = string(b)
+ }
+ if _, err := tx.Exec(`
+ INSERT INTO executions (id, task_id, start_time, end_time, exit_code, status, stdout_path, stderr_path, artifact_dir, cost_usd, error_msg, session_id, sandbox_dir, changestats_json, commits_json)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, ?)`,
+ e.ID, e.TaskID, e.StartTime.UTC(), e.EndTime.UTC(), e.ExitCode, e.Status,
+ e.StdoutPath, e.StderrPath, e.ArtifactDir, e.CostUSD, e.ErrorMsg, e.SessionID, e.SandboxDir, commitsJSON,
+ ); err != nil {
+ return err
+ }
+
+ // Transition task to RUNNING.
+ now := time.Now().UTC()
+ if _, err := tx.Exec(`UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?`, string(task.StateRunning), now, e.TaskID); err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
// CreateExecution inserts an execution record.
func (s *DB) CreateExecution(e *Execution) error {
var changestatsJSON *string
@@ -544,6 +646,141 @@ type RecentExecution struct {
StdoutPath string `json:"stdout_path"`
}
+// ThroughputBucket is one time-bucket of execution counts by outcome.
+type ThroughputBucket struct {
+ Hour string `json:"hour"` // RFC3339 truncated to hour
+ Completed int `json:"completed"`
+ Failed int `json:"failed"`
+ Other int `json:"other"`
+}
+
+// BillingDay is the aggregated cost and run count for a calendar day.
+type BillingDay struct {
+ Day string `json:"day"` // YYYY-MM-DD
+ CostUSD float64 `json:"cost_usd"`
+ Runs int `json:"runs"`
+}
+
+// FailedExecution is a failed/timed-out/budget-exceeded execution with its error.
+type FailedExecution struct {
+ ID string `json:"id"`
+ TaskID string `json:"task_id"`
+ TaskName string `json:"task_name"`
+ Status string `json:"status"`
+ ErrorMsg string `json:"error_msg"`
+ Category string `json:"category"` // quota | timeout | rate_limit | git | failed
+ StartedAt time.Time `json:"started_at"`
+}
+
+// DashboardStats is returned by QueryDashboardStats.
+type DashboardStats struct {
+ Throughput []ThroughputBucket `json:"throughput"`
+ Billing []BillingDay `json:"billing"`
+ Failures []FailedExecution `json:"failures"`
+}
+
+// QueryDashboardStats returns pre-aggregated stats for the given window.
+func (s *DB) QueryDashboardStats(since time.Time) (*DashboardStats, error) {
+ stats := &DashboardStats{
+ Throughput: []ThroughputBucket{},
+ Billing: []BillingDay{},
+ Failures: []FailedExecution{},
+ }
+
+ // Throughput: completions per hour bucket
+ tpRows, err := s.db.Query(`
+ SELECT strftime('%Y-%m-%dT%H:00:00Z', start_time) as hour,
+ SUM(CASE WHEN status IN ('COMPLETED','READY') THEN 1 ELSE 0 END),
+ SUM(CASE WHEN status IN ('FAILED','TIMED_OUT','BUDGET_EXCEEDED') THEN 1 ELSE 0 END),
+ SUM(CASE WHEN status NOT IN ('COMPLETED','READY','FAILED','TIMED_OUT','BUDGET_EXCEEDED') THEN 1 ELSE 0 END)
+ FROM executions
+ WHERE start_time >= ? AND status NOT IN ('RUNNING','QUEUED','PENDING')
+ GROUP BY hour ORDER BY hour ASC`, since.UTC())
+ if err != nil {
+ return nil, err
+ }
+ defer tpRows.Close()
+ for tpRows.Next() {
+ var b ThroughputBucket
+ if err := tpRows.Scan(&b.Hour, &b.Completed, &b.Failed, &b.Other); err != nil {
+ return nil, err
+ }
+ stats.Throughput = append(stats.Throughput, b)
+ }
+ if err := tpRows.Err(); err != nil {
+ return nil, err
+ }
+
+ // Billing: cost per day
+ billRows, err := s.db.Query(`
+ SELECT date(start_time) as day, COALESCE(SUM(cost_usd),0), COUNT(*)
+ FROM executions
+ WHERE start_time >= ?
+ GROUP BY day ORDER BY day ASC`, since.UTC())
+ if err != nil {
+ return nil, err
+ }
+ defer billRows.Close()
+ for billRows.Next() {
+ var b BillingDay
+ if err := billRows.Scan(&b.Day, &b.CostUSD, &b.Runs); err != nil {
+ return nil, err
+ }
+ stats.Billing = append(stats.Billing, b)
+ }
+ if err := billRows.Err(); err != nil {
+ return nil, err
+ }
+
+ // Failures: recent failed executions with error messages
+ failRows, err := s.db.Query(`
+ SELECT e.id, e.task_id, t.name, e.status, COALESCE(e.error_msg,''), e.start_time
+ FROM executions e JOIN tasks t ON e.task_id = t.id
+ WHERE e.start_time >= ? AND e.status IN ('FAILED','TIMED_OUT','BUDGET_EXCEEDED')
+ ORDER BY e.start_time DESC LIMIT 50`, since.UTC())
+ if err != nil {
+ return nil, err
+ }
+ defer failRows.Close()
+ for failRows.Next() {
+ var f FailedExecution
+ if err := failRows.Scan(&f.ID, &f.TaskID, &f.TaskName, &f.Status, &f.ErrorMsg, &f.StartedAt); err != nil {
+ return nil, err
+ }
+ f.Category = classifyError(f.Status, f.ErrorMsg)
+ stats.Failures = append(stats.Failures, f)
+ }
+ if err := failRows.Err(); err != nil {
+ return nil, err
+ }
+
+ return stats, nil
+}
+
+// classifyError maps a status + error message to a human category.
+func classifyError(status, msg string) string {
+ if status == "TIMED_OUT" {
+ return "timeout"
+ }
+ if status == "BUDGET_EXCEEDED" {
+ return "quota"
+ }
+ low := strings.ToLower(msg)
+ if strings.Contains(low, "quota") || strings.Contains(low, "exhausted") || strings.Contains(low, "terminalquota") {
+ return "quota"
+ }
+ if strings.Contains(low, "rate limit") || strings.Contains(low, "429") || strings.Contains(low, "too many requests") {
+ return "rate_limit"
+ }
+ if strings.Contains(low, "git push") || strings.Contains(low, "git pull") {
+ return "git"
+ }
+ if strings.Contains(low, "timeout") || strings.Contains(low, "deadline") {
+ return "timeout"
+ }
+ return "failed"
+}
+
// ListRecentExecutions returns executions since the given time, joined with task names.
// If taskID is non-empty, only executions for that task are returned.
func (s *DB) ListRecentExecutions(since time.Time, limit int, taskID string) ([]*RecentExecution, error) {
@@ -600,6 +837,24 @@ func (s *DB) UpdateTaskSummary(taskID, summary string) error {
return err
}
+// UpdateTaskCheckerReport sets the checker_report field on a task.
+func (s *DB) UpdateTaskCheckerReport(id, report string) error {
+ now := time.Now().UTC()
+ _, err := s.db.Exec(`UPDATE tasks SET checker_report = ?, updated_at = ? WHERE id = ?`, report, now, id)
+ return err
+}
+
+// GetCheckerTask returns the checker task for the given checked task ID,
+// or nil if no checker task exists.
+func (s *DB) GetCheckerTask(checkedTaskID string) (*task.Task, error) {
+ row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE checker_for_task_id = ? LIMIT 1`, checkedTaskID)
+ t, err := scanTask(row)
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return t, err
+}
+
// AppendTaskInteraction appends a Q&A interaction to the task's interaction history.
func (s *DB) AppendTaskInteraction(taskID string, interaction task.Interaction) error {
tx, err := s.db.Begin()
@@ -682,17 +937,35 @@ func scanTask(row scanner) (*task.Task, error) {
timeoutNS int64
parentTaskID sql.NullString
elaborationInput sql.NullString
+ project sql.NullString
+ repositoryURL sql.NullString
rejectionComment sql.NullString
questionJSON sql.NullString
summary sql.NullString
interactionsJSON sql.NullString
+ storyID sql.NullString
+ acceptanceCriteria sql.NullString
+ checkerForTaskID sql.NullString
+ checkerReport sql.NullString
+ )
+ err := row.Scan(
+ &t.ID, &t.Name, &t.Description, &elaborationInput, &project, &repositoryURL,
+ &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON,
+ &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt,
+ &rejectionComment, &questionJSON, &summary, &interactionsJSON, &storyID,
+ &acceptanceCriteria, &checkerForTaskID, &checkerReport,
)
- err := row.Scan(&t.ID, &t.Name, &t.Description, &elaborationInput, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt, &rejectionComment, &questionJSON, &summary, &interactionsJSON)
t.ParentTaskID = parentTaskID.String
t.ElaborationInput = elaborationInput.String
+ t.Project = project.String
+ t.RepositoryURL = repositoryURL.String
t.RejectionComment = rejectionComment.String
t.QuestionJSON = questionJSON.String
t.Summary = summary.String
+ t.StoryID = storyID.String
+ t.AcceptanceCriteria = acceptanceCriteria.String
+ t.CheckerForTaskID = checkerForTaskID.String
+ t.CheckerReport = checkerReport.String
if err != nil {
return nil, err
}
@@ -772,3 +1045,263 @@ func (s *DB) UpdateExecutionChangestats(execID string, stats *task.Changestats)
func scanExecutionRows(rows *sql.Rows) (*Execution, error) {
return scanExecution(rows)
}
+
+// PushSubscription represents a browser push subscription.
+type PushSubscription struct {
+ ID string `json:"id"`
+ Endpoint string `json:"endpoint"`
+ P256DHKey string `json:"p256dh_key"`
+ AuthKey string `json:"auth_key"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
+// SavePushSubscription inserts or replaces a push subscription by endpoint.
+func (s *DB) SavePushSubscription(sub PushSubscription) error {
+ _, err := s.db.Exec(`
+ INSERT INTO push_subscriptions (id, endpoint, p256dh_key, auth_key)
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT(endpoint) DO UPDATE SET
+ id = excluded.id,
+ p256dh_key = excluded.p256dh_key,
+ auth_key = excluded.auth_key`,
+ sub.ID, sub.Endpoint, sub.P256DHKey, sub.AuthKey,
+ )
+ return err
+}
+
+// DeletePushSubscription removes the subscription with the given endpoint.
+func (s *DB) DeletePushSubscription(endpoint string) error {
+ _, err := s.db.Exec(`DELETE FROM push_subscriptions WHERE endpoint = ?`, endpoint)
+ return err
+}
+
+// ListPushSubscriptions returns all registered push subscriptions.
+func (s *DB) ListPushSubscriptions() ([]PushSubscription, error) {
+ rows, err := s.db.Query(`SELECT id, endpoint, p256dh_key, auth_key, created_at FROM push_subscriptions ORDER BY created_at`)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var subs []PushSubscription
+ for rows.Next() {
+ var sub PushSubscription
+ var createdAt string
+ if err := rows.Scan(&sub.ID, &sub.Endpoint, &sub.P256DHKey, &sub.AuthKey, &createdAt); err != nil {
+ return nil, err
+ }
+ // Parse created_at; ignore errors (use zero time on failure).
+ for _, layout := range []string{time.RFC3339, "2006-01-02 15:04:05", "2006-01-02T15:04:05Z"} {
+ if t, err := time.Parse(layout, createdAt); err == nil {
+ sub.CreatedAt = t
+ break
+ }
+ }
+ subs = append(subs, sub)
+ }
+ if subs == nil {
+ subs = []PushSubscription{}
+ }
+ return subs, rows.Err()
+}
+
+// GetSetting returns the value for a key, or ("", nil) if not found.
+func (s *DB) GetSetting(key string) (string, error) {
+ var value string
+ err := s.db.QueryRow(`SELECT value FROM settings WHERE key = ?`, key).Scan(&value)
+ if err == sql.ErrNoRows {
+ return "", nil
+ }
+ return value, err
+}
+
+// SetSetting upserts a key/value pair in the settings table.
+func (s *DB) SetSetting(key, value string) error {
+ _, err := s.db.Exec(`INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value`, key, value)
+ return err
+}
+
+// AgentEvent records a rate-limit state change for an agent.
+type AgentEvent struct {
+ ID string `json:"id"`
+ Agent string `json:"agent"`
+ Event string `json:"event"` // "rate_limited" | "available"
+ Timestamp time.Time `json:"timestamp"`
+ Until *time.Time `json:"until,omitempty"` // non-nil for "rate_limited" events
+ Reason string `json:"reason"` // "transient" | "quota"
+}
+
+// RecordAgentEvent inserts an agent rate-limit event.
+func (s *DB) RecordAgentEvent(e AgentEvent) error {
+ _, err := s.db.Exec(
+ `INSERT INTO agent_events (id, agent, event, timestamp, until, reason) VALUES (?, ?, ?, ?, ?, ?)`,
+ e.ID, e.Agent, e.Event, e.Timestamp.UTC(), timeOrNull(e.Until), e.Reason,
+ )
+ return err
+}
+
+// ListAgentEvents returns agent events since the given time, newest first.
+func (s *DB) ListAgentEvents(since time.Time) ([]AgentEvent, error) {
+ rows, err := s.db.Query(
+ `SELECT id, agent, event, timestamp, until, reason FROM agent_events WHERE timestamp >= ? ORDER BY timestamp DESC LIMIT 500`,
+ since.UTC(),
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var events []AgentEvent
+ for rows.Next() {
+ var e AgentEvent
+ var until sql.NullTime
+ var reason sql.NullString
+ if err := rows.Scan(&e.ID, &e.Agent, &e.Event, &e.Timestamp, &until, &reason); err != nil {
+ return nil, err
+ }
+ if until.Valid {
+ e.Until = &until.Time
+ }
+ e.Reason = reason.String
+ events = append(events, e)
+ }
+ return events, rows.Err()
+}
+
+func timeOrNull(t *time.Time) interface{} {
+ if t == nil {
+ return nil
+ }
+ return t.UTC()
+}
+
+// CreateProject inserts a new project.
+func (s *DB) CreateProject(p *task.Project) error {
+ now := time.Now().UTC()
+ _, err := s.db.Exec(
+ `INSERT INTO projects (id, name, remote_url, local_path, type, deploy_script, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
+ p.ID, p.Name, p.RemoteURL, p.LocalPath, p.Type, p.DeployScript, now, now,
+ )
+ return err
+}
+
+// GetProject retrieves a project by ID.
+func (s *DB) GetProject(id string) (*task.Project, error) {
+ row := s.db.QueryRow(`SELECT id, name, remote_url, local_path, type, deploy_script FROM projects WHERE id = ?`, id)
+ p := &task.Project{}
+ if err := row.Scan(&p.ID, &p.Name, &p.RemoteURL, &p.LocalPath, &p.Type, &p.DeployScript); err != nil {
+ return nil, err
+ }
+ return p, nil
+}
+
+// ListProjects returns all projects.
+func (s *DB) ListProjects() ([]*task.Project, error) {
+ rows, err := s.db.Query(`SELECT id, name, remote_url, local_path, type, deploy_script FROM projects ORDER BY name`)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var projects []*task.Project
+ for rows.Next() {
+ p := &task.Project{}
+ if err := rows.Scan(&p.ID, &p.Name, &p.RemoteURL, &p.LocalPath, &p.Type, &p.DeployScript); err != nil {
+ return nil, err
+ }
+ projects = append(projects, p)
+ }
+ return projects, rows.Err()
+}
+
+// UpdateProject updates an existing project.
+func (s *DB) UpdateProject(p *task.Project) error {
+ now := time.Now().UTC()
+ _, err := s.db.Exec(
+ `UPDATE projects SET name = ?, remote_url = ?, local_path = ?, type = ?, deploy_script = ?, updated_at = ? WHERE id = ?`,
+ p.Name, p.RemoteURL, p.LocalPath, p.Type, p.DeployScript, now, p.ID,
+ )
+ return err
+}
+
+// UpsertProject inserts or updates a project by ID (used for seeding).
+func (s *DB) UpsertProject(p *task.Project) error {
+ now := time.Now().UTC()
+ _, err := s.db.Exec(
+ `INSERT INTO projects (id, name, remote_url, local_path, type, deploy_script, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+ ON CONFLICT(id) DO UPDATE SET name=excluded.name, remote_url=excluded.remote_url,
+ local_path=excluded.local_path, type=excluded.type, deploy_script=excluded.deploy_script, updated_at=excluded.updated_at`,
+ p.ID, p.Name, p.RemoteURL, p.LocalPath, p.Type, p.DeployScript, now, now,
+ )
+ return err
+}
+
+// CreateStory inserts a new story.
+func (s *DB) CreateStory(st *task.Story) error {
+ now := time.Now().UTC()
+ _, err := s.db.Exec(
+ `INSERT INTO stories (id, name, project_id, branch_name, deploy_config, validation_json, status, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+ st.ID, st.Name, st.ProjectID, st.BranchName, st.DeployConfig, st.ValidationJSON, string(st.Status), now, now,
+ )
+ return err
+}
+
+// GetStory retrieves a story by ID.
+func (s *DB) GetStory(id string) (*task.Story, error) {
+ row := s.db.QueryRow(`SELECT id, name, project_id, branch_name, deploy_config, validation_json, status, created_at, updated_at FROM stories WHERE id = ?`, id)
+ st := &task.Story{}
+ var status string
+ if err := row.Scan(&st.ID, &st.Name, &st.ProjectID, &st.BranchName, &st.DeployConfig, &st.ValidationJSON, &status, &st.CreatedAt, &st.UpdatedAt); err != nil {
+ return nil, err
+ }
+ st.Status = task.StoryState(status)
+ return st, nil
+}
+
+// ListStories returns all stories ordered by creation time descending.
+func (s *DB) ListStories() ([]*task.Story, error) {
+ rows, err := s.db.Query(`SELECT id, name, project_id, branch_name, deploy_config, validation_json, status, created_at, updated_at FROM stories ORDER BY created_at DESC`)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var stories []*task.Story
+ for rows.Next() {
+ st := &task.Story{}
+ var status string
+ if err := rows.Scan(&st.ID, &st.Name, &st.ProjectID, &st.BranchName, &st.DeployConfig, &st.ValidationJSON, &status, &st.CreatedAt, &st.UpdatedAt); err != nil {
+ return nil, err
+ }
+ st.Status = task.StoryState(status)
+ stories = append(stories, st)
+ }
+ return stories, rows.Err()
+}
+
+// UpdateStoryStatus updates the status of a story.
+func (s *DB) UpdateStoryStatus(id string, status task.StoryState) error {
+ now := time.Now().UTC()
+ _, err := s.db.Exec(`UPDATE stories SET status = ?, updated_at = ? WHERE id = ?`, string(status), now, id)
+ return err
+}
+
+// ListTasksByStory returns all tasks associated with a story, ordered by creation time ascending.
+func (s *DB) ListTasksByStory(storyID string) ([]*task.Task, error) {
+ rows, err := s.db.Query(
+ `SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE story_id = ? ORDER BY created_at ASC`,
+ storyID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var tasks []*task.Task
+ for rows.Next() {
+ t, err := scanTaskRows(rows)
+ if err != nil {
+ return nil, err
+ }
+ tasks = append(tasks, t)
+ }
+ return tasks, rows.Err()
+}
diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go
index 752c5b1..0e67e02 100644
--- a/internal/storage/db_test.go
+++ b/internal/storage/db_test.go
@@ -41,7 +41,6 @@ func TestCreateTask_AndGetTask(t *testing.T) {
Type: "claude",
Model: "sonnet",
Instructions: "do it",
- ProjectDir: "/tmp",
MaxBudgetUSD: 2.5,
},
Priority: task.PriorityHigh,
@@ -990,6 +989,128 @@ func TestAppendTaskInteraction_NotFound(t *testing.T) {
}
}
+func TestCreateTask_Project_RoundTrip(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC().Truncate(time.Second)
+
+ tk := &task.Task{
+ ID: "proj-1",
+ Name: "Project Task",
+ Project: "my-project",
+ Agent: task.AgentConfig{Type: "claude", Instructions: "do it"},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ State: task.StatePending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatalf("creating task: %v", err)
+ }
+
+ got, err := db.GetTask("proj-1")
+ if err != nil {
+ t.Fatalf("getting task: %v", err)
+ }
+ if got.Project != "my-project" {
+ t.Errorf("project: want %q, got %q", "my-project", got.Project)
+ }
+}
+
+// ── Push subscription tests ───────────────────────────────────────────────────
+
+func TestPushSubscription_SaveAndList(t *testing.T) {
+ db := testDB(t)
+
+ sub := PushSubscription{
+ ID: "sub-1",
+ Endpoint: "https://push.example.com/endpoint1",
+ P256DHKey: "p256dhkey1",
+ AuthKey: "authkey1",
+ }
+ if err := db.SavePushSubscription(sub); err != nil {
+ t.Fatalf("SavePushSubscription: %v", err)
+ }
+
+ subs, err := db.ListPushSubscriptions()
+ if err != nil {
+ t.Fatalf("ListPushSubscriptions: %v", err)
+ }
+ if len(subs) != 1 {
+ t.Fatalf("want 1 subscription, got %d", len(subs))
+ }
+ if subs[0].Endpoint != sub.Endpoint {
+ t.Errorf("endpoint: want %q, got %q", sub.Endpoint, subs[0].Endpoint)
+ }
+ if subs[0].P256DHKey != sub.P256DHKey {
+ t.Errorf("p256dh_key: want %q, got %q", sub.P256DHKey, subs[0].P256DHKey)
+ }
+ if subs[0].AuthKey != sub.AuthKey {
+ t.Errorf("auth_key: want %q, got %q", sub.AuthKey, subs[0].AuthKey)
+ }
+}
+
+func TestPushSubscription_Delete(t *testing.T) {
+ db := testDB(t)
+
+ sub := PushSubscription{
+ ID: "sub-del",
+ Endpoint: "https://push.example.com/todelete",
+ P256DHKey: "key",
+ AuthKey: "auth",
+ }
+ if err := db.SavePushSubscription(sub); err != nil {
+ t.Fatalf("SavePushSubscription: %v", err)
+ }
+
+ if err := db.DeletePushSubscription(sub.Endpoint); err != nil {
+ t.Fatalf("DeletePushSubscription: %v", err)
+ }
+
+ subs, err := db.ListPushSubscriptions()
+ if err != nil {
+ t.Fatalf("ListPushSubscriptions: %v", err)
+ }
+ if len(subs) != 0 {
+ t.Errorf("want 0 subscriptions after delete, got %d", len(subs))
+ }
+}
+
+func TestPushSubscription_UniqueEndpoint(t *testing.T) {
+ db := testDB(t)
+
+ sub := PushSubscription{
+ ID: "sub-uq",
+ Endpoint: "https://push.example.com/unique",
+ P256DHKey: "key1",
+ AuthKey: "auth1",
+ }
+ if err := db.SavePushSubscription(sub); err != nil {
+ t.Fatalf("SavePushSubscription first: %v", err)
+ }
+
+ // Save again with same endpoint — should update or replace, not error.
+ sub2 := PushSubscription{
+ ID: "sub-uq2",
+ Endpoint: "https://push.example.com/unique",
+ P256DHKey: "key2",
+ AuthKey: "auth2",
+ }
+ if err := db.SavePushSubscription(sub2); err != nil {
+ t.Fatalf("SavePushSubscription second (upsert): %v", err)
+ }
+
+ subs, err := db.ListPushSubscriptions()
+ if err != nil {
+ t.Fatalf("ListPushSubscriptions: %v", err)
+ }
+ if len(subs) != 1 {
+ t.Errorf("want 1 subscription after upsert, got %d", len(subs))
+ }
+}
+
func TestExecution_StoreAndRetrieveChangestats(t *testing.T) {
db := testDB(t)
now := time.Now().UTC().Truncate(time.Second)
@@ -1032,3 +1153,252 @@ func TestExecution_StoreAndRetrieveChangestats(t *testing.T) {
}
}
+func TestCreateProject(t *testing.T) {
+ db := testDB(t)
+ defer db.Close()
+
+ p := &task.Project{
+ ID: "proj-1",
+ Name: "claudomator",
+ RemoteURL: "/bare/claudomator.git",
+ LocalPath: "/workspace/claudomator",
+ Type: "web",
+ }
+ if err := db.CreateProject(p); err != nil {
+ t.Fatalf("CreateProject: %v", err)
+ }
+ got, err := db.GetProject("proj-1")
+ if err != nil {
+ t.Fatalf("GetProject: %v", err)
+ }
+ if got.Name != "claudomator" {
+ t.Errorf("Name: want claudomator, got %q", got.Name)
+ }
+ if got.LocalPath != "/workspace/claudomator" {
+ t.Errorf("LocalPath: want /workspace/claudomator, got %q", got.LocalPath)
+ }
+}
+
+func TestListProjects(t *testing.T) {
+ db := testDB(t)
+ defer db.Close()
+
+ for _, p := range []*task.Project{
+ {ID: "p1", Name: "alpha", Type: "web"},
+ {ID: "p2", Name: "beta", Type: "android"},
+ } {
+ if err := db.CreateProject(p); err != nil {
+ t.Fatalf("CreateProject: %v", err)
+ }
+ }
+ list, err := db.ListProjects()
+ if err != nil {
+ t.Fatalf("ListProjects: %v", err)
+ }
+ if len(list) != 2 {
+ t.Errorf("want 2 projects, got %d", len(list))
+ }
+}
+
+func TestUpdateProject(t *testing.T) {
+ db := testDB(t)
+ defer db.Close()
+
+ p := &task.Project{ID: "p1", Name: "original", Type: "web"}
+ if err := db.CreateProject(p); err != nil {
+ t.Fatalf("CreateProject: %v", err)
+ }
+ p.Name = "updated"
+ if err := db.UpdateProject(p); err != nil {
+ t.Fatalf("UpdateProject: %v", err)
+ }
+ got, _ := db.GetProject("p1")
+ if got.Name != "updated" {
+ t.Errorf("Name after update: want updated, got %q", got.Name)
+ }
+}
+
+func TestCreateStory(t *testing.T) {
+ db := testDB(t)
+ st := &task.Story{
+ ID: "story-1",
+ Name: "My Story",
+ Status: task.StoryPending,
+ }
+ if err := db.CreateStory(st); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+}
+
+func TestGetStory(t *testing.T) {
+ db := testDB(t)
+ st := &task.Story{
+ ID: "story-2",
+ Name: "Get Story",
+ ProjectID: "proj-1",
+ Status: task.StoryPending,
+ }
+ if err := db.CreateStory(st); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+ got, err := db.GetStory("story-2")
+ if err != nil {
+ t.Fatalf("GetStory: %v", err)
+ }
+ if got.Name != "Get Story" {
+ t.Errorf("Name: want 'Get Story', got %q", got.Name)
+ }
+ if got.ProjectID != "proj-1" {
+ t.Errorf("ProjectID: want 'proj-1', got %q", got.ProjectID)
+ }
+ if got.Status != task.StoryPending {
+ t.Errorf("Status: want PENDING, got %q", got.Status)
+ }
+}
+
+func TestListStories(t *testing.T) {
+ db := testDB(t)
+ for _, name := range []string{"A", "B", "C"} {
+ if err := db.CreateStory(&task.Story{ID: name, Name: name, Status: task.StoryPending}); err != nil {
+ t.Fatalf("CreateStory %s: %v", name, err)
+ }
+ }
+ stories, err := db.ListStories()
+ if err != nil {
+ t.Fatalf("ListStories: %v", err)
+ }
+ if len(stories) != 3 {
+ t.Errorf("want 3 stories, got %d", len(stories))
+ }
+}
+
+func TestUpdateStoryStatus(t *testing.T) {
+ db := testDB(t)
+ st := &task.Story{ID: "story-upd", Name: "Upd", Status: task.StoryPending}
+ if err := db.CreateStory(st); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+ if err := db.UpdateStoryStatus("story-upd", task.StoryInProgress); err != nil {
+ t.Fatalf("UpdateStoryStatus: %v", err)
+ }
+ got, _ := db.GetStory("story-upd")
+ if got.Status != task.StoryInProgress {
+ t.Errorf("Status: want IN_PROGRESS, got %q", got.Status)
+ }
+}
+
+func TestListTasksByStory(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC()
+
+ if err := db.CreateStory(&task.Story{ID: "story-tasks", Name: "S", Status: task.StoryPending}); err != nil {
+ t.Fatalf("CreateStory: %v", err)
+ }
+
+ makeTask := func(id string) *task.Task {
+ return &task.Task{
+ ID: id,
+ Name: id,
+ StoryID: "story-tasks",
+ Agent: task.AgentConfig{Type: "claude"},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 1},
+ State: task.StatePending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ }
+
+ if err := db.CreateTask(makeTask("t1")); err != nil {
+ t.Fatal(err)
+ }
+ if err := db.CreateTask(makeTask("t2")); err != nil {
+ t.Fatal(err)
+ }
+
+ tasks, err := db.ListTasksByStory("story-tasks")
+ if err != nil {
+ t.Fatalf("ListTasksByStory: %v", err)
+ }
+ if len(tasks) != 2 {
+ t.Errorf("want 2 tasks, got %d", len(tasks))
+ }
+ for _, tk := range tasks {
+ if tk.StoryID != "story-tasks" {
+ t.Errorf("task %s: StoryID want 'story-tasks', got %q", tk.ID, tk.StoryID)
+ }
+ }
+}
+
+func TestUpdateTaskCheckerReport(t *testing.T) {
+ db := testDB(t)
+ tk := &task.Task{
+ ID: "cr-1", Name: "orig", RepositoryURL: "https://github.com/x/y",
+ Agent: task.AgentConfig{Type: "claude", Instructions: "x"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{},
+ State: task.StatePending, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
+ }
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatalf("CreateTask: %v", err)
+ }
+ if err := db.UpdateTaskCheckerReport("cr-1", "Tests failed: missing endpoint"); err != nil {
+ t.Fatalf("UpdateTaskCheckerReport: %v", err)
+ }
+ got, err := db.GetTask("cr-1")
+ if err != nil {
+ t.Fatalf("GetTask: %v", err)
+ }
+ if got.CheckerReport != "Tests failed: missing endpoint" {
+ t.Errorf("expected checker report, got %q", got.CheckerReport)
+ }
+}
+
+func TestGetCheckerTask(t *testing.T) {
+ db := testDB(t)
+ checked := &task.Task{
+ ID: "chk-orig", Name: "orig", RepositoryURL: "https://github.com/x/y",
+ Agent: task.AgentConfig{Type: "claude", Instructions: "x"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{},
+ State: task.StatePending, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
+ }
+ if err := db.CreateTask(checked); err != nil {
+ t.Fatalf("CreateTask checked: %v", err)
+ }
+ checker := &task.Task{
+ ID: "chk-checker", Name: "Check: orig", CheckerForTaskID: "chk-orig",
+ RepositoryURL: "https://github.com/x/y",
+ Agent: task.AgentConfig{Type: "claude", Instructions: "validate"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{},
+ State: task.StatePending, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
+ }
+ if err := db.CreateTask(checker); err != nil {
+ t.Fatalf("CreateTask checker: %v", err)
+ }
+
+ // Should find the checker task.
+ got, err := db.GetCheckerTask("chk-orig")
+ if err != nil {
+ t.Fatalf("GetCheckerTask: %v", err)
+ }
+ if got == nil || got.ID != "chk-checker" {
+ t.Errorf("expected checker task ID chk-checker, got %v", got)
+ }
+
+ // Should return nil when no checker exists.
+ none, err := db.GetCheckerTask("nonexistent")
+ if err != nil {
+ t.Fatalf("GetCheckerTask nonexistent: %v", err)
+ }
+ if none != nil {
+ t.Errorf("expected nil for task with no checker, got %v", none)
+ }
+}
+
diff --git a/internal/storage/seed.go b/internal/storage/seed.go
new file mode 100644
index 0000000..c2df84f
--- /dev/null
+++ b/internal/storage/seed.go
@@ -0,0 +1,62 @@
+package storage
+
+import (
+ "os/exec"
+ "strings"
+
+ "github.com/thepeterstone/claudomator/internal/task"
+)
+
+// SeedProjects upserts the default project registry on startup.
+func (s *DB) SeedProjects() error {
+ projects := []*task.Project{
+ {
+ ID: "claudomator",
+ Name: "claudomator",
+ LocalPath: "/workspace/claudomator",
+ RemoteURL: localBareRemote("/workspace/claudomator"),
+ Type: "web",
+ DeployScript: "/workspace/claudomator/scripts/deploy",
+ },
+ {
+ ID: "nav",
+ Name: "nav",
+ LocalPath: "/workspace/nav",
+ RemoteURL: localBareRemote("/workspace/nav"),
+ Type: "android",
+ },
+ {
+ ID: "doot",
+ Name: "doot",
+ LocalPath: "/workspace/doot",
+ RemoteURL: localBareRemote("/workspace/doot"),
+ Type: "web",
+ DeployScript: "/workspace/doot/scripts/deploy",
+ },
+ {
+ ID: "modal-shell",
+ Name: "modal-shell",
+ LocalPath: "/workspace/modal-shell",
+ RemoteURL: localBareRemote("/workspace/modal-shell"),
+ Type: "web",
+ },
+ }
+ for _, p := range projects {
+ if err := s.UpsertProject(p); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// localBareRemote returns the URL of the "local" git remote for dir,
+// falling back to dir itself if the remote is not configured.
+func localBareRemote(dir string) string {
+ out, err := exec.Command("git", "-C", dir, "remote", "get-url", "local").Output()
+ if err == nil {
+ if url := strings.TrimSpace(string(out)); url != "" {
+ return url
+ }
+ }
+ return dir
+}
diff --git a/internal/storage/sqlite_cgo.go b/internal/storage/sqlite_cgo.go
new file mode 100644
index 0000000..0956852
--- /dev/null
+++ b/internal/storage/sqlite_cgo.go
@@ -0,0 +1,5 @@
+//go:build cgo
+
+package storage
+
+import _ "github.com/mattn/go-sqlite3"
diff --git a/internal/storage/sqlite_nocgo.go b/internal/storage/sqlite_nocgo.go
new file mode 100644
index 0000000..9862440
--- /dev/null
+++ b/internal/storage/sqlite_nocgo.go
@@ -0,0 +1,21 @@
+//go:build !cgo
+
+package storage
+
+import (
+ "database/sql"
+ "database/sql/driver"
+
+ modernc "modernc.org/sqlite"
+)
+
+// Register the modernc pure-Go SQLite driver under the "sqlite3" name so that
+// the rest of the codebase can use sql.Open("sqlite3", ...) regardless of
+// whether CGO is enabled.
+func init() {
+ sql.Register("sqlite3", &modernc.Driver{})
+}
+
+// modernc.Driver satisfies driver.Driver; this blank-import ensures the
+// compiler sees the interface is satisfied.
+var _ driver.Driver = (*modernc.Driver)(nil)
diff --git a/internal/task/project.go b/internal/task/project.go
new file mode 100644
index 0000000..bd8a4fb
--- /dev/null
+++ b/internal/task/project.go
@@ -0,0 +1,11 @@
+package task
+
+// Project represents a registered codebase that agents can operate on.
+type Project struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ RemoteURL string `json:"remote_url"`
+ LocalPath string `json:"local_path"`
+ Type string `json:"type"` // "web" | "android"
+ DeployScript string `json:"deploy_script"` // optional path or command
+}
diff --git a/internal/task/story.go b/internal/task/story.go
new file mode 100644
index 0000000..536bda1
--- /dev/null
+++ b/internal/task/story.go
@@ -0,0 +1,41 @@
+package task
+
+import "time"
+
+type StoryState string
+
+const (
+ StoryPending StoryState = "PENDING"
+ StoryInProgress StoryState = "IN_PROGRESS"
+ StoryShippable StoryState = "SHIPPABLE"
+ StoryDeployed StoryState = "DEPLOYED"
+ StoryValidating StoryState = "VALIDATING"
+ StoryReviewReady StoryState = "REVIEW_READY"
+ StoryNeedsFix StoryState = "NEEDS_FIX"
+)
+
+type Story struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ ProjectID string `json:"project_id"`
+ BranchName string `json:"branch_name"`
+ DeployConfig string `json:"deploy_config"`
+ ValidationJSON string `json:"validation_json"`
+ Status StoryState `json:"status"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
+var validStoryTransitions = map[StoryState]map[StoryState]bool{
+ StoryPending: {StoryInProgress: true},
+ StoryInProgress: {StoryShippable: true, StoryNeedsFix: true},
+ StoryShippable: {StoryDeployed: true},
+ StoryDeployed: {StoryValidating: true},
+ StoryValidating: {StoryReviewReady: true, StoryNeedsFix: true},
+ StoryReviewReady: {},
+ StoryNeedsFix: {StoryInProgress: true},
+}
+
+func ValidStoryTransition(from, to StoryState) bool {
+ return validStoryTransitions[from][to]
+}
diff --git a/internal/task/story_test.go b/internal/task/story_test.go
new file mode 100644
index 0000000..38d0290
--- /dev/null
+++ b/internal/task/story_test.go
@@ -0,0 +1,42 @@
+package task
+
+import "testing"
+
+func TestValidStoryTransition_Valid(t *testing.T) {
+ cases := []struct {
+ from StoryState
+ to StoryState
+ }{
+ {StoryPending, StoryInProgress},
+ {StoryInProgress, StoryShippable},
+ {StoryInProgress, StoryNeedsFix},
+ {StoryNeedsFix, StoryInProgress},
+ {StoryShippable, StoryDeployed},
+ {StoryDeployed, StoryValidating},
+ {StoryValidating, StoryReviewReady},
+ {StoryValidating, StoryNeedsFix},
+ }
+ for _, tc := range cases {
+ if !ValidStoryTransition(tc.from, tc.to) {
+ t.Errorf("expected valid transition %s → %s", tc.from, tc.to)
+ }
+ }
+}
+
+func TestValidStoryTransition_Invalid(t *testing.T) {
+ cases := []struct {
+ from StoryState
+ to StoryState
+ }{
+ {StoryPending, StoryDeployed},
+ {StoryReviewReady, StoryPending},
+ {StoryReviewReady, StoryInProgress},
+ {StoryReviewReady, StoryShippable},
+ {StoryShippable, StoryPending},
+ }
+ for _, tc := range cases {
+ if ValidStoryTransition(tc.from, tc.to) {
+ t.Errorf("expected invalid transition %s → %s", tc.from, tc.to)
+ }
+ }
+}
diff --git a/internal/task/task.go b/internal/task/task.go
index fd1dde6..935a238 100644
--- a/internal/task/task.go
+++ b/internal/task/task.go
@@ -32,13 +32,14 @@ type AgentConfig struct {
Model string `yaml:"model" json:"model"`
ContextFiles []string `yaml:"context_files" json:"context_files"`
Instructions string `yaml:"instructions" json:"instructions"`
- ProjectDir string `yaml:"project_dir" json:"project_dir"`
+ ContainerImage string `yaml:"container_image" json:"container_image"`
MaxBudgetUSD float64 `yaml:"max_budget_usd" json:"max_budget_usd"`
PermissionMode string `yaml:"permission_mode" json:"permission_mode"`
AllowedTools []string `yaml:"allowed_tools" json:"allowed_tools"`
DisallowedTools []string `yaml:"disallowed_tools" json:"disallowed_tools"`
SystemPromptAppend string `yaml:"system_prompt_append" json:"system_prompt_append"`
AdditionalArgs []string `yaml:"additional_args" json:"additional_args"`
+ ProjectDir string `yaml:"project_dir" json:"project_dir,omitempty"`
SkipPlanning bool `yaml:"skip_planning" json:"skip_planning"`
// Local-runner sampling controls. Pointer for Temperature so a 0 value can
@@ -79,12 +80,19 @@ type Task struct {
ParentTaskID string `yaml:"parent_task_id" json:"parent_task_id"`
Name string `yaml:"name" json:"name"`
Description string `yaml:"description" json:"description"`
+ Project string `yaml:"project" json:"project"` // Human-readable project name
+ RepositoryURL string `yaml:"repository_url" json:"repository_url"`
Agent AgentConfig `yaml:"agent" json:"agent"`
Timeout Duration `yaml:"timeout" json:"timeout"`
Retry RetryConfig `yaml:"retry" json:"retry"`
Priority Priority `yaml:"priority" json:"priority"`
Tags []string `yaml:"tags" json:"tags"`
DependsOn []string `yaml:"depends_on" json:"depends_on"`
+ StoryID string `yaml:"-" json:"story_id,omitempty"`
+ BranchName string `yaml:"-" json:"branch_name,omitempty"`
+ AcceptanceCriteria string `yaml:"-" json:"acceptance_criteria,omitempty"`
+ CheckerForTaskID string `yaml:"-" json:"checker_for_task_id,omitempty"`
+ CheckerReport string `yaml:"-" json:"checker_report,omitempty"`
State State `yaml:"-" json:"state"`
RejectionComment string `yaml:"-" json:"rejection_comment,omitempty"`
QuestionJSON string `yaml:"-" json:"question,omitempty"`
@@ -130,7 +138,7 @@ type BatchFile struct {
// BLOCKED may advance to READY when all subtasks complete, or back to QUEUED on user answer.
var validTransitions = map[State]map[State]bool{
StatePending: {StateQueued: true, StateCancelled: true},
- StateQueued: {StateRunning: true, StateCancelled: true},
+ StateQueued: {StateRunning: true, StateCancelled: true, StateReady: true}, // READY: parent task completed via subtask delegation
StateRunning: {StateReady: true, StateCompleted: true, StateFailed: true, StateTimedOut: true, StateCancelled: true, StateBudgetExceeded: true, StateBlocked: true},
StateReady: {StateCompleted: true, StatePending: true},
StateFailed: {StateQueued: true}, // retry
diff --git a/internal/task/task_test.go b/internal/task/task_test.go
index 15ba019..e6a17b8 100644
--- a/internal/task/task_test.go
+++ b/internal/task/task_test.go
@@ -100,3 +100,31 @@ func TestDuration_MarshalYAML(t *testing.T) {
t.Errorf("expected '15m0s', got %v", v)
}
}
+
+func TestTask_ProjectField(t *testing.T) {
+ t.Run("struct assignment", func(t *testing.T) {
+ task := Task{Project: "my-project"}
+ if task.Project != "my-project" {
+ t.Errorf("expected Project 'my-project', got %q", task.Project)
+ }
+ })
+
+ t.Run("yaml parsing", func(t *testing.T) {
+ yaml := `
+name: "Test Task"
+project: my-project
+agent:
+ instructions: "Do something"
+`
+ tasks, err := Parse([]byte(yaml))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(tasks) != 1 {
+ t.Fatalf("expected 1 task, got %d", len(tasks))
+ }
+ if tasks[0].Project != "my-project" {
+ t.Errorf("expected Project 'my-project', got %q", tasks[0].Project)
+ }
+ })
+}
diff --git a/internal/task/validator.go b/internal/task/validator.go
index 003fab9..43e482e 100644
--- a/internal/task/validator.go
+++ b/internal/task/validator.go
@@ -29,6 +29,9 @@ func Validate(t *Task) error {
if t.Name == "" {
ve.Add("name is required")
}
+ if t.RepositoryURL == "" {
+ ve.Add("repository_url is required")
+ }
if t.Agent.Instructions == "" {
ve.Add("agent.instructions is required")
}
diff --git a/internal/task/validator_test.go b/internal/task/validator_test.go
index 657d93f..2c6735c 100644
--- a/internal/task/validator_test.go
+++ b/internal/task/validator_test.go
@@ -9,10 +9,10 @@ func validTask() *Task {
return &Task{
ID: "test-id",
Name: "Valid Task",
+ RepositoryURL: "https://github.com/user/repo",
Agent: AgentConfig{
Type: "claude",
Instructions: "do something",
- ProjectDir: "/tmp",
},
Priority: PriorityNormal,
Retry: RetryConfig{MaxAttempts: 1, Backoff: "exponential"},