diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/executions.go | 4 | ||||
| -rw-r--r-- | internal/api/executions_test.go | 20 | ||||
| -rw-r--r-- | internal/api/logs.go | 33 | ||||
| -rw-r--r-- | internal/api/logs_test.go | 71 |
4 files changed, 113 insertions, 15 deletions
diff --git a/internal/api/executions.go b/internal/api/executions.go index d9214c0..114425e 100644 --- a/internal/api/executions.go +++ b/internal/api/executions.go @@ -21,12 +21,16 @@ func (s *Server) handleListRecentExecutions(w http.ResponseWriter, r *http.Reque } } + const maxLimit = 1000 limit := 50 if v := r.URL.Query().Get("limit"); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 { limit = n } } + if limit > maxLimit { + limit = maxLimit + } taskID := r.URL.Query().Get("task_id") diff --git a/internal/api/executions_test.go b/internal/api/executions_test.go index a2bba21..45548ad 100644 --- a/internal/api/executions_test.go +++ b/internal/api/executions_test.go @@ -258,6 +258,26 @@ func TestGetExecutionLog_FollowSSEHeaders(t *testing.T) { } } +func TestListRecentExecutions_LimitClamped(t *testing.T) { + srv, _ := testServer(t) + + req := httptest.NewRequest("GET", "/api/executions?limit=10000000", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } + // The handler should not pass limit > maxLimit to the store. + // We verify indirectly: if the query param is accepted without error and + // does not cause a panic or 500, the clamp is in effect. + // A direct assertion requires a mock store; here we check the response is valid. + var execs []storage.RecentExecution + if err := json.NewDecoder(w.Body).Decode(&execs); err != nil { + t.Fatalf("decoding response: %v", err) + } +} + func TestListTasks_ReturnsStateField(t *testing.T) { srv, store := testServer(t) createTaskWithState(t, store, "state-check", task.StateRunning) diff --git a/internal/api/logs.go b/internal/api/logs.go index 1ba4b00..4e63489 100644 --- a/internal/api/logs.go +++ b/internal/api/logs.go @@ -36,9 +36,10 @@ var terminalStates = map[string]bool{ } type logStreamMsg struct { - Type string `json:"type"` - Message *logAssistMsg `json:"message,omitempty"` - CostUSD float64 `json:"cost_usd,omitempty"` + Type string `json:"type"` + Message *logAssistMsg `json:"message,omitempty"` + CostUSD float64 `json:"cost_usd,omitempty"` + TotalCostUSD float64 `json:"total_cost_usd,omitempty"` } type logAssistMsg struct { @@ -258,29 +259,31 @@ func emitLogLine(w http.ResponseWriter, flusher http.Flusher, line []byte) { return } for _, block := range msg.Message.Content { - var event map[string]string + var data []byte switch block.Type { case "text": - event = map[string]string{"type": "text", "content": block.Text} + data, _ = json.Marshal(map[string]string{"type": "text", "content": block.Text}) case "tool_use": - summary := string(block.Input) - if len(summary) > 80 { - summary = summary[:80] - } - event = map[string]string{"type": "tool_use", "content": fmt.Sprintf("%s(%s)", block.Name, summary)} + data, _ = json.Marshal(struct { + Type string `json:"type"` + Name string `json:"name"` + Input json.RawMessage `json:"input,omitempty"` + }{Type: "tool_use", Name: block.Name, Input: block.Input}) default: continue } - data, _ := json.Marshal(event) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() } case "result": - event := map[string]string{ - "type": "cost", - "content": fmt.Sprintf("%g", msg.CostUSD), + cost := msg.TotalCostUSD + if cost == 0 { + cost = msg.CostUSD } - data, _ := json.Marshal(event) + data, _ := json.Marshal(struct { + Type string `json:"type"` + TotalCost float64 `json:"total_cost"` + }{Type: "cost", TotalCost: cost}) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() } diff --git a/internal/api/logs_test.go b/internal/api/logs_test.go index 52fa168..6c6be05 100644 --- a/internal/api/logs_test.go +++ b/internal/api/logs_test.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "errors" "net/http" "net/http/httptest" @@ -293,6 +294,76 @@ func TestHandleStreamTaskLogs_TerminalExecution_EmitsEventsAndDone(t *testing.T) } } +// TestEmitLogLine_ToolUse_EmitsNameField verifies that emitLogLine emits a tool_use SSE event +// with a "name" field matching the tool name so the web UI can display it as "[ToolName]". +func TestEmitLogLine_ToolUse_EmitsNameField(t *testing.T) { + line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","input":{"command":"ls -la"}}]}}`) + + w := httptest.NewRecorder() + emitLogLine(w, w, line) + + body := w.Body.String() + var found bool + for _, chunk := range strings.Split(body, "\n\n") { + chunk = strings.TrimSpace(chunk) + if !strings.HasPrefix(chunk, "data: ") { + continue + } + jsonStr := strings.TrimPrefix(chunk, "data: ") + var e map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &e); err != nil { + continue + } + if e["type"] == "tool_use" { + if e["name"] != "Bash" { + t.Errorf("tool_use event name: want Bash, got %v", e["name"]) + } + if e["input"] == nil { + t.Error("tool_use event input: expected non-nil") + } + found = true + } + } + if !found { + t.Errorf("no tool_use event found in SSE output:\n%s", body) + } +} + +// TestEmitLogLine_Cost_EmitsTotalCostField verifies that emitLogLine emits a cost SSE event +// with a numeric "total_cost" field so the web UI can display it correctly. +func TestEmitLogLine_Cost_EmitsTotalCostField(t *testing.T) { + line := []byte(`{"type":"result","total_cost_usd":0.0042}`) + + w := httptest.NewRecorder() + emitLogLine(w, w, line) + + body := w.Body.String() + var found bool + for _, chunk := range strings.Split(body, "\n\n") { + chunk = strings.TrimSpace(chunk) + if !strings.HasPrefix(chunk, "data: ") { + continue + } + jsonStr := strings.TrimPrefix(chunk, "data: ") + var e map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &e); err != nil { + continue + } + if e["type"] == "cost" { + if e["total_cost"] == nil { + t.Error("cost event total_cost: expected non-nil numeric field") + } + if v, ok := e["total_cost"].(float64); !ok || v != 0.0042 { + t.Errorf("cost event total_cost: want 0.0042, got %v", e["total_cost"]) + } + found = true + } + } + if !found { + t.Errorf("no cost event found in SSE output:\n%s", body) + } +} + // TestHandleStreamTaskLogs_RunningExecution_LiveTails verifies that a RUNNING execution is // live-tailed and a done event is emitted once it transitions to a terminal state. func TestHandleStreamTaskLogs_RunningExecution_LiveTails(t *testing.T) { |
