summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-10 09:09:32 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-10 09:29:02 +0000
commit63ccc3380df10cab066e08b40ea41ee1b51bb651 (patch)
tree1cab595d980ead075bd79bc10789be423f132088 /internal
parent0676f0f2e6d1ba371806ca4b808a4993027d86ea (diff)
feat: include project context in elaborator prompt
The elaborator now reads CLAUDE.md and SESSION_STATE.md from the project directory (if they exist) and prepends their content to the user prompt. This allows the AI to generate tasks that are more context-aware.
Diffstat (limited to 'internal')
-rw-r--r--internal/api/elaborate.go28
-rw-r--r--internal/api/elaborate_test.go71
2 files changed, 98 insertions, 1 deletions
diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go
index 2f6c707..5954e29 100644
--- a/internal/api/elaborate.go
+++ b/internal/api/elaborate.go
@@ -6,7 +6,9 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "os"
"os/exec"
+ "path/filepath"
"strings"
"time"
)
@@ -86,6 +88,23 @@ func (s *Server) claudeBinaryPath() string {
return "claude"
}
+func readProjectContext(workDir string) string {
+ if workDir == "" {
+ return ""
+ }
+ var sb strings.Builder
+ for _, filename := range []string{"CLAUDE.md", "SESSION_STATE.md"} {
+ path := filepath.Join(workDir, filename)
+ if data, err := os.ReadFile(path); err == nil {
+ if sb.Len() > 0 {
+ sb.WriteString("\n\n")
+ }
+ sb.WriteString(fmt.Sprintf("--- %s ---\n%s", filename, string(data)))
+ }
+ }
+ return sb.String()
+}
+
func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) {
if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) {
writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
@@ -110,16 +129,23 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) {
workDir = input.ProjectDir
}
+ projectContext := readProjectContext(workDir)
+ fullPrompt := input.Prompt
+ if projectContext != "" {
+ fullPrompt = fmt.Sprintf("Project context from %s:\n%s\n\nUser request: %s", workDir, projectContext, input.Prompt)
+ }
+
ctx, cancel := context.WithTimeout(r.Context(), elaborateTimeout)
defer cancel()
cmd := exec.CommandContext(ctx, s.claudeBinaryPath(),
- "-p", input.Prompt,
+ "-p", fullPrompt,
"--system-prompt", buildElaboratePrompt(workDir),
"--output-format", "json",
"--model", "haiku",
)
+
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
diff --git a/internal/api/elaborate_test.go b/internal/api/elaborate_test.go
index b33ca11..114e75e 100644
--- a/internal/api/elaborate_test.go
+++ b/internal/api/elaborate_test.go
@@ -189,3 +189,74 @@ func TestElaborateTask_InvalidJSONFromClaude(t *testing.T) {
t.Error("expected error message in response")
}
}
+
+func createFakeClaudeCapturingArgs(t *testing.T, output string, exitCode int, argsFile string) string {
+ t.Helper()
+ dir := t.TempDir()
+ outputFile := filepath.Join(dir, "output.json")
+ if err := os.WriteFile(outputFile, []byte(output), 0600); err != nil {
+ t.Fatal(err)
+ }
+ script := filepath.Join(dir, "claude")
+ // Use printf to handle arguments safely
+ content := fmt.Sprintf("#!/bin/sh\nprintf \"%%s\\n\" \"$@\" > %q\ncat %q\nexit %d\n", argsFile, outputFile, exitCode)
+ if err := os.WriteFile(script, []byte(content), 0755); err != nil {
+ t.Fatal(err)
+ }
+ return script
+}
+
+func TestElaborateTask_WithProjectContext(t *testing.T) {
+ srv, _ := testServer(t)
+
+ // Create a temporary workspace with CLAUDE.md and SESSION_STATE.md
+ workDir := t.TempDir()
+ claudeContent := "Claude context info"
+ sessionContent := "Session state info"
+ if err := os.WriteFile(filepath.Join(workDir, "CLAUDE.md"), []byte(claudeContent), 0600); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.WriteFile(filepath.Join(workDir, "SESSION_STATE.md"), []byte(sessionContent), 0600); err != nil {
+ t.Fatal(err)
+ }
+
+ // Capture arguments passed to claude
+ argsFile := filepath.Join(t.TempDir(), "args.txt")
+
+ task := elaboratedTask{
+ Name: "Task with context",
+ Agent: elaboratedAgent{
+ Instructions: "Instructions",
+ },
+ }
+ taskJSON, _ := json.Marshal(task)
+ wrapper := map[string]string{"result": string(taskJSON)}
+ wrapperJSON, _ := json.Marshal(wrapper)
+
+ // Modified createFakeClaude to capture arguments
+ srv.elaborateCmdPath = createFakeClaudeCapturingArgs(t, string(wrapperJSON), 0, argsFile)
+
+ body := fmt.Sprintf(`{"prompt":"do something", "project_dir":"%s"}`, workDir)
+ req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ // Check if captured arguments contain the context
+ capturedArgs, err := os.ReadFile(argsFile)
+ if err != nil {
+ t.Fatal(err)
+ }
+ argsStr := string(capturedArgs)
+ if !strings.Contains(argsStr, claudeContent) {
+ t.Errorf("expected arguments to contain CLAUDE.md content, got %s", argsStr)
+ }
+ if !strings.Contains(argsStr, sessionContent) {
+ t.Errorf("expected arguments to contain SESSION_STATE.md content, got %s", argsStr)
+ }
+}