summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/executions.go4
-rw-r--r--internal/api/executions_test.go20
-rw-r--r--internal/api/logs.go33
-rw-r--r--internal/api/logs_test.go71
4 files changed, 113 insertions, 15 deletions
diff --git a/internal/api/executions.go b/internal/api/executions.go
index d9214c0..114425e 100644
--- a/internal/api/executions.go
+++ b/internal/api/executions.go
@@ -21,12 +21,16 @@ func (s *Server) handleListRecentExecutions(w http.ResponseWriter, r *http.Reque
}
}
+ const maxLimit = 1000
limit := 50
if v := r.URL.Query().Get("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
limit = n
}
}
+ if limit > maxLimit {
+ limit = maxLimit
+ }
taskID := r.URL.Query().Get("task_id")
diff --git a/internal/api/executions_test.go b/internal/api/executions_test.go
index a2bba21..45548ad 100644
--- a/internal/api/executions_test.go
+++ b/internal/api/executions_test.go
@@ -258,6 +258,26 @@ func TestGetExecutionLog_FollowSSEHeaders(t *testing.T) {
}
}
+func TestListRecentExecutions_LimitClamped(t *testing.T) {
+ srv, _ := testServer(t)
+
+ req := httptest.NewRequest("GET", "/api/executions?limit=10000000", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ // The handler should not pass limit > maxLimit to the store.
+ // We verify indirectly: if the query param is accepted without error and
+ // does not cause a panic or 500, the clamp is in effect.
+ // A direct assertion requires a mock store; here we check the response is valid.
+ var execs []storage.RecentExecution
+ if err := json.NewDecoder(w.Body).Decode(&execs); err != nil {
+ t.Fatalf("decoding response: %v", err)
+ }
+}
+
func TestListTasks_ReturnsStateField(t *testing.T) {
srv, store := testServer(t)
createTaskWithState(t, store, "state-check", task.StateRunning)
diff --git a/internal/api/logs.go b/internal/api/logs.go
index 1ba4b00..4e63489 100644
--- a/internal/api/logs.go
+++ b/internal/api/logs.go
@@ -36,9 +36,10 @@ var terminalStates = map[string]bool{
}
type logStreamMsg struct {
- Type string `json:"type"`
- Message *logAssistMsg `json:"message,omitempty"`
- CostUSD float64 `json:"cost_usd,omitempty"`
+ Type string `json:"type"`
+ Message *logAssistMsg `json:"message,omitempty"`
+ CostUSD float64 `json:"cost_usd,omitempty"`
+ TotalCostUSD float64 `json:"total_cost_usd,omitempty"`
}
type logAssistMsg struct {
@@ -258,29 +259,31 @@ func emitLogLine(w http.ResponseWriter, flusher http.Flusher, line []byte) {
return
}
for _, block := range msg.Message.Content {
- var event map[string]string
+ var data []byte
switch block.Type {
case "text":
- event = map[string]string{"type": "text", "content": block.Text}
+ data, _ = json.Marshal(map[string]string{"type": "text", "content": block.Text})
case "tool_use":
- summary := string(block.Input)
- if len(summary) > 80 {
- summary = summary[:80]
- }
- event = map[string]string{"type": "tool_use", "content": fmt.Sprintf("%s(%s)", block.Name, summary)}
+ data, _ = json.Marshal(struct {
+ Type string `json:"type"`
+ Name string `json:"name"`
+ Input json.RawMessage `json:"input,omitempty"`
+ }{Type: "tool_use", Name: block.Name, Input: block.Input})
default:
continue
}
- data, _ := json.Marshal(event)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
case "result":
- event := map[string]string{
- "type": "cost",
- "content": fmt.Sprintf("%g", msg.CostUSD),
+ cost := msg.TotalCostUSD
+ if cost == 0 {
+ cost = msg.CostUSD
}
- data, _ := json.Marshal(event)
+ data, _ := json.Marshal(struct {
+ Type string `json:"type"`
+ TotalCost float64 `json:"total_cost"`
+ }{Type: "cost", TotalCost: cost})
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
diff --git a/internal/api/logs_test.go b/internal/api/logs_test.go
index 52fa168..6c6be05 100644
--- a/internal/api/logs_test.go
+++ b/internal/api/logs_test.go
@@ -1,6 +1,7 @@
package api
import (
+ "encoding/json"
"errors"
"net/http"
"net/http/httptest"
@@ -293,6 +294,76 @@ func TestHandleStreamTaskLogs_TerminalExecution_EmitsEventsAndDone(t *testing.T)
}
}
+// TestEmitLogLine_ToolUse_EmitsNameField verifies that emitLogLine emits a tool_use SSE event
+// with a "name" field matching the tool name so the web UI can display it as "[ToolName]".
+func TestEmitLogLine_ToolUse_EmitsNameField(t *testing.T) {
+ line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","input":{"command":"ls -la"}}]}}`)
+
+ w := httptest.NewRecorder()
+ emitLogLine(w, w, line)
+
+ body := w.Body.String()
+ var found bool
+ for _, chunk := range strings.Split(body, "\n\n") {
+ chunk = strings.TrimSpace(chunk)
+ if !strings.HasPrefix(chunk, "data: ") {
+ continue
+ }
+ jsonStr := strings.TrimPrefix(chunk, "data: ")
+ var e map[string]interface{}
+ if err := json.Unmarshal([]byte(jsonStr), &e); err != nil {
+ continue
+ }
+ if e["type"] == "tool_use" {
+ if e["name"] != "Bash" {
+ t.Errorf("tool_use event name: want Bash, got %v", e["name"])
+ }
+ if e["input"] == nil {
+ t.Error("tool_use event input: expected non-nil")
+ }
+ found = true
+ }
+ }
+ if !found {
+ t.Errorf("no tool_use event found in SSE output:\n%s", body)
+ }
+}
+
+// TestEmitLogLine_Cost_EmitsTotalCostField verifies that emitLogLine emits a cost SSE event
+// with a numeric "total_cost" field so the web UI can display it correctly.
+func TestEmitLogLine_Cost_EmitsTotalCostField(t *testing.T) {
+ line := []byte(`{"type":"result","total_cost_usd":0.0042}`)
+
+ w := httptest.NewRecorder()
+ emitLogLine(w, w, line)
+
+ body := w.Body.String()
+ var found bool
+ for _, chunk := range strings.Split(body, "\n\n") {
+ chunk = strings.TrimSpace(chunk)
+ if !strings.HasPrefix(chunk, "data: ") {
+ continue
+ }
+ jsonStr := strings.TrimPrefix(chunk, "data: ")
+ var e map[string]interface{}
+ if err := json.Unmarshal([]byte(jsonStr), &e); err != nil {
+ continue
+ }
+ if e["type"] == "cost" {
+ if e["total_cost"] == nil {
+ t.Error("cost event total_cost: expected non-nil numeric field")
+ }
+ if v, ok := e["total_cost"].(float64); !ok || v != 0.0042 {
+ t.Errorf("cost event total_cost: want 0.0042, got %v", e["total_cost"])
+ }
+ found = true
+ }
+ }
+ if !found {
+ t.Errorf("no cost event found in SSE output:\n%s", body)
+ }
+}
+
// TestHandleStreamTaskLogs_RunningExecution_LiveTails verifies that a RUNNING execution is
// live-tailed and a done event is emitted once it transitions to a terminal state.
func TestHandleStreamTaskLogs_RunningExecution_LiveTails(t *testing.T) {