diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-03-10 09:09:32 +0000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-03-10 09:29:02 +0000 |
| commit | 63ccc3380df10cab066e08b40ea41ee1b51bb651 (patch) | |
| tree | 1cab595d980ead075bd79bc10789be423f132088 | |
| parent | 0676f0f2e6d1ba371806ca4b808a4993027d86ea (diff) | |
feat: include project context in elaborator prompt
The elaborator now reads CLAUDE.md and SESSION_STATE.md from the project directory (if they exist) and prepends their content to the user prompt. This allows the AI to generate tasks that are more context-aware.
| -rw-r--r-- | internal/api/elaborate.go | 28 | ||||
| -rw-r--r-- | internal/api/elaborate_test.go | 71 |
2 files changed, 98 insertions, 1 deletions
diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go index 2f6c707..5954e29 100644 --- a/internal/api/elaborate.go +++ b/internal/api/elaborate.go @@ -6,7 +6,9 @@ import ( "encoding/json" "fmt" "net/http" + "os" "os/exec" + "path/filepath" "strings" "time" ) @@ -86,6 +88,23 @@ func (s *Server) claudeBinaryPath() string { return "claude" } +func readProjectContext(workDir string) string { + if workDir == "" { + return "" + } + var sb strings.Builder + for _, filename := range []string{"CLAUDE.md", "SESSION_STATE.md"} { + path := filepath.Join(workDir, filename) + if data, err := os.ReadFile(path); err == nil { + if sb.Len() > 0 { + sb.WriteString("\n\n") + } + sb.WriteString(fmt.Sprintf("--- %s ---\n%s", filename, string(data))) + } + } + return sb.String() +} + func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) { writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) @@ -110,16 +129,23 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { workDir = input.ProjectDir } + projectContext := readProjectContext(workDir) + fullPrompt := input.Prompt + if projectContext != "" { + fullPrompt = fmt.Sprintf("Project context from %s:\n%s\n\nUser request: %s", workDir, projectContext, input.Prompt) + } + ctx, cancel := context.WithTimeout(r.Context(), elaborateTimeout) defer cancel() cmd := exec.CommandContext(ctx, s.claudeBinaryPath(), - "-p", input.Prompt, + "-p", fullPrompt, "--system-prompt", buildElaboratePrompt(workDir), "--output-format", "json", "--model", "haiku", ) + var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr diff --git a/internal/api/elaborate_test.go b/internal/api/elaborate_test.go index b33ca11..114e75e 100644 --- a/internal/api/elaborate_test.go +++ b/internal/api/elaborate_test.go @@ -189,3 +189,74 @@ func TestElaborateTask_InvalidJSONFromClaude(t *testing.T) { t.Error("expected error message in response") } } + +func createFakeClaudeCapturingArgs(t *testing.T, output string, exitCode int, argsFile string) string { + t.Helper() + dir := t.TempDir() + outputFile := filepath.Join(dir, "output.json") + if err := os.WriteFile(outputFile, []byte(output), 0600); err != nil { + t.Fatal(err) + } + script := filepath.Join(dir, "claude") + // Use printf to handle arguments safely + content := fmt.Sprintf("#!/bin/sh\nprintf \"%%s\\n\" \"$@\" > %q\ncat %q\nexit %d\n", argsFile, outputFile, exitCode) + if err := os.WriteFile(script, []byte(content), 0755); err != nil { + t.Fatal(err) + } + return script +} + +func TestElaborateTask_WithProjectContext(t *testing.T) { + srv, _ := testServer(t) + + // Create a temporary workspace with CLAUDE.md and SESSION_STATE.md + workDir := t.TempDir() + claudeContent := "Claude context info" + sessionContent := "Session state info" + if err := os.WriteFile(filepath.Join(workDir, "CLAUDE.md"), []byte(claudeContent), 0600); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(workDir, "SESSION_STATE.md"), []byte(sessionContent), 0600); err != nil { + t.Fatal(err) + } + + // Capture arguments passed to claude + argsFile := filepath.Join(t.TempDir(), "args.txt") + + task := elaboratedTask{ + Name: "Task with context", + Agent: elaboratedAgent{ + Instructions: "Instructions", + }, + } + taskJSON, _ := json.Marshal(task) + wrapper := map[string]string{"result": string(taskJSON)} + wrapperJSON, _ := json.Marshal(wrapper) + + // Modified createFakeClaude to capture arguments + srv.elaborateCmdPath = createFakeClaudeCapturingArgs(t, string(wrapperJSON), 0, argsFile) + + body := fmt.Sprintf(`{"prompt":"do something", "project_dir":"%s"}`, workDir) + req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } + + // Check if captured arguments contain the context + capturedArgs, err := os.ReadFile(argsFile) + if err != nil { + t.Fatal(err) + } + argsStr := string(capturedArgs) + if !strings.Contains(argsStr, claudeContent) { + t.Errorf("expected arguments to contain CLAUDE.md content, got %s", argsStr) + } + if !strings.Contains(argsStr, sessionContent) { + t.Errorf("expected arguments to contain SESSION_STATE.md content, got %s", argsStr) + } +} |
