summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/server.go58
-rw-r--r--internal/api/server_test.go144
2 files changed, 187 insertions, 15 deletions
diff --git a/internal/api/server.go b/internal/api/server.go
index 59d59eb..65b0181 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -397,14 +397,15 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) {
var input struct {
- Name string `json:"name"`
- Description string `json:"description"`
- Agent task.AgentConfig `json:"agent"`
- Claude task.AgentConfig `json:"claude"` // legacy alias
- Timeout string `json:"timeout"`
- Priority string `json:"priority"`
- Tags []string `json:"tags"`
- ParentTaskID string `json:"parent_task_id"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ ElaborationInput string `json:"elaboration_input"`
+ Agent task.AgentConfig `json:"agent"`
+ Claude task.AgentConfig `json:"claude"` // legacy alias
+ Timeout string `json:"timeout"`
+ Priority string `json:"priority"`
+ Tags []string `json:"tags"`
+ ParentTaskID string `json:"parent_task_id"`
}
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
@@ -418,10 +419,11 @@ func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) {
now := time.Now().UTC()
t := &task.Task{
- ID: uuid.New().String(),
- Name: input.Name,
- Description: input.Description,
- Agent: input.Agent,
+ ID: uuid.New().String(),
+ Name: input.Name,
+ Description: input.Description,
+ ElaborationInput: input.ElaborationInput,
+ Agent: input.Agent,
Priority: task.Priority(input.Priority),
Tags: input.Tags,
DependsOn: []string{},
@@ -515,8 +517,16 @@ func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
- agent := r.URL.Query().Get("agent")
+ agentParam := r.URL.Query().Get("agent") // Use a different name to avoid confusion
+ // 1. Retrieve the original task to preserve agent config if not "auto".
+ originalTask, err := s.store.GetTask(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
+ return
+ }
+
+ // 2. Reset the task for retry, which clears the agent config.
t, err := s.store.ResetTaskForRetry(id)
if err != nil {
if strings.Contains(err.Error(), "not found") {
@@ -531,9 +541,27 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
return
}
- if agent != "" && agent != "auto" {
- t.Agent.Type = agent
+ // 3. Restore original agent type and model if not explicitly overridden by query parameter.
+ // Only restore if original task had a specific agent type set and query parameter is not overriding it.
+ if originalTask.Agent.Type != "" && agentParam == "" {
+ t.Agent.Type = originalTask.Agent.Type
+ t.Agent.Model = originalTask.Agent.Model
+ }
+
+ // 4. Handle agent query parameter override.
+ if agentParam != "" && agentParam != "auto" {
+ t.Agent.Type = agentParam
+ }
+
+ // 5. Update task agent in DB if it has changed from the reset (only if originalTask.Agent.Type was explicitly set, or agentParam was set).
+ if originalTask.Agent.Type != t.Agent.Type || originalTask.Agent.Model != t.Agent.Model {
+ if err := s.store.UpdateTaskAgent(t.ID, t.Agent); err != nil {
+ s.logger.Error("failed to update task agent config", "error", err, "taskID", t.ID)
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
}
+ // The task `t` now has the correct agent configuration.
if err := s.pool.Submit(context.Background(), t); err != nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)})
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index d090313..a670f33 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -132,6 +132,150 @@ func pollState(t *testing.T, store *storage.DB, taskID string, wantState task.St
return ""
}
+func testServerWithGeminiMockRunner(t *testing.T) (*Server, *storage.DB) {
+ t.Helper()
+ dbPath := filepath.Join(t.TempDir(), "test.db")
+ store, err := storage.Open(dbPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() { store.Close() })
+
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
+
+ // Create the mock gemini binary script.
+ mockBinDir := t.TempDir()
+ mockGeminiPath := filepath.Join(mockBinDir, "mock-gemini-binary.sh")
+ mockScriptContent := `#!/bin/bash
+# Mock gemini binary that outputs stream-json wrapped in markdown to stdout.
+echo "```json"
+echo "{\"type\":\"content_block_start\",\"content_block\":{\"text\":\"Hello, Gemini!\",\"type\":\"text\"}}"
+echo "{\"type\":\"content_block_delta\",\"content_block\":{\"text\":\" How are you?\"}}"
+echo "{\"type\":\"content_block_end\"}"
+echo "{\"type\":\"message_delta\",\"message\":{\"role\":\"model\"}}"
+echo "{\"type\":\"message_end\"}"
+echo "```"
+exit 0
+`
+ if err := os.WriteFile(mockGeminiPath, []byte(mockScriptContent), 0755); err != nil {
+ t.Fatalf("writing mock gemini script: %v", err)
+ }
+
+ // Configure GeminiRunner to use the mock script.
+ geminiRunner := &executor.GeminiRunner{
+ BinaryPath: mockGeminiPath,
+ Logger: logger,
+ LogDir: t.TempDir(), // Ensure log directory is temporary for test
+ APIURL: "http://localhost:8080", // Placeholder, not used by this mock
+ }
+
+ runners := map[string]executor.Runner{
+ "claude": &mockRunner{}, // Keep mock for claude to not interfere
+ "gemini": geminiRunner,
+ }
+ pool := executor.NewPool(2, runners, store, logger)
+ srv := NewServer(store, pool, logger, "claude", "gemini") // Pass original binary paths
+ return srv, store
+}
+
+// TestGeminiLogs_ParsedCorrectly verifies that Gemini's markdown-wrapped stream-json
+// output is correctly unwrapped and parsed before being written to stdout.log
+// and exposed via the /api/tasks/{id}/executions/{exec-id}/log endpoint.
+func TestGeminiLogs_ParsedCorrectly(t *testing.T) {
+ srv, store := testServerWithGeminiMockRunner(t)
+
+ // Expected parsed JSON lines.
+ expectedParsedLogs := []string{
+ `{"type":"content_block_start","content_block":{"text":"Hello, Gemini!","type":"text"}}`,
+ `{"type":"content_block_delta","content_block":{"text":" How are you?"}}`,
+ `{"type":"content_block_end"}`,
+ `{"type":"message_delta","message":{"role":"model"}}`,
+ `{"type":"message_end"}`,
+ }
+
+ // 1. Create a task with Gemini agent.
+ tk := createTestTask(t, srv, `{
+ "name": "Gemini Log Test Task",
+ "description": "Test Gemini log parsing",
+ "agent": {
+ "type": "gemini",
+ "instructions": "generate some output",
+ "model": "gemini-2.5-flash-lite"
+ }
+ }`)
+
+ // 2. Run the task.
+ req := httptest.NewRequest("POST", "/api/tasks/"+tk.ID+"/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("run task status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ // 3. Wait for the task to complete.
+ pollState(t, store, tk.ID, task.StateCompleted, 2*time.Second)
+
+ // Re-fetch the task to ensure we have the updated execution details.
+ updatedTask, err := store.GetTask(tk.ID)
+ if err != nil {
+ t.Fatalf("re-fetching task: %v", err)
+ }
+
+ // 4. Get the execution details to find the log path.
+ executions, err := store.ListExecutions(updatedTask.ID)
+ if err != nil {
+ t.Fatalf("listing executions: %v", err)
+ }
+ if len(executions) != 1 {
+ t.Fatalf("want 1 execution, got %d", len(executions))
+ }
+ exec := executions[0]
+ t.Logf("Re-fetched execution: %+v", exec) // Log the entire execution struct
+
+ // 5. Verify the content of stdout.log directly.
+ t.Logf("Attempting to read stdout.log from: %q", exec.StdoutPath)
+ stdoutContent, err := os.ReadFile(exec.StdoutPath)
+ if err != nil {
+ t.Fatalf("reading stdout.log: %v", err)
+ }
+ stdoutLines := strings.Split(strings.TrimSpace(string(stdoutContent)), "\n")
+ if len(stdoutLines) != len(expectedParsedLogs) {
+ t.Errorf("stdout.log line count: want %d, got %d\nContent:\n%s", len(expectedParsedLogs), len(stdoutLines), stdoutContent)
+ }
+ for i, line := range stdoutLines {
+ if i >= len(expectedParsedLogs) {
+ break
+ }
+ if line != expectedParsedLogs[i] {
+ t.Errorf("stdout.log line %d: want %q, got %q", i, expectedParsedLogs[i], line)
+ }
+ }
+
+ // 6. Verify the content retrieved via the API endpoint.
+ req = httptest.NewRequest("GET", "/api/tasks/"+tk.ID+"/executions/"+exec.ID+"/log", nil)
+ w = httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("GET /log status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ apiLogContent := strings.TrimSpace(w.Body.String())
+ apiLogLines := strings.Split(apiLogContent, "\n")
+ if len(apiLogLines) != len(expectedParsedLogs) {
+ t.Errorf("API log line count: want %d, got %d\nContent:\n%s", len(expectedParsedLogs), len(apiLogLines), apiLogContent)
+ }
+ for i, line := range apiLogLines {
+ if i >= len(expectedParsedLogs) {
+ break
+ }
+ if line != expectedParsedLogs[i] {
+ t.Errorf("API log line %d: want %q, got %q", i, expectedParsedLogs[i], line)
+ }
+ }
+}
+
func TestListWorkspaces_UsesConfiguredRoot(t *testing.T) {
srv, _ := testServer(t)