summaryrefslogtreecommitdiff
path: root/internal/executor/question_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/executor/question_test.go')
-rw-r--r--internal/executor/question_test.go253
1 files changed, 253 insertions, 0 deletions
diff --git a/internal/executor/question_test.go b/internal/executor/question_test.go
new file mode 100644
index 0000000..d0fbed9
--- /dev/null
+++ b/internal/executor/question_test.go
@@ -0,0 +1,253 @@
+package executor
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "log/slog"
+ "strings"
+ "testing"
+)
+
+func TestQuestionRegistry_RegisterAndAnswer(t *testing.T) {
+ qr := NewQuestionRegistry()
+
+ ch := qr.Register("task-1", "toolu_abc", json.RawMessage(`{"question":"color?"}`))
+
+ // Answer should unblock the channel.
+ go func() {
+ ok := qr.Answer("toolu_abc", "blue")
+ if !ok {
+ t.Error("Answer returned false, expected true")
+ }
+ }()
+
+ answer := <-ch
+ if answer != "blue" {
+ t.Errorf("want 'blue', got %q", answer)
+ }
+
+ // Question should be removed after answering.
+ if qr.Get("toolu_abc") != nil {
+ t.Error("question should be removed after answering")
+ }
+}
+
+func TestQuestionRegistry_AnswerUnknown(t *testing.T) {
+ qr := NewQuestionRegistry()
+ ok := qr.Answer("nonexistent", "anything")
+ if ok {
+ t.Error("expected false for unknown question")
+ }
+}
+
+func TestQuestionRegistry_PendingForTask(t *testing.T) {
+ qr := NewQuestionRegistry()
+ qr.Register("task-1", "toolu_1", json.RawMessage(`{}`))
+ qr.Register("task-1", "toolu_2", json.RawMessage(`{}`))
+ qr.Register("task-2", "toolu_3", json.RawMessage(`{}`))
+
+ pending := qr.PendingForTask("task-1")
+ if len(pending) != 2 {
+ t.Errorf("want 2 pending for task-1, got %d", len(pending))
+ }
+
+ pending2 := qr.PendingForTask("task-2")
+ if len(pending2) != 1 {
+ t.Errorf("want 1 pending for task-2, got %d", len(pending2))
+ }
+}
+
+func TestQuestionRegistry_Remove(t *testing.T) {
+ qr := NewQuestionRegistry()
+ qr.Register("task-1", "toolu_x", json.RawMessage(`{}`))
+ qr.Remove("toolu_x")
+ if qr.Get("toolu_x") != nil {
+ t.Error("question should be removed")
+ }
+}
+
+func TestExtractAskUserQuestion_DetectsQuestion(t *testing.T) {
+ // Simulate a stream-json assistant event containing an AskUserQuestion tool_use.
+ event := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_01ABC",
+ "name": "AskUserQuestion",
+ "input": map[string]interface{}{
+ "questions": []interface{}{
+ map[string]interface{}{
+ "question": "Which color?",
+ "header": "Color",
+ "options": []interface{}{
+ map[string]interface{}{"label": "red", "description": "Red color"},
+ map[string]interface{}{"label": "blue", "description": "Blue color"},
+ },
+ "multiSelect": false,
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "toolu_01ABC" {
+ t.Errorf("toolUseID: want 'toolu_01ABC', got %q", toolUseID)
+ }
+ if input == nil {
+ t.Fatal("input should not be nil")
+ }
+}
+
+func TestExtractAskUserQuestion_IgnoresOtherTools(t *testing.T) {
+ event := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_01XYZ",
+ "name": "Read",
+ "input": map[string]interface{}{"file_path": "/foo"},
+ },
+ },
+ },
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "" || input != nil {
+ t.Error("should not detect non-AskUserQuestion tool_use")
+ }
+}
+
+func TestExtractAskUserQuestion_IgnoresNonAssistant(t *testing.T) {
+ event := map[string]interface{}{
+ "type": "system",
+ "subtype": "init",
+ }
+ line, _ := json.Marshal(event)
+
+ toolUseID, input := extractAskUserQuestion(line)
+ if toolUseID != "" || input != nil {
+ t.Error("should not detect from non-assistant events")
+ }
+}
+
+func TestStreamAndParseQuestions_DetectsQuestionAndCost(t *testing.T) {
+ // Build a stream with an assistant event containing AskUserQuestion and a result with cost.
+ assistantEvent := map[string]interface{}{
+ "type": "assistant",
+ "message": map[string]interface{}{
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": "toolu_Q1",
+ "name": "AskUserQuestion",
+ "input": map[string]interface{}{
+ "questions": []interface{}{
+ map[string]interface{}{
+ "question": "Pick a number",
+ "header": "Num",
+ "options": []interface{}{
+ map[string]interface{}{"label": "1", "description": "One"},
+ map[string]interface{}{"label": "2", "description": "Two"},
+ },
+ "multiSelect": false,
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ resultEvent := map[string]interface{}{
+ "type": "result",
+ "cost_usd": 0.05,
+ }
+
+ var buf bytes.Buffer
+ json.NewEncoder(&buf).Encode(assistantEvent)
+ json.NewEncoder(&buf).Encode(resultEvent)
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ var questions []questionDetected
+ onQuestion := func(toolUseID string, input json.RawMessage) {
+ questions = append(questions, questionDetected{toolUseID, input})
+ }
+
+ cost := streamAndParseWithQuestions(strings.NewReader(buf.String()), io.Discard, logger, onQuestion)
+
+ if cost != 0.05 {
+ t.Errorf("cost: want 0.05, got %f", cost)
+ }
+ if len(questions) != 1 {
+ t.Fatalf("want 1 question detected, got %d", len(questions))
+ }
+ if questions[0].toolUseID != "toolu_Q1" {
+ t.Errorf("toolUseID: want 'toolu_Q1', got %q", questions[0].toolUseID)
+ }
+}
+
+type questionDetected struct {
+ toolUseID string
+ input json.RawMessage
+}
+
+func TestBuildToolResultMessage_Format(t *testing.T) {
+ msg := buildToolResultMessage("toolu_123", "blue")
+
+ var parsed map[string]interface{}
+ if err := json.Unmarshal(msg, &parsed); err != nil {
+ t.Fatalf("invalid JSON: %v", err)
+ }
+
+ // Should have type "user" with message containing tool_result
+ msgObj, ok := parsed["message"].(map[string]interface{})
+ if !ok {
+ t.Fatal("missing 'message' field")
+ }
+ content, ok := msgObj["content"].([]interface{})
+ if !ok || len(content) == 0 {
+ t.Fatal("missing content array")
+ }
+
+ block := content[0].(map[string]interface{})
+ if block["type"] != "tool_result" {
+ t.Errorf("type: want 'tool_result', got %v", block["type"])
+ }
+ if block["tool_use_id"] != "toolu_123" {
+ t.Errorf("tool_use_id: want 'toolu_123', got %v", block["tool_use_id"])
+ }
+
+ // The content should contain the answer JSON
+ resultContent, ok := block["content"].(string)
+ if !ok {
+ t.Fatal("content should be a string")
+ }
+ var answerData map[string]interface{}
+ if err := json.Unmarshal([]byte(resultContent), &answerData); err != nil {
+ t.Fatalf("answer content is not valid JSON: %v", err)
+ }
+ answers, ok := answerData["answers"].(map[string]interface{})
+ if !ok {
+ t.Fatal("missing answers in result content")
+ }
+ // At least one answer key should have the value "blue"
+ found := false
+ for _, v := range answers {
+ if v == "blue" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("expected 'blue' in answers, got %v", answers)
+ }
+}