summaryrefslogtreecommitdiff
path: root/internal/llm/client_test.go
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-05-01 22:14:37 -1000
committerGitHub <noreply@github.com>2026-05-01 22:14:37 -1000
commit99115d8158137083239c45e5a860b718ff4cefa1 (patch)
tree1bf3bd0505eea79375c67af83c7c5fe8c0f274ff /internal/llm/client_test.go
parentc2aa026f6ce1c9e216b99d74f294fc133d5fcddd (diff)
parent50f8fe8c1ff8b82e0bd399e5776e58bda3e57d1c (diff)
Merge pull request #1 from thepeterstone/claude/local-oss-model-agents-MEBqj
Local OSS models as a third runner (epic)
Diffstat (limited to 'internal/llm/client_test.go')
-rw-r--r--internal/llm/client_test.go159
1 files changed, 159 insertions, 0 deletions
diff --git a/internal/llm/client_test.go b/internal/llm/client_test.go
new file mode 100644
index 0000000..8257836
--- /dev/null
+++ b/internal/llm/client_test.go
@@ -0,0 +1,159 @@
+package llm
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestChat_ParsesCompletion(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/v1/chat/completions" {
+ t.Errorf("unexpected path %q", r.URL.Path)
+ }
+ if r.Header.Get("Authorization") != "Bearer test-key" {
+ t.Errorf("missing/wrong bearer header: %q", r.Header.Get("Authorization"))
+ }
+ var body openAIRequest
+ if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
+ t.Fatalf("decode body: %v", err)
+ }
+ if body.Model != "test-model" {
+ t.Errorf("model: want test-model got %q", body.Model)
+ }
+ if len(body.Messages) != 1 || body.Messages[0].Content != "hello" {
+ t.Errorf("messages mismatch: %+v", body.Messages)
+ }
+ if body.ResponseFormat == nil || body.ResponseFormat.Type != "json_object" {
+ t.Errorf("expected response_format json_object, got %+v", body.ResponseFormat)
+ }
+ w.Header().Set("Content-Type", "application/json")
+ fmt.Fprintln(w, `{
+ "model": "test-model",
+ "choices": [{"message": {"role": "assistant", "content": "world"}, "finish_reason": "stop"}],
+ "usage": {"prompt_tokens": 4, "completion_tokens": 7}
+ }`)
+ }))
+ defer srv.Close()
+
+ c := &Client{Endpoint: srv.URL + "/v1", Model: "test-model", APIKey: "test-key"}
+ resp, err := c.Chat(context.Background(), ChatRequest{
+ Messages: []Message{{Role: "user", Content: "hello"}},
+ ResponseJSON: true,
+ })
+ if err != nil {
+ t.Fatalf("Chat: %v", err)
+ }
+ if resp.Content != "world" {
+ t.Errorf("content: want world got %q", resp.Content)
+ }
+ if resp.PromptTokens != 4 || resp.OutputTokens != 7 {
+ t.Errorf("tokens mismatch: %+v", resp)
+ }
+ if resp.FinishReason != "stop" {
+ t.Errorf("finish_reason: want stop got %q", resp.FinishReason)
+ }
+}
+
+func TestChatStream_ParsesSSE(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ flusher, _ := w.(http.Flusher)
+ chunks := []string{
+ `{"model":"test-model","choices":[{"delta":{"content":"Hel"},"finish_reason":""}]}`,
+ `{"model":"test-model","choices":[{"delta":{"content":"lo, "},"finish_reason":""}]}`,
+ `{"model":"test-model","choices":[{"delta":{"content":"world"},"finish_reason":"stop"}]}`,
+ `{"model":"test-model","choices":[],"usage":{"prompt_tokens":3,"completion_tokens":5}}`,
+ }
+ for _, c := range chunks {
+ fmt.Fprintf(w, "data: %s\n\n", c)
+ if flusher != nil {
+ flusher.Flush()
+ }
+ }
+ fmt.Fprint(w, "data: [DONE]\n\n")
+ }))
+ defer srv.Close()
+
+ c := &Client{Endpoint: srv.URL + "/v1", Model: "test-model"}
+
+ var deltas []string
+ resp, err := c.ChatStream(context.Background(),
+ ChatRequest{Messages: []Message{{Role: "user", Content: "hi"}}},
+ func(d string) { deltas = append(deltas, d) },
+ )
+ if err != nil {
+ t.Fatalf("ChatStream: %v", err)
+ }
+ if got := strings.Join(deltas, ""); got != "Hello, world" {
+ t.Errorf("aggregated deltas: want %q got %q", "Hello, world", got)
+ }
+ if resp.Content != "Hello, world" {
+ t.Errorf("content: want %q got %q", "Hello, world", resp.Content)
+ }
+ if resp.PromptTokens != 3 || resp.OutputTokens != 5 {
+ t.Errorf("tokens: %+v", resp)
+ }
+ if resp.FinishReason != "stop" {
+ t.Errorf("finish_reason: want stop got %q", resp.FinishReason)
+ }
+}
+
+func TestChat_RetriesOn429(t *testing.T) {
+ var calls int32
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ n := atomic.AddInt32(&calls, 1)
+ if n == 1 {
+ w.Header().Set("Retry-After", "1")
+ http.Error(w, "slow down", http.StatusTooManyRequests)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ fmt.Fprintln(w, `{
+ "model":"m","choices":[{"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],
+ "usage":{"prompt_tokens":1,"completion_tokens":1}
+ }`)
+ }))
+ defer srv.Close()
+
+ c := &Client{
+ Endpoint: srv.URL + "/v1",
+ Model: "m",
+ HTTPClient: &http.Client{Timeout: 5 * time.Second},
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ resp, err := c.Chat(ctx, ChatRequest{Messages: []Message{{Role: "user", Content: "hi"}}})
+ if err != nil {
+ t.Fatalf("Chat: %v", err)
+ }
+ if resp.Content != "ok" {
+ t.Errorf("content: want ok got %q", resp.Content)
+ }
+ if got := atomic.LoadInt32(&calls); got != 2 {
+ t.Errorf("expected 2 server calls (1 retry), got %d", got)
+ }
+}
+
+// Sanity: errFromStatus produces a string that retry.IsRateLimitError matches.
+func TestErrFromStatus_RateLimitMarker(t *testing.T) {
+ resp := &http.Response{
+ StatusCode: http.StatusTooManyRequests,
+ Header: http.Header{"Retry-After": []string{"30"}},
+ }
+ body, _ := io.ReadAll(strings.NewReader("limit hit"))
+ err := errFromStatus(resp, body)
+ if !strings.Contains(strings.ToLower(err.Error()), "rate limit") {
+ t.Errorf("error should contain 'rate limit', got: %v", err)
+ }
+ if !strings.Contains(err.Error(), "retry-after: 30") {
+ t.Errorf("error should embed retry-after, got: %v", err)
+ }
+}