package executor import ( "bytes" "encoding/json" "io" "log/slog" "strings" "testing" ) func TestQuestionRegistry_RegisterAndAnswer(t *testing.T) { qr := NewQuestionRegistry() ch := qr.Register("task-1", "toolu_abc", json.RawMessage(`{"question":"color?"}`)) // Answer should unblock the channel. go func() { ok := qr.Answer("toolu_abc", "blue") if !ok { t.Error("Answer returned false, expected true") } }() answer := <-ch if answer != "blue" { t.Errorf("want 'blue', got %q", answer) } // Question should be removed after answering. if qr.Get("toolu_abc") != nil { t.Error("question should be removed after answering") } } func TestQuestionRegistry_AnswerUnknown(t *testing.T) { qr := NewQuestionRegistry() ok := qr.Answer("nonexistent", "anything") if ok { t.Error("expected false for unknown question") } } func TestQuestionRegistry_PendingForTask(t *testing.T) { qr := NewQuestionRegistry() qr.Register("task-1", "toolu_1", json.RawMessage(`{}`)) qr.Register("task-1", "toolu_2", json.RawMessage(`{}`)) qr.Register("task-2", "toolu_3", json.RawMessage(`{}`)) pending := qr.PendingForTask("task-1") if len(pending) != 2 { t.Errorf("want 2 pending for task-1, got %d", len(pending)) } pending2 := qr.PendingForTask("task-2") if len(pending2) != 1 { t.Errorf("want 1 pending for task-2, got %d", len(pending2)) } } func TestQuestionRegistry_Remove(t *testing.T) { qr := NewQuestionRegistry() qr.Register("task-1", "toolu_x", json.RawMessage(`{}`)) qr.Remove("toolu_x") if qr.Get("toolu_x") != nil { t.Error("question should be removed") } } func TestExtractAskUserQuestion_DetectsQuestion(t *testing.T) { // Simulate a stream-json assistant event containing an AskUserQuestion tool_use. event := map[string]interface{}{ "type": "assistant", "message": map[string]interface{}{ "content": []interface{}{ map[string]interface{}{ "type": "tool_use", "id": "toolu_01ABC", "name": "AskUserQuestion", "input": map[string]interface{}{ "questions": []interface{}{ map[string]interface{}{ "question": "Which color?", "header": "Color", "options": []interface{}{ map[string]interface{}{"label": "red", "description": "Red color"}, map[string]interface{}{"label": "blue", "description": "Blue color"}, }, "multiSelect": false, }, }, }, }, }, }, } line, _ := json.Marshal(event) toolUseID, input := extractAskUserQuestion(line) if toolUseID != "toolu_01ABC" { t.Errorf("toolUseID: want 'toolu_01ABC', got %q", toolUseID) } if input == nil { t.Fatal("input should not be nil") } } func TestExtractAskUserQuestion_IgnoresOtherTools(t *testing.T) { event := map[string]interface{}{ "type": "assistant", "message": map[string]interface{}{ "content": []interface{}{ map[string]interface{}{ "type": "tool_use", "id": "toolu_01XYZ", "name": "Read", "input": map[string]interface{}{"file_path": "/foo"}, }, }, }, } line, _ := json.Marshal(event) toolUseID, input := extractAskUserQuestion(line) if toolUseID != "" || input != nil { t.Error("should not detect non-AskUserQuestion tool_use") } } func TestExtractAskUserQuestion_IgnoresNonAssistant(t *testing.T) { event := map[string]interface{}{ "type": "system", "subtype": "init", } line, _ := json.Marshal(event) toolUseID, input := extractAskUserQuestion(line) if toolUseID != "" || input != nil { t.Error("should not detect from non-assistant events") } } func TestStreamAndParseQuestions_DetectsQuestionAndCost(t *testing.T) { // Build a stream with an assistant event containing AskUserQuestion and a result with cost. assistantEvent := map[string]interface{}{ "type": "assistant", "message": map[string]interface{}{ "content": []interface{}{ map[string]interface{}{ "type": "tool_use", "id": "toolu_Q1", "name": "AskUserQuestion", "input": map[string]interface{}{ "questions": []interface{}{ map[string]interface{}{ "question": "Pick a number", "header": "Num", "options": []interface{}{ map[string]interface{}{"label": "1", "description": "One"}, map[string]interface{}{"label": "2", "description": "Two"}, }, "multiSelect": false, }, }, }, }, }, }, } resultEvent := map[string]interface{}{ "type": "result", "cost_usd": 0.05, } var buf bytes.Buffer json.NewEncoder(&buf).Encode(assistantEvent) json.NewEncoder(&buf).Encode(resultEvent) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) var questions []questionDetected onQuestion := func(toolUseID string, input json.RawMessage) { questions = append(questions, questionDetected{toolUseID, input}) } cost := streamAndParseWithQuestions(strings.NewReader(buf.String()), io.Discard, logger, onQuestion) if cost != 0.05 { t.Errorf("cost: want 0.05, got %f", cost) } if len(questions) != 1 { t.Fatalf("want 1 question detected, got %d", len(questions)) } if questions[0].toolUseID != "toolu_Q1" { t.Errorf("toolUseID: want 'toolu_Q1', got %q", questions[0].toolUseID) } } type questionDetected struct { toolUseID string input json.RawMessage } func TestBuildToolResultMessage_Format(t *testing.T) { msg := buildToolResultMessage("toolu_123", "blue") var parsed map[string]interface{} if err := json.Unmarshal(msg, &parsed); err != nil { t.Fatalf("invalid JSON: %v", err) } // Should have type "user" with message containing tool_result msgObj, ok := parsed["message"].(map[string]interface{}) if !ok { t.Fatal("missing 'message' field") } content, ok := msgObj["content"].([]interface{}) if !ok || len(content) == 0 { t.Fatal("missing content array") } block := content[0].(map[string]interface{}) if block["type"] != "tool_result" { t.Errorf("type: want 'tool_result', got %v", block["type"]) } if block["tool_use_id"] != "toolu_123" { t.Errorf("tool_use_id: want 'toolu_123', got %v", block["tool_use_id"]) } // The content should contain the answer JSON resultContent, ok := block["content"].(string) if !ok { t.Fatal("content should be a string") } var answerData map[string]interface{} if err := json.Unmarshal([]byte(resultContent), &answerData); err != nil { t.Fatalf("answer content is not valid JSON: %v", err) } answers, ok := answerData["answers"].(map[string]interface{}) if !ok { t.Fatal("missing answers in result content") } // At least one answer key should have the value "blue" found := false for _, v := range answers { if v == "blue" { found = true break } } if !found { t.Errorf("expected 'blue' in answers, got %v", answers) } }