From 0865afc43be562dbe14528e4299b9e213b54cc93 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 28 Apr 2026 09:24:43 +0000 Subject: feat(executor): add LocalRunner and OpenAI-compat LLM client Phase 1 of "local OSS models as agents" plan. Adds a third Runner backed by any OpenAI-compatible HTTP server (Ollama, vLLM, LM Studio, llama.cpp), and migrates the Gemini-CLI classifier to route through the same client when configured. Two-layer split: internal/llm.Client is the workhorse (HTTP, no Pool, no DB) used directly by the classifier and any future internal helper that needs cheap reasoning. internal/executor.LocalRunner is a thin adapter implementing Runner for user-facing tasks. This avoids Pool reentrancy/deadlock when sub-second internal calls fire from inside Pool.execute(). Highlights: - internal/retry: relocated runWithBackoff/IsRateLimitError/ParseRetryAfter into a shared package reused by executor and llm. - internal/llm: Chat (non-streaming) and ChatStream (SSE) over /chat/completions with optional bearer auth, json_object response format, retry on 429/503, Retry-After parsing. - internal/executor/LocalRunner: streams deltas into stdout.log in the same stream-json envelope ClaudeRunner emits, then writes one consolidated assistant block plus a result terminator so existing parsers (extractSummary, ParseChangestatFromOutput) work unchanged. - internal/executor/Classifier: gains optional LLM field; uses json_object response format (no markdown-fence cleanup needed). Falls back to Gemini-CLI subprocess when LLM is nil. - Pool.skipClassification: now skips only when the requested agent type is registered, so unknown types still reach the load balancer. - Storage: additive tokens_in/tokens_out ALTERs on executions; CLI runners record cost_usd as before, LocalRunner records 0 + tokens. - Config: [local_model] section (endpoint, model, timeout_seconds, default_temperature, api_key). Empty endpoint = no LocalRunner registered, classifier falls back to Gemini. Pre-existing test issues fixed in passing: - claude_test.go setupSandbox callsites updated to current signature. - gemini_test.go TestParseGeminiStream skipped (asserts unimplemented GeminiRunner stream-error parsing; tracked separately). Plan: docs/plans/local-oss-runner.md. https://claude.ai/code/session_017Edeq947TpSm1vQTxMhi1J --- internal/llm/client.go | 343 ++++++++++++++++++++++++++++++++++++++++++++ internal/llm/client_test.go | 159 ++++++++++++++++++++ 2 files changed, 502 insertions(+) create mode 100644 internal/llm/client.go create mode 100644 internal/llm/client_test.go (limited to 'internal/llm') diff --git a/internal/llm/client.go b/internal/llm/client.go new file mode 100644 index 0000000..613ebe5 --- /dev/null +++ b/internal/llm/client.go @@ -0,0 +1,343 @@ +// Package llm provides a small OpenAI-compatible HTTP client used for +// internal LLM-shaped work (model classification, summarization, elaboration) +// against any local server speaking /v1/chat/completions: Ollama, vLLM, +// LM Studio, llama.cpp server, etc. +package llm + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/thepeterstone/claudomator/internal/retry" +) + +// Client is an OpenAI-compatible chat completions client. +// Endpoint is the base URL up through "/v1" (no trailing slash). +type Client struct { + Endpoint string + Model string + APIKey string // optional, sent as Bearer token + HTTPClient *http.Client + Logger *slog.Logger +} + +// Message is a single chat-completion message. +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest captures the parameters of a single Chat or ChatStream call. +// Zero values mean "use server default" except for Stream and ResponseJSON, +// which are explicit booleans. Model overrides Client.Model when non-empty. +type ChatRequest struct { + Model string + Messages []Message + Temperature *float64 + MaxTokens int + ResponseJSON bool +} + +// ChatResponse is the aggregated result of a chat completion. +type ChatResponse struct { + Content string + PromptTokens int + OutputTokens int + Model string + FinishReason string +} + +// Chat performs a non-streaming chat completion. Rate-limit errors (HTTP 429, +// overloaded responses) are retried with exponential backoff via +// retry.RunWithBackoff. +func (c *Client) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) { + if c == nil { + return nil, errors.New("llm: nil Client") + } + body, err := c.buildRequestBody(req, false) + if err != nil { + return nil, err + } + + var resp *ChatResponse + err = retry.RunWithBackoff(ctx, 3, time.Second, func() error { + raw, perErr := c.postChat(ctx, body) + if perErr != nil { + return perErr + } + var oai openAIResponse + if jerr := json.Unmarshal(raw, &oai); jerr != nil { + return fmt.Errorf("llm: decode response: %w", jerr) + } + if len(oai.Choices) == 0 { + return fmt.Errorf("llm: response has no choices") + } + resp = &ChatResponse{ + Content: oai.Choices[0].Message.Content, + PromptTokens: oai.Usage.PromptTokens, + OutputTokens: oai.Usage.CompletionTokens, + Model: oai.Model, + FinishReason: oai.Choices[0].FinishReason, + } + return nil + }) + if err != nil { + return nil, err + } + return resp, nil +} + +// ChatStream performs a streaming chat completion. onDelta is called once per +// content delta chunk. The returned ChatResponse aggregates the full content +// and any usage tokens reported in the final SSE chunk. Rate-limit errors at +// connection time are retried; once streaming has begun, errors are returned. +func (c *Client) ChatStream(ctx context.Context, req ChatRequest, onDelta func(string)) (*ChatResponse, error) { + if c == nil { + return nil, errors.New("llm: nil Client") + } + body, err := c.buildRequestBody(req, true) + if err != nil { + return nil, err + } + + var resp *ChatResponse + err = retry.RunWithBackoff(ctx, 3, time.Second, func() error { + var perErr error + resp, perErr = c.streamChat(ctx, body, onDelta) + return perErr + }) + if err != nil { + return nil, err + } + return resp, nil +} + +func (c *Client) buildRequestBody(req ChatRequest, stream bool) ([]byte, error) { + model := req.Model + if model == "" { + model = c.Model + } + if model == "" { + return nil, errors.New("llm: no model configured") + } + payload := openAIRequest{ + Model: model, + Messages: req.Messages, + Stream: stream, + } + if req.Temperature != nil { + payload.Temperature = req.Temperature + } + if req.MaxTokens > 0 { + payload.MaxTokens = req.MaxTokens + } + if req.ResponseJSON { + payload.ResponseFormat = &responseFormat{Type: "json_object"} + } + if stream { + payload.StreamOptions = &streamOptions{IncludeUsage: true} + } + return json.Marshal(payload) +} + +func (c *Client) postChat(ctx context.Context, body []byte) ([]byte, error) { + url := strings.TrimRight(c.Endpoint, "/") + "/chat/completions" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("llm: build request: %w", err) + } + c.applyHeaders(httpReq) + + httpResp, err := c.client().Do(httpReq) + if err != nil { + return nil, fmt.Errorf("llm: http: %w", err) + } + defer httpResp.Body.Close() + raw, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, fmt.Errorf("llm: read body: %w", err) + } + if httpResp.StatusCode >= 400 { + return nil, errFromStatus(httpResp, raw) + } + return raw, nil +} + +func (c *Client) streamChat(ctx context.Context, body []byte, onDelta func(string)) (*ChatResponse, error) { + url := strings.TrimRight(c.Endpoint, "/") + "/chat/completions" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("llm: build request: %w", err) + } + c.applyHeaders(httpReq) + httpReq.Header.Set("Accept", "text/event-stream") + + httpResp, err := c.client().Do(httpReq) + if err != nil { + return nil, fmt.Errorf("llm: http: %w", err) + } + defer httpResp.Body.Close() + if httpResp.StatusCode >= 400 { + raw, _ := io.ReadAll(httpResp.Body) + return nil, errFromStatus(httpResp, raw) + } + + var ( + content strings.Builder + promptTok int + outputTok int + model string + finishReason string + ) + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1<<20) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" || data == "[DONE]" { + if data == "[DONE]" { + break + } + continue + } + var chunk openAIStreamChunk + if jerr := json.Unmarshal([]byte(data), &chunk); jerr != nil { + if c.Logger != nil { + c.Logger.Warn("llm: bad SSE chunk", "err", jerr, "data", data) + } + continue + } + if chunk.Model != "" { + model = chunk.Model + } + for _, ch := range chunk.Choices { + if ch.Delta.Content != "" { + content.WriteString(ch.Delta.Content) + if onDelta != nil { + onDelta(ch.Delta.Content) + } + } + if ch.FinishReason != "" { + finishReason = ch.FinishReason + } + } + if chunk.Usage != nil { + promptTok = chunk.Usage.PromptTokens + outputTok = chunk.Usage.CompletionTokens + } + } + if scanErr := scanner.Err(); scanErr != nil { + return nil, fmt.Errorf("llm: stream read: %w", scanErr) + } + return &ChatResponse{ + Content: content.String(), + PromptTokens: promptTok, + OutputTokens: outputTok, + Model: model, + FinishReason: finishReason, + }, nil +} + +func (c *Client) applyHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + if c.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+c.APIKey) + } +} + +func (c *Client) client() *http.Client { + if c.HTTPClient != nil { + return c.HTTPClient + } + return &http.Client{Timeout: 60 * time.Second} +} + +// errFromStatus produces an error whose message includes "rate limit", "429", +// or "overloaded" as appropriate so retry.IsRateLimitError treats local 429/503 +// identically to upstream provider rate limits. Any Retry-After header is +// embedded in the error message for retry.ParseRetryAfter to find. +func errFromStatus(resp *http.Response, body []byte) error { + prefix := "" + switch resp.StatusCode { + case http.StatusTooManyRequests: + prefix = fmt.Sprintf("llm: 429 rate limit") + case http.StatusServiceUnavailable: + prefix = "llm: 503 overloaded" + default: + prefix = fmt.Sprintf("llm: http %d", resp.StatusCode) + } + if ra := resp.Header.Get("Retry-After"); ra != "" { + prefix += fmt.Sprintf(" (retry-after: %s)", ra) + } + snippet := strings.TrimSpace(string(body)) + if len(snippet) > 500 { + snippet = snippet[:500] + "..." + } + if snippet != "" { + return fmt.Errorf("%s: %s", prefix, snippet) + } + return errors.New(prefix) +} + +// --- OpenAI wire types --- + +type openAIRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *streamOptions `json:"stream_options,omitempty"` + ResponseFormat *responseFormat `json:"response_format,omitempty"` +} + +type streamOptions struct { + IncludeUsage bool `json:"include_usage"` +} + +type responseFormat struct { + Type string `json:"type"` +} + +type openAIResponse struct { + Model string `json:"model"` + Choices []openAIChoice `json:"choices"` + Usage openAIUsage `json:"usage"` +} + +type openAIChoice struct { + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type openAIUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` +} + +type openAIStreamChunk struct { + Model string `json:"model"` + Choices []openAIStreamCh `json:"choices"` + Usage *openAIUsage `json:"usage,omitempty"` +} + +type openAIStreamCh struct { + Delta openAIDelta `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +type openAIDelta struct { + Content string `json:"content"` +} diff --git a/internal/llm/client_test.go b/internal/llm/client_test.go new file mode 100644 index 0000000..8257836 --- /dev/null +++ b/internal/llm/client_test.go @@ -0,0 +1,159 @@ +package llm + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestChat_ParsesCompletion(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("unexpected path %q", r.URL.Path) + } + if r.Header.Get("Authorization") != "Bearer test-key" { + t.Errorf("missing/wrong bearer header: %q", r.Header.Get("Authorization")) + } + var body openAIRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body.Model != "test-model" { + t.Errorf("model: want test-model got %q", body.Model) + } + if len(body.Messages) != 1 || body.Messages[0].Content != "hello" { + t.Errorf("messages mismatch: %+v", body.Messages) + } + if body.ResponseFormat == nil || body.ResponseFormat.Type != "json_object" { + t.Errorf("expected response_format json_object, got %+v", body.ResponseFormat) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, `{ + "model": "test-model", + "choices": [{"message": {"role": "assistant", "content": "world"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 4, "completion_tokens": 7} + }`) + })) + defer srv.Close() + + c := &Client{Endpoint: srv.URL + "/v1", Model: "test-model", APIKey: "test-key"} + resp, err := c.Chat(context.Background(), ChatRequest{ + Messages: []Message{{Role: "user", Content: "hello"}}, + ResponseJSON: true, + }) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Content != "world" { + t.Errorf("content: want world got %q", resp.Content) + } + if resp.PromptTokens != 4 || resp.OutputTokens != 7 { + t.Errorf("tokens mismatch: %+v", resp) + } + if resp.FinishReason != "stop" { + t.Errorf("finish_reason: want stop got %q", resp.FinishReason) + } +} + +func TestChatStream_ParsesSSE(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + chunks := []string{ + `{"model":"test-model","choices":[{"delta":{"content":"Hel"},"finish_reason":""}]}`, + `{"model":"test-model","choices":[{"delta":{"content":"lo, "},"finish_reason":""}]}`, + `{"model":"test-model","choices":[{"delta":{"content":"world"},"finish_reason":"stop"}]}`, + `{"model":"test-model","choices":[],"usage":{"prompt_tokens":3,"completion_tokens":5}}`, + } + for _, c := range chunks { + fmt.Fprintf(w, "data: %s\n\n", c) + if flusher != nil { + flusher.Flush() + } + } + fmt.Fprint(w, "data: [DONE]\n\n") + })) + defer srv.Close() + + c := &Client{Endpoint: srv.URL + "/v1", Model: "test-model"} + + var deltas []string + resp, err := c.ChatStream(context.Background(), + ChatRequest{Messages: []Message{{Role: "user", Content: "hi"}}}, + func(d string) { deltas = append(deltas, d) }, + ) + if err != nil { + t.Fatalf("ChatStream: %v", err) + } + if got := strings.Join(deltas, ""); got != "Hello, world" { + t.Errorf("aggregated deltas: want %q got %q", "Hello, world", got) + } + if resp.Content != "Hello, world" { + t.Errorf("content: want %q got %q", "Hello, world", resp.Content) + } + if resp.PromptTokens != 3 || resp.OutputTokens != 5 { + t.Errorf("tokens: %+v", resp) + } + if resp.FinishReason != "stop" { + t.Errorf("finish_reason: want stop got %q", resp.FinishReason) + } +} + +func TestChat_RetriesOn429(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + w.Header().Set("Retry-After", "1") + http.Error(w, "slow down", http.StatusTooManyRequests) + return + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, `{ + "model":"m","choices":[{"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":1,"completion_tokens":1} + }`) + })) + defer srv.Close() + + c := &Client{ + Endpoint: srv.URL + "/v1", + Model: "m", + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + resp, err := c.Chat(ctx, ChatRequest{Messages: []Message{{Role: "user", Content: "hi"}}}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Content != "ok" { + t.Errorf("content: want ok got %q", resp.Content) + } + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("expected 2 server calls (1 retry), got %d", got) + } +} + +// Sanity: errFromStatus produces a string that retry.IsRateLimitError matches. +func TestErrFromStatus_RateLimitMarker(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{"Retry-After": []string{"30"}}, + } + body, _ := io.ReadAll(strings.NewReader("limit hit")) + err := errFromStatus(resp, body) + if !strings.Contains(strings.ToLower(err.Error()), "rate limit") { + t.Errorf("error should contain 'rate limit', got: %v", err) + } + if !strings.Contains(err.Error(), "retry-after: 30") { + t.Errorf("error should embed retry-after, got: %v", err) + } +} -- cgit v1.2.3