summaryrefslogtreecommitdiff
path: root/internal/executor/question.go
blob: 0ae1b08b4b98f6c6ca39ebd7189384a08d6f76bc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package executor

import (
	"bufio"
	"encoding/json"
	"io"
	"log/slog"
)

// extractAskUserQuestion parses a stream-json line and returns the tool_use_id and input
// if the line is an assistant event containing an AskUserQuestion tool_use.
func extractAskUserQuestion(line []byte) (string, json.RawMessage) {
	var event struct {
		Type    string `json:"type"`
		Message struct {
			Content []struct {
				Type  string          `json:"type"`
				ID    string          `json:"id"`
				Name  string          `json:"name"`
				Input json.RawMessage `json:"input"`
			} `json:"content"`
		} `json:"message"`
	}
	if err := json.Unmarshal(line, &event); err != nil {
		return "", nil
	}
	if event.Type != "assistant" {
		return "", nil
	}
	for _, block := range event.Message.Content {
		if block.Type == "tool_use" && block.Name == "AskUserQuestion" {
			return block.ID, block.Input
		}
	}
	return "", nil
}

// streamAndParseWithQuestions reads streaming JSON, writes to w, parses cost,
// and calls onQuestion for each detected AskUserQuestion tool_use.
func streamAndParseWithQuestions(r io.Reader, w io.Writer, _ *slog.Logger, onQuestion func(string, json.RawMessage)) float64 {
	tee := io.TeeReader(r, w)
	scanner := bufio.NewScanner(tee)
	scanner.Buffer(make([]byte, 1024*1024), 1024*1024)

	var totalCost float64
	for scanner.Scan() {
		line := scanner.Bytes()

		if toolUseID, input := extractAskUserQuestion(line); toolUseID != "" {
			if onQuestion != nil {
				onQuestion(toolUseID, input)
			}
		}

		var msg map[string]interface{}
		if err := json.Unmarshal(line, &msg); err != nil {
			continue
		}
		if costData, ok := msg["cost_usd"]; ok {
			if cost, ok := costData.(float64); ok {
				totalCost = cost
			}
		}
	}
	return totalCost
}

// buildToolResultMessage builds a tool_result message to feed back to Claude
// as the answer to an AskUserQuestion tool_use.
func buildToolResultMessage(toolUseID, answer string) []byte {
	answerJSON, _ := json.Marshal(map[string]interface{}{
		"answers": map[string]string{"answer": answer},
	})
	msg := map[string]interface{}{
		"message": map[string]interface{}{
			"role": "user",
			"content": []map[string]interface{}{
				{
					"type":        "tool_result",
					"tool_use_id": toolUseID,
					"content":     string(answerJSON),
				},
			},
		},
	}
	result, _ := json.Marshal(msg)
	return result
}