diff options
Diffstat (limited to 'internal/executor/gemini.go')
| -rw-r--r-- | internal/executor/gemini.go | 87 |
1 files changed, 62 insertions, 25 deletions
diff --git a/internal/executor/gemini.go b/internal/executor/gemini.go index d79c47d..7f2f54f 100644 --- a/internal/executor/gemini.go +++ b/internal/executor/gemini.go @@ -2,6 +2,7 @@ package executor import ( "context" + "encoding/json" "fmt" "io" "log/slog" @@ -117,16 +118,21 @@ func (r *GeminiRunner) execOnce(ctx context.Context, args []string, workingDir, var streamErr error + var streamCost float64 var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - _, streamErr = parseGeminiStream(stdoutR, stdoutFile, r.Logger) + streamCost, streamErr = parseGeminiStream(stdoutR, stdoutFile, r.Logger) stdoutR.Close() }() wg.Wait() // Wait for parseGeminiStream to finish + if streamCost > 0 { + e.CostUSD = streamCost + } + // Set a dummy exit code for this simulated run e.ExitCode = 0 @@ -136,9 +142,10 @@ func (r *GeminiRunner) execOnce(ctx context.Context, args []string, workingDir, return nil } -// parseGeminiStream reads streaming JSON from the gemini CLI, unwraps markdown -// code blocks, writes the inner JSON to w, and returns (costUSD, error). -// For now, it focuses on unwrapping and writing, not detailed parsing of cost/errors. +// parseGeminiStream reads streaming JSON from the gemini CLI, strips markdown +// code fences if the output is wrapped in them, writes the inner stream-json +// to w, and returns (costUSD, error). If a `result` event has `is_error: true`, +// an error wrapping the result message is returned. func parseGeminiStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, error) { fullOutput, err := io.ReadAll(r) if err != nil { @@ -146,31 +153,61 @@ func parseGeminiStream(r io.Reader, w io.Writer, logger *slog.Logger) (float64, } logger.Debug("parseGeminiStream: raw output received", "output", string(fullOutput)) - outputStr := strings.TrimSpace(string(fullOutput)) // Trim leading/trailing whitespace/newlines from the whole output - - jsonContent := outputStr // Default to raw output if no markdown block is found or malformed - jsonStartIdx := strings.Index(outputStr, "```json") - if jsonStartIdx != -1 { - // Found "```json", now look for the closing "```" - jsonEndIdx := strings.LastIndex(outputStr, "```") - if jsonEndIdx != -1 && jsonEndIdx > jsonStartIdx { - // Extract content between the markdown fences. - jsonContent = outputStr[jsonStartIdx+len("```json"):jsonEndIdx] - jsonContent = strings.TrimSpace(jsonContent) // Trim again after extraction, to remove potential inner newlines - } else { - logger.Warn("Malformed markdown JSON block from Gemini (missing closing ``` or invalid structure), falling back to raw output.", "outputLength", len(outputStr)) + inner := stripGeminiFences(string(fullOutput), logger) + if _, writeErr := w.Write([]byte(inner)); writeErr != nil { + return 0, fmt.Errorf("writing gemini output: %w", writeErr) + } + + // Walk lines looking for a result event so we can surface errors and cost. + var ( + cost float64 + errMsg string + isError bool + ) + for _, raw := range strings.Split(inner, "\n") { + line := strings.TrimSpace(raw) + if line == "" { + continue + } + var evt struct { + Type string `json:"type"` + IsError bool `json:"is_error"` + Result string `json:"result"` + Cost float64 `json:"total_cost_usd"` + } + if err := json.Unmarshal([]byte(line), &evt); err != nil { + continue + } + if evt.Type == "result" { + if evt.Cost > 0 { + cost = evt.Cost + } + if evt.IsError { + isError = true + errMsg = evt.Result + } } - } else { - logger.Warn("No markdown JSON block found from Gemini, falling back to raw output.", "outputLength", len(outputStr)) } - - // Write the (possibly extracted and trimmed) JSON content to the writer. - _, writeErr := w.Write([]byte(jsonContent)) - if writeErr != nil { - return 0, fmt.Errorf("writing extracted gemini json: %w", writeErr) + if isError { + return cost, fmt.Errorf("gemini reported error: %s", errMsg) } + return cost, nil +} - return 0, nil // For now, no cost/error parsing for Gemini stream +// stripGeminiFences removes a surrounding ```json ... ``` markdown block if +// present, returning the trimmed inner content. If no markdown fence is +// found, the input is returned verbatim (no whitespace trimming) so callers +// that expect byte-exact pass-through behavior get it. +func stripGeminiFences(raw string, logger *slog.Logger) string { + trimmed := strings.TrimSpace(raw) + if start := strings.Index(trimmed, "```json"); start != -1 { + if end := strings.LastIndex(trimmed, "```"); end > start { + return strings.TrimSpace(trimmed[start+len("```json") : end]) + } + logger.Warn("malformed gemini markdown block (missing closing fence); using raw output", "len", len(trimmed)) + return trimmed + } + return raw } func (r *GeminiRunner) buildArgs(t *task.Task, e *storage.Execution, questionFile string) []string { |
