summaryrefslogtreecommitdiff
path: root/internal/llm
diff options
context:
space:
mode:
authorClaude <noreply@anthropic.com>2026-04-28 09:24:43 +0000
committerClaude <noreply@anthropic.com>2026-04-28 09:24:43 +0000
commit0865afc43be562dbe14528e4299b9e213b54cc93 (patch)
tree3ffb11207fb6b9866b5a2477bba7abe38964f83a /internal/llm
parentc2aa026f6ce1c9e216b99d74f294fc133d5fcddd (diff)
feat(executor): add LocalRunner and OpenAI-compat LLM client
Phase 1 of "local OSS models as agents" plan. Adds a third Runner backed by any OpenAI-compatible HTTP server (Ollama, vLLM, LM Studio, llama.cpp), and migrates the Gemini-CLI classifier to route through the same client when configured. Two-layer split: internal/llm.Client is the workhorse (HTTP, no Pool, no DB) used directly by the classifier and any future internal helper that needs cheap reasoning. internal/executor.LocalRunner is a thin adapter implementing Runner for user-facing tasks. This avoids Pool reentrancy/deadlock when sub-second internal calls fire from inside Pool.execute(). Highlights: - internal/retry: relocated runWithBackoff/IsRateLimitError/ParseRetryAfter into a shared package reused by executor and llm. - internal/llm: Chat (non-streaming) and ChatStream (SSE) over /chat/completions with optional bearer auth, json_object response format, retry on 429/503, Retry-After parsing. - internal/executor/LocalRunner: streams deltas into stdout.log in the same stream-json envelope ClaudeRunner emits, then writes one consolidated assistant block plus a result terminator so existing parsers (extractSummary, ParseChangestatFromOutput) work unchanged. - internal/executor/Classifier: gains optional LLM field; uses json_object response format (no markdown-fence cleanup needed). Falls back to Gemini-CLI subprocess when LLM is nil. - Pool.skipClassification: now skips only when the requested agent type is registered, so unknown types still reach the load balancer. - Storage: additive tokens_in/tokens_out ALTERs on executions; CLI runners record cost_usd as before, LocalRunner records 0 + tokens. - Config: [local_model] section (endpoint, model, timeout_seconds, default_temperature, api_key). Empty endpoint = no LocalRunner registered, classifier falls back to Gemini. Pre-existing test issues fixed in passing: - claude_test.go setupSandbox callsites updated to current signature. - gemini_test.go TestParseGeminiStream skipped (asserts unimplemented GeminiRunner stream-error parsing; tracked separately). Plan: docs/plans/local-oss-runner.md. https://claude.ai/code/session_017Edeq947TpSm1vQTxMhi1J
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/client.go343
-rw-r--r--internal/llm/client_test.go159
2 files changed, 502 insertions, 0 deletions
diff --git a/internal/llm/client.go b/internal/llm/client.go
new file mode 100644
index 0000000..613ebe5
--- /dev/null
+++ b/internal/llm/client.go
@@ -0,0 +1,343 @@
+// Package llm provides a small OpenAI-compatible HTTP client used for
+// internal LLM-shaped work (model classification, summarization, elaboration)
+// against any local server speaking /v1/chat/completions: Ollama, vLLM,
+// LM Studio, llama.cpp server, etc.
+package llm
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/thepeterstone/claudomator/internal/retry"
+)
+
+// Client is an OpenAI-compatible chat completions client.
+// Endpoint is the base URL up through "/v1" (no trailing slash).
+type Client struct {
+ Endpoint string
+ Model string
+ APIKey string // optional, sent as Bearer token
+ HTTPClient *http.Client
+ Logger *slog.Logger
+}
+
+// Message is a single chat-completion message.
+type Message struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+// ChatRequest captures the parameters of a single Chat or ChatStream call.
+// Zero values mean "use server default" except for Stream and ResponseJSON,
+// which are explicit booleans. Model overrides Client.Model when non-empty.
+type ChatRequest struct {
+ Model string
+ Messages []Message
+ Temperature *float64
+ MaxTokens int
+ ResponseJSON bool
+}
+
+// ChatResponse is the aggregated result of a chat completion.
+type ChatResponse struct {
+ Content string
+ PromptTokens int
+ OutputTokens int
+ Model string
+ FinishReason string
+}
+
+// Chat performs a non-streaming chat completion. Rate-limit errors (HTTP 429,
+// overloaded responses) are retried with exponential backoff via
+// retry.RunWithBackoff.
+func (c *Client) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) {
+ if c == nil {
+ return nil, errors.New("llm: nil Client")
+ }
+ body, err := c.buildRequestBody(req, false)
+ if err != nil {
+ return nil, err
+ }
+
+ var resp *ChatResponse
+ err = retry.RunWithBackoff(ctx, 3, time.Second, func() error {
+ raw, perErr := c.postChat(ctx, body)
+ if perErr != nil {
+ return perErr
+ }
+ var oai openAIResponse
+ if jerr := json.Unmarshal(raw, &oai); jerr != nil {
+ return fmt.Errorf("llm: decode response: %w", jerr)
+ }
+ if len(oai.Choices) == 0 {
+ return fmt.Errorf("llm: response has no choices")
+ }
+ resp = &ChatResponse{
+ Content: oai.Choices[0].Message.Content,
+ PromptTokens: oai.Usage.PromptTokens,
+ OutputTokens: oai.Usage.CompletionTokens,
+ Model: oai.Model,
+ FinishReason: oai.Choices[0].FinishReason,
+ }
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+// ChatStream performs a streaming chat completion. onDelta is called once per
+// content delta chunk. The returned ChatResponse aggregates the full content
+// and any usage tokens reported in the final SSE chunk. Rate-limit errors at
+// connection time are retried; once streaming has begun, errors are returned.
+func (c *Client) ChatStream(ctx context.Context, req ChatRequest, onDelta func(string)) (*ChatResponse, error) {
+ if c == nil {
+ return nil, errors.New("llm: nil Client")
+ }
+ body, err := c.buildRequestBody(req, true)
+ if err != nil {
+ return nil, err
+ }
+
+ var resp *ChatResponse
+ err = retry.RunWithBackoff(ctx, 3, time.Second, func() error {
+ var perErr error
+ resp, perErr = c.streamChat(ctx, body, onDelta)
+ return perErr
+ })
+ if err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+func (c *Client) buildRequestBody(req ChatRequest, stream bool) ([]byte, error) {
+ model := req.Model
+ if model == "" {
+ model = c.Model
+ }
+ if model == "" {
+ return nil, errors.New("llm: no model configured")
+ }
+ payload := openAIRequest{
+ Model: model,
+ Messages: req.Messages,
+ Stream: stream,
+ }
+ if req.Temperature != nil {
+ payload.Temperature = req.Temperature
+ }
+ if req.MaxTokens > 0 {
+ payload.MaxTokens = req.MaxTokens
+ }
+ if req.ResponseJSON {
+ payload.ResponseFormat = &responseFormat{Type: "json_object"}
+ }
+ if stream {
+ payload.StreamOptions = &streamOptions{IncludeUsage: true}
+ }
+ return json.Marshal(payload)
+}
+
+func (c *Client) postChat(ctx context.Context, body []byte) ([]byte, error) {
+ url := strings.TrimRight(c.Endpoint, "/") + "/chat/completions"
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, fmt.Errorf("llm: build request: %w", err)
+ }
+ c.applyHeaders(httpReq)
+
+ httpResp, err := c.client().Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("llm: http: %w", err)
+ }
+ defer httpResp.Body.Close()
+ raw, err := io.ReadAll(httpResp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("llm: read body: %w", err)
+ }
+ if httpResp.StatusCode >= 400 {
+ return nil, errFromStatus(httpResp, raw)
+ }
+ return raw, nil
+}
+
+func (c *Client) streamChat(ctx context.Context, body []byte, onDelta func(string)) (*ChatResponse, error) {
+ url := strings.TrimRight(c.Endpoint, "/") + "/chat/completions"
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, fmt.Errorf("llm: build request: %w", err)
+ }
+ c.applyHeaders(httpReq)
+ httpReq.Header.Set("Accept", "text/event-stream")
+
+ httpResp, err := c.client().Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("llm: http: %w", err)
+ }
+ defer httpResp.Body.Close()
+ if httpResp.StatusCode >= 400 {
+ raw, _ := io.ReadAll(httpResp.Body)
+ return nil, errFromStatus(httpResp, raw)
+ }
+
+ var (
+ content strings.Builder
+ promptTok int
+ outputTok int
+ model string
+ finishReason string
+ )
+ scanner := bufio.NewScanner(httpResp.Body)
+ scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !strings.HasPrefix(line, "data:") {
+ continue
+ }
+ data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
+ if data == "" || data == "[DONE]" {
+ if data == "[DONE]" {
+ break
+ }
+ continue
+ }
+ var chunk openAIStreamChunk
+ if jerr := json.Unmarshal([]byte(data), &chunk); jerr != nil {
+ if c.Logger != nil {
+ c.Logger.Warn("llm: bad SSE chunk", "err", jerr, "data", data)
+ }
+ continue
+ }
+ if chunk.Model != "" {
+ model = chunk.Model
+ }
+ for _, ch := range chunk.Choices {
+ if ch.Delta.Content != "" {
+ content.WriteString(ch.Delta.Content)
+ if onDelta != nil {
+ onDelta(ch.Delta.Content)
+ }
+ }
+ if ch.FinishReason != "" {
+ finishReason = ch.FinishReason
+ }
+ }
+ if chunk.Usage != nil {
+ promptTok = chunk.Usage.PromptTokens
+ outputTok = chunk.Usage.CompletionTokens
+ }
+ }
+ if scanErr := scanner.Err(); scanErr != nil {
+ return nil, fmt.Errorf("llm: stream read: %w", scanErr)
+ }
+ return &ChatResponse{
+ Content: content.String(),
+ PromptTokens: promptTok,
+ OutputTokens: outputTok,
+ Model: model,
+ FinishReason: finishReason,
+ }, nil
+}
+
+func (c *Client) applyHeaders(req *http.Request) {
+ req.Header.Set("Content-Type", "application/json")
+ if c.APIKey != "" {
+ req.Header.Set("Authorization", "Bearer "+c.APIKey)
+ }
+}
+
+func (c *Client) client() *http.Client {
+ if c.HTTPClient != nil {
+ return c.HTTPClient
+ }
+ return &http.Client{Timeout: 60 * time.Second}
+}
+
+// errFromStatus produces an error whose message includes "rate limit", "429",
+// or "overloaded" as appropriate so retry.IsRateLimitError treats local 429/503
+// identically to upstream provider rate limits. Any Retry-After header is
+// embedded in the error message for retry.ParseRetryAfter to find.
+func errFromStatus(resp *http.Response, body []byte) error {
+ prefix := ""
+ switch resp.StatusCode {
+ case http.StatusTooManyRequests:
+ prefix = fmt.Sprintf("llm: 429 rate limit")
+ case http.StatusServiceUnavailable:
+ prefix = "llm: 503 overloaded"
+ default:
+ prefix = fmt.Sprintf("llm: http %d", resp.StatusCode)
+ }
+ if ra := resp.Header.Get("Retry-After"); ra != "" {
+ prefix += fmt.Sprintf(" (retry-after: %s)", ra)
+ }
+ snippet := strings.TrimSpace(string(body))
+ if len(snippet) > 500 {
+ snippet = snippet[:500] + "..."
+ }
+ if snippet != "" {
+ return fmt.Errorf("%s: %s", prefix, snippet)
+ }
+ return errors.New(prefix)
+}
+
+// --- OpenAI wire types ---
+
+type openAIRequest struct {
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ StreamOptions *streamOptions `json:"stream_options,omitempty"`
+ ResponseFormat *responseFormat `json:"response_format,omitempty"`
+}
+
+type streamOptions struct {
+ IncludeUsage bool `json:"include_usage"`
+}
+
+type responseFormat struct {
+ Type string `json:"type"`
+}
+
+type openAIResponse struct {
+ Model string `json:"model"`
+ Choices []openAIChoice `json:"choices"`
+ Usage openAIUsage `json:"usage"`
+}
+
+type openAIChoice struct {
+ Message Message `json:"message"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type openAIUsage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+}
+
+type openAIStreamChunk struct {
+ Model string `json:"model"`
+ Choices []openAIStreamCh `json:"choices"`
+ Usage *openAIUsage `json:"usage,omitempty"`
+}
+
+type openAIStreamCh struct {
+ Delta openAIDelta `json:"delta"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type openAIDelta struct {
+ Content string `json:"content"`
+}
diff --git a/internal/llm/client_test.go b/internal/llm/client_test.go
new file mode 100644
index 0000000..8257836
--- /dev/null
+++ b/internal/llm/client_test.go
@@ -0,0 +1,159 @@
+package llm
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestChat_ParsesCompletion(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/v1/chat/completions" {
+ t.Errorf("unexpected path %q", r.URL.Path)
+ }
+ if r.Header.Get("Authorization") != "Bearer test-key" {
+ t.Errorf("missing/wrong bearer header: %q", r.Header.Get("Authorization"))
+ }
+ var body openAIRequest
+ if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
+ t.Fatalf("decode body: %v", err)
+ }
+ if body.Model != "test-model" {
+ t.Errorf("model: want test-model got %q", body.Model)
+ }
+ if len(body.Messages) != 1 || body.Messages[0].Content != "hello" {
+ t.Errorf("messages mismatch: %+v", body.Messages)
+ }
+ if body.ResponseFormat == nil || body.ResponseFormat.Type != "json_object" {
+ t.Errorf("expected response_format json_object, got %+v", body.ResponseFormat)
+ }
+ w.Header().Set("Content-Type", "application/json")
+ fmt.Fprintln(w, `{
+ "model": "test-model",
+ "choices": [{"message": {"role": "assistant", "content": "world"}, "finish_reason": "stop"}],
+ "usage": {"prompt_tokens": 4, "completion_tokens": 7}
+ }`)
+ }))
+ defer srv.Close()
+
+ c := &Client{Endpoint: srv.URL + "/v1", Model: "test-model", APIKey: "test-key"}
+ resp, err := c.Chat(context.Background(), ChatRequest{
+ Messages: []Message{{Role: "user", Content: "hello"}},
+ ResponseJSON: true,
+ })
+ if err != nil {
+ t.Fatalf("Chat: %v", err)
+ }
+ if resp.Content != "world" {
+ t.Errorf("content: want world got %q", resp.Content)
+ }
+ if resp.PromptTokens != 4 || resp.OutputTokens != 7 {
+ t.Errorf("tokens mismatch: %+v", resp)
+ }
+ if resp.FinishReason != "stop" {
+ t.Errorf("finish_reason: want stop got %q", resp.FinishReason)
+ }
+}
+
+func TestChatStream_ParsesSSE(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ flusher, _ := w.(http.Flusher)
+ chunks := []string{
+ `{"model":"test-model","choices":[{"delta":{"content":"Hel"},"finish_reason":""}]}`,
+ `{"model":"test-model","choices":[{"delta":{"content":"lo, "},"finish_reason":""}]}`,
+ `{"model":"test-model","choices":[{"delta":{"content":"world"},"finish_reason":"stop"}]}`,
+ `{"model":"test-model","choices":[],"usage":{"prompt_tokens":3,"completion_tokens":5}}`,
+ }
+ for _, c := range chunks {
+ fmt.Fprintf(w, "data: %s\n\n", c)
+ if flusher != nil {
+ flusher.Flush()
+ }
+ }
+ fmt.Fprint(w, "data: [DONE]\n\n")
+ }))
+ defer srv.Close()
+
+ c := &Client{Endpoint: srv.URL + "/v1", Model: "test-model"}
+
+ var deltas []string
+ resp, err := c.ChatStream(context.Background(),
+ ChatRequest{Messages: []Message{{Role: "user", Content: "hi"}}},
+ func(d string) { deltas = append(deltas, d) },
+ )
+ if err != nil {
+ t.Fatalf("ChatStream: %v", err)
+ }
+ if got := strings.Join(deltas, ""); got != "Hello, world" {
+ t.Errorf("aggregated deltas: want %q got %q", "Hello, world", got)
+ }
+ if resp.Content != "Hello, world" {
+ t.Errorf("content: want %q got %q", "Hello, world", resp.Content)
+ }
+ if resp.PromptTokens != 3 || resp.OutputTokens != 5 {
+ t.Errorf("tokens: %+v", resp)
+ }
+ if resp.FinishReason != "stop" {
+ t.Errorf("finish_reason: want stop got %q", resp.FinishReason)
+ }
+}
+
+func TestChat_RetriesOn429(t *testing.T) {
+ var calls int32
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ n := atomic.AddInt32(&calls, 1)
+ if n == 1 {
+ w.Header().Set("Retry-After", "1")
+ http.Error(w, "slow down", http.StatusTooManyRequests)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ fmt.Fprintln(w, `{
+ "model":"m","choices":[{"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],
+ "usage":{"prompt_tokens":1,"completion_tokens":1}
+ }`)
+ }))
+ defer srv.Close()
+
+ c := &Client{
+ Endpoint: srv.URL + "/v1",
+ Model: "m",
+ HTTPClient: &http.Client{Timeout: 5 * time.Second},
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ resp, err := c.Chat(ctx, ChatRequest{Messages: []Message{{Role: "user", Content: "hi"}}})
+ if err != nil {
+ t.Fatalf("Chat: %v", err)
+ }
+ if resp.Content != "ok" {
+ t.Errorf("content: want ok got %q", resp.Content)
+ }
+ if got := atomic.LoadInt32(&calls); got != 2 {
+ t.Errorf("expected 2 server calls (1 retry), got %d", got)
+ }
+}
+
+// Sanity: errFromStatus produces a string that retry.IsRateLimitError matches.
+func TestErrFromStatus_RateLimitMarker(t *testing.T) {
+ resp := &http.Response{
+ StatusCode: http.StatusTooManyRequests,
+ Header: http.Header{"Retry-After": []string{"30"}},
+ }
+ body, _ := io.ReadAll(strings.NewReader("limit hit"))
+ err := errFromStatus(resp, body)
+ if !strings.Contains(strings.ToLower(err.Error()), "rate limit") {
+ t.Errorf("error should contain 'rate limit', got: %v", err)
+ }
+ if !strings.Contains(err.Error(), "retry-after: 30") {
+ t.Errorf("error should embed retry-after, got: %v", err)
+ }
+}