package llm import ( "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" ) func TestChat_ParsesCompletion(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/chat/completions" { t.Errorf("unexpected path %q", r.URL.Path) } if r.Header.Get("Authorization") != "Bearer test-key" { t.Errorf("missing/wrong bearer header: %q", r.Header.Get("Authorization")) } var body openAIRequest if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("decode body: %v", err) } if body.Model != "test-model" { t.Errorf("model: want test-model got %q", body.Model) } if len(body.Messages) != 1 || body.Messages[0].Content != "hello" { t.Errorf("messages mismatch: %+v", body.Messages) } if body.ResponseFormat == nil || body.ResponseFormat.Type != "json_object" { t.Errorf("expected response_format json_object, got %+v", body.ResponseFormat) } w.Header().Set("Content-Type", "application/json") fmt.Fprintln(w, `{ "model": "test-model", "choices": [{"message": {"role": "assistant", "content": "world"}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 4, "completion_tokens": 7} }`) })) defer srv.Close() c := &Client{Endpoint: srv.URL + "/v1", Model: "test-model", APIKey: "test-key"} resp, err := c.Chat(context.Background(), ChatRequest{ Messages: []Message{{Role: "user", Content: "hello"}}, ResponseJSON: true, }) if err != nil { t.Fatalf("Chat: %v", err) } if resp.Content != "world" { t.Errorf("content: want world got %q", resp.Content) } if resp.PromptTokens != 4 || resp.OutputTokens != 7 { t.Errorf("tokens mismatch: %+v", resp) } if resp.FinishReason != "stop" { t.Errorf("finish_reason: want stop got %q", resp.FinishReason) } } func TestChatStream_ParsesSSE(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) chunks := []string{ `{"model":"test-model","choices":[{"delta":{"content":"Hel"},"finish_reason":""}]}`, `{"model":"test-model","choices":[{"delta":{"content":"lo, "},"finish_reason":""}]}`, `{"model":"test-model","choices":[{"delta":{"content":"world"},"finish_reason":"stop"}]}`, `{"model":"test-model","choices":[],"usage":{"prompt_tokens":3,"completion_tokens":5}}`, } for _, c := range chunks { fmt.Fprintf(w, "data: %s\n\n", c) if flusher != nil { flusher.Flush() } } fmt.Fprint(w, "data: [DONE]\n\n") })) defer srv.Close() c := &Client{Endpoint: srv.URL + "/v1", Model: "test-model"} var deltas []string resp, err := c.ChatStream(context.Background(), ChatRequest{Messages: []Message{{Role: "user", Content: "hi"}}}, func(d string) { deltas = append(deltas, d) }, ) if err != nil { t.Fatalf("ChatStream: %v", err) } if got := strings.Join(deltas, ""); got != "Hello, world" { t.Errorf("aggregated deltas: want %q got %q", "Hello, world", got) } if resp.Content != "Hello, world" { t.Errorf("content: want %q got %q", "Hello, world", resp.Content) } if resp.PromptTokens != 3 || resp.OutputTokens != 5 { t.Errorf("tokens: %+v", resp) } if resp.FinishReason != "stop" { t.Errorf("finish_reason: want stop got %q", resp.FinishReason) } } func TestChat_RetriesOn429(t *testing.T) { var calls int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := atomic.AddInt32(&calls, 1) if n == 1 { w.Header().Set("Retry-After", "1") http.Error(w, "slow down", http.StatusTooManyRequests) return } w.Header().Set("Content-Type", "application/json") fmt.Fprintln(w, `{ "model":"m","choices":[{"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}], "usage":{"prompt_tokens":1,"completion_tokens":1} }`) })) defer srv.Close() c := &Client{ Endpoint: srv.URL + "/v1", Model: "m", HTTPClient: &http.Client{Timeout: 5 * time.Second}, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() resp, err := c.Chat(ctx, ChatRequest{Messages: []Message{{Role: "user", Content: "hi"}}}) if err != nil { t.Fatalf("Chat: %v", err) } if resp.Content != "ok" { t.Errorf("content: want ok got %q", resp.Content) } if got := atomic.LoadInt32(&calls); got != 2 { t.Errorf("expected 2 server calls (1 retry), got %d", got) } } // Sanity: errFromStatus produces a string that retry.IsRateLimitError matches. func TestErrFromStatus_RateLimitMarker(t *testing.T) { resp := &http.Response{ StatusCode: http.StatusTooManyRequests, Header: http.Header{"Retry-After": []string{"30"}}, } body, _ := io.ReadAll(strings.NewReader("limit hit")) err := errFromStatus(resp, body) if !strings.Contains(strings.ToLower(err.Error()), "rate limit") { t.Errorf("error should contain 'rate limit', got: %v", err) } if !strings.Contains(err.Error(), "retry-after: 30") { t.Errorf("error should embed retry-after, got: %v", err) } }