summaryrefslogtreecommitdiff
path: root/internal/llm/client.go
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-05-01 22:14:37 -1000
committerGitHub <noreply@github.com>2026-05-01 22:14:37 -1000
commit99115d8158137083239c45e5a860b718ff4cefa1 (patch)
tree1bf3bd0505eea79375c67af83c7c5fe8c0f274ff /internal/llm/client.go
parentc2aa026f6ce1c9e216b99d74f294fc133d5fcddd (diff)
parent50f8fe8c1ff8b82e0bd399e5776e58bda3e57d1c (diff)
Merge pull request #1 from thepeterstone/claude/local-oss-model-agents-MEBqj
Local OSS models as a third runner (epic)
Diffstat (limited to 'internal/llm/client.go')
-rw-r--r--internal/llm/client.go343
1 files changed, 343 insertions, 0 deletions
diff --git a/internal/llm/client.go b/internal/llm/client.go
new file mode 100644
index 0000000..613ebe5
--- /dev/null
+++ b/internal/llm/client.go
@@ -0,0 +1,343 @@
+// Package llm provides a small OpenAI-compatible HTTP client used for
+// internal LLM-shaped work (model classification, summarization, elaboration)
+// against any local server speaking /v1/chat/completions: Ollama, vLLM,
+// LM Studio, llama.cpp server, etc.
+package llm
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/thepeterstone/claudomator/internal/retry"
+)
+
+// Client is an OpenAI-compatible chat completions client.
+// Endpoint is the base URL up through "/v1" (no trailing slash).
+type Client struct {
+ Endpoint string
+ Model string
+ APIKey string // optional, sent as Bearer token
+ HTTPClient *http.Client
+ Logger *slog.Logger
+}
+
+// Message is a single chat-completion message.
+type Message struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+// ChatRequest captures the parameters of a single Chat or ChatStream call.
+// Zero values mean "use server default" except for Stream and ResponseJSON,
+// which are explicit booleans. Model overrides Client.Model when non-empty.
+type ChatRequest struct {
+ Model string
+ Messages []Message
+ Temperature *float64
+ MaxTokens int
+ ResponseJSON bool
+}
+
+// ChatResponse is the aggregated result of a chat completion.
+type ChatResponse struct {
+ Content string
+ PromptTokens int
+ OutputTokens int
+ Model string
+ FinishReason string
+}
+
+// Chat performs a non-streaming chat completion. Rate-limit errors (HTTP 429,
+// overloaded responses) are retried with exponential backoff via
+// retry.RunWithBackoff.
+func (c *Client) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) {
+ if c == nil {
+ return nil, errors.New("llm: nil Client")
+ }
+ body, err := c.buildRequestBody(req, false)
+ if err != nil {
+ return nil, err
+ }
+
+ var resp *ChatResponse
+ err = retry.RunWithBackoff(ctx, 3, time.Second, func() error {
+ raw, perErr := c.postChat(ctx, body)
+ if perErr != nil {
+ return perErr
+ }
+ var oai openAIResponse
+ if jerr := json.Unmarshal(raw, &oai); jerr != nil {
+ return fmt.Errorf("llm: decode response: %w", jerr)
+ }
+ if len(oai.Choices) == 0 {
+ return fmt.Errorf("llm: response has no choices")
+ }
+ resp = &ChatResponse{
+ Content: oai.Choices[0].Message.Content,
+ PromptTokens: oai.Usage.PromptTokens,
+ OutputTokens: oai.Usage.CompletionTokens,
+ Model: oai.Model,
+ FinishReason: oai.Choices[0].FinishReason,
+ }
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+// ChatStream performs a streaming chat completion. onDelta is called once per
+// content delta chunk. The returned ChatResponse aggregates the full content
+// and any usage tokens reported in the final SSE chunk. Rate-limit errors at
+// connection time are retried; once streaming has begun, errors are returned.
+func (c *Client) ChatStream(ctx context.Context, req ChatRequest, onDelta func(string)) (*ChatResponse, error) {
+ if c == nil {
+ return nil, errors.New("llm: nil Client")
+ }
+ body, err := c.buildRequestBody(req, true)
+ if err != nil {
+ return nil, err
+ }
+
+ var resp *ChatResponse
+ err = retry.RunWithBackoff(ctx, 3, time.Second, func() error {
+ var perErr error
+ resp, perErr = c.streamChat(ctx, body, onDelta)
+ return perErr
+ })
+ if err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+func (c *Client) buildRequestBody(req ChatRequest, stream bool) ([]byte, error) {
+ model := req.Model
+ if model == "" {
+ model = c.Model
+ }
+ if model == "" {
+ return nil, errors.New("llm: no model configured")
+ }
+ payload := openAIRequest{
+ Model: model,
+ Messages: req.Messages,
+ Stream: stream,
+ }
+ if req.Temperature != nil {
+ payload.Temperature = req.Temperature
+ }
+ if req.MaxTokens > 0 {
+ payload.MaxTokens = req.MaxTokens
+ }
+ if req.ResponseJSON {
+ payload.ResponseFormat = &responseFormat{Type: "json_object"}
+ }
+ if stream {
+ payload.StreamOptions = &streamOptions{IncludeUsage: true}
+ }
+ return json.Marshal(payload)
+}
+
+func (c *Client) postChat(ctx context.Context, body []byte) ([]byte, error) {
+ url := strings.TrimRight(c.Endpoint, "/") + "/chat/completions"
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, fmt.Errorf("llm: build request: %w", err)
+ }
+ c.applyHeaders(httpReq)
+
+ httpResp, err := c.client().Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("llm: http: %w", err)
+ }
+ defer httpResp.Body.Close()
+ raw, err := io.ReadAll(httpResp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("llm: read body: %w", err)
+ }
+ if httpResp.StatusCode >= 400 {
+ return nil, errFromStatus(httpResp, raw)
+ }
+ return raw, nil
+}
+
+func (c *Client) streamChat(ctx context.Context, body []byte, onDelta func(string)) (*ChatResponse, error) {
+ url := strings.TrimRight(c.Endpoint, "/") + "/chat/completions"
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, fmt.Errorf("llm: build request: %w", err)
+ }
+ c.applyHeaders(httpReq)
+ httpReq.Header.Set("Accept", "text/event-stream")
+
+ httpResp, err := c.client().Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("llm: http: %w", err)
+ }
+ defer httpResp.Body.Close()
+ if httpResp.StatusCode >= 400 {
+ raw, _ := io.ReadAll(httpResp.Body)
+ return nil, errFromStatus(httpResp, raw)
+ }
+
+ var (
+ content strings.Builder
+ promptTok int
+ outputTok int
+ model string
+ finishReason string
+ )
+ scanner := bufio.NewScanner(httpResp.Body)
+ scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !strings.HasPrefix(line, "data:") {
+ continue
+ }
+ data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
+ if data == "" || data == "[DONE]" {
+ if data == "[DONE]" {
+ break
+ }
+ continue
+ }
+ var chunk openAIStreamChunk
+ if jerr := json.Unmarshal([]byte(data), &chunk); jerr != nil {
+ if c.Logger != nil {
+ c.Logger.Warn("llm: bad SSE chunk", "err", jerr, "data", data)
+ }
+ continue
+ }
+ if chunk.Model != "" {
+ model = chunk.Model
+ }
+ for _, ch := range chunk.Choices {
+ if ch.Delta.Content != "" {
+ content.WriteString(ch.Delta.Content)
+ if onDelta != nil {
+ onDelta(ch.Delta.Content)
+ }
+ }
+ if ch.FinishReason != "" {
+ finishReason = ch.FinishReason
+ }
+ }
+ if chunk.Usage != nil {
+ promptTok = chunk.Usage.PromptTokens
+ outputTok = chunk.Usage.CompletionTokens
+ }
+ }
+ if scanErr := scanner.Err(); scanErr != nil {
+ return nil, fmt.Errorf("llm: stream read: %w", scanErr)
+ }
+ return &ChatResponse{
+ Content: content.String(),
+ PromptTokens: promptTok,
+ OutputTokens: outputTok,
+ Model: model,
+ FinishReason: finishReason,
+ }, nil
+}
+
+func (c *Client) applyHeaders(req *http.Request) {
+ req.Header.Set("Content-Type", "application/json")
+ if c.APIKey != "" {
+ req.Header.Set("Authorization", "Bearer "+c.APIKey)
+ }
+}
+
+func (c *Client) client() *http.Client {
+ if c.HTTPClient != nil {
+ return c.HTTPClient
+ }
+ return &http.Client{Timeout: 60 * time.Second}
+}
+
+// errFromStatus produces an error whose message includes "rate limit", "429",
+// or "overloaded" as appropriate so retry.IsRateLimitError treats local 429/503
+// identically to upstream provider rate limits. Any Retry-After header is
+// embedded in the error message for retry.ParseRetryAfter to find.
+func errFromStatus(resp *http.Response, body []byte) error {
+ prefix := ""
+ switch resp.StatusCode {
+ case http.StatusTooManyRequests:
+ prefix = fmt.Sprintf("llm: 429 rate limit")
+ case http.StatusServiceUnavailable:
+ prefix = "llm: 503 overloaded"
+ default:
+ prefix = fmt.Sprintf("llm: http %d", resp.StatusCode)
+ }
+ if ra := resp.Header.Get("Retry-After"); ra != "" {
+ prefix += fmt.Sprintf(" (retry-after: %s)", ra)
+ }
+ snippet := strings.TrimSpace(string(body))
+ if len(snippet) > 500 {
+ snippet = snippet[:500] + "..."
+ }
+ if snippet != "" {
+ return fmt.Errorf("%s: %s", prefix, snippet)
+ }
+ return errors.New(prefix)
+}
+
+// --- OpenAI wire types ---
+
+type openAIRequest struct {
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ StreamOptions *streamOptions `json:"stream_options,omitempty"`
+ ResponseFormat *responseFormat `json:"response_format,omitempty"`
+}
+
+type streamOptions struct {
+ IncludeUsage bool `json:"include_usage"`
+}
+
+type responseFormat struct {
+ Type string `json:"type"`
+}
+
+type openAIResponse struct {
+ Model string `json:"model"`
+ Choices []openAIChoice `json:"choices"`
+ Usage openAIUsage `json:"usage"`
+}
+
+type openAIChoice struct {
+ Message Message `json:"message"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type openAIUsage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+}
+
+type openAIStreamChunk struct {
+ Model string `json:"model"`
+ Choices []openAIStreamCh `json:"choices"`
+ Usage *openAIUsage `json:"usage,omitempty"`
+}
+
+type openAIStreamCh struct {
+ Delta openAIDelta `json:"delta"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type openAIDelta struct {
+ Content string `json:"content"`
+}