summaryrefslogtreecommitdiff
path: root/internal/api/elaborate_local_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/elaborate_local_test.go')
-rw-r--r--internal/api/elaborate_local_test.go214
1 files changed, 214 insertions, 0 deletions
diff --git a/internal/api/elaborate_local_test.go b/internal/api/elaborate_local_test.go
new file mode 100644
index 0000000..09a8f9e
--- /dev/null
+++ b/internal/api/elaborate_local_test.go
@@ -0,0 +1,214 @@
+package api
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync/atomic"
+ "testing"
+
+ "github.com/thepeterstone/claudomator/internal/llm"
+)
+
+// fakeChatCompletionsServer returns an httptest server that responds to a
+// /chat/completions POST with the given assistant content (which should be a
+// JSON-encoded elaboratedTask). Returns the server and a counter of calls
+// received so tests can assert dispatch ordering.
+func fakeChatCompletionsServer(t *testing.T, assistantContent string) (*httptest.Server, *int32) {
+ t.Helper()
+ var calls int32
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ atomic.AddInt32(&calls, 1)
+ w.Header().Set("Content-Type", "application/json")
+ // The assistant content has to be JSON-encoded inside the wire format.
+ escaped, _ := json.Marshal(assistantContent)
+ fmt.Fprintf(w, `{
+ "model":"local",
+ "choices":[{"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],
+ "usage":{"prompt_tokens":10,"completion_tokens":50}
+ }`, string(escaped))
+ }))
+ t.Cleanup(srv.Close)
+ return srv, &calls
+}
+
+func TestElaborateWithLocal_ParsesValidResponse(t *testing.T) {
+ taskBody, _ := json.Marshal(elaboratedTask{
+ Name: "Test elaborated task",
+ Description: "From local llm",
+ Agent: elaboratedAgent{
+ Type: "claude",
+ Model: "sonnet",
+ Instructions: "Run go build.",
+ MaxBudgetUSD: 0.25,
+ AllowedTools: []string{"Bash"},
+ },
+ Timeout: "10m",
+ Priority: "normal",
+ Tags: []string{"build"},
+ })
+ srv, calls := fakeChatCompletionsServer(t, string(taskBody))
+
+ c := &llm.Client{Endpoint: srv.URL + "/v1", Model: "fake"}
+ result, err := elaborateWithLocal(context.Background(), c, "/some/dir", "build the project")
+ if err != nil {
+ t.Fatalf("elaborateWithLocal: %v", err)
+ }
+ if result.Name != "Test elaborated task" {
+ t.Errorf("Name: %q", result.Name)
+ }
+ if result.Agent.Instructions != "Run go build." {
+ t.Errorf("Instructions: %q", result.Agent.Instructions)
+ }
+ if got := atomic.LoadInt32(calls); got != 1 {
+ t.Errorf("expected 1 call, got %d", got)
+ }
+}
+
+func TestElaborateWithLocal_NilClient(t *testing.T) {
+ _, err := elaborateWithLocal(context.Background(), nil, "", "p")
+ if err == nil || !strings.Contains(err.Error(), "no client") {
+ t.Errorf("expected nil-client error, got %v", err)
+ }
+}
+
+func TestElaborateWithLocal_BadJSON(t *testing.T) {
+ srv, _ := fakeChatCompletionsServer(t, "this is not JSON at all")
+ c := &llm.Client{Endpoint: srv.URL + "/v1", Model: "fake"}
+ _, err := elaborateWithLocal(context.Background(), c, "", "p")
+ if err == nil || !strings.Contains(err.Error(), "parse JSON") {
+ t.Errorf("expected parse error, got %v", err)
+ }
+}
+
+// TestElaborateTask_LocalLLMPreferred verifies the dispatcher uses local LLM
+// when SetLLM is configured, and does not invoke claude.
+func TestElaborateTask_LocalLLMPreferred(t *testing.T) {
+ srv, _ := testServer(t)
+
+ taskBody, _ := json.Marshal(elaboratedTask{
+ Name: "Local-elaborated",
+ Description: "From local",
+ Agent: elaboratedAgent{
+ Type: "claude",
+ Model: "sonnet",
+ Instructions: "Do work. Tests pass when complete.",
+ MaxBudgetUSD: 0.25,
+ AllowedTools: []string{"Bash"},
+ },
+ Timeout: "10m",
+ Priority: "normal",
+ })
+ llmSrv, _ := fakeChatCompletionsServer(t, string(taskBody))
+ srv.SetLLM(&llm.Client{Endpoint: llmSrv.URL + "/v1", Model: "fake"})
+ // Point Claude binary at a path that would fail if called.
+ srv.elaborateCmdPath = "/nonexistent/claude-should-not-run"
+
+ body := `{"prompt":"do work"}`
+ req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var got elaboratedTask
+ if err := json.NewDecoder(w.Body).Decode(&got); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if got.Name != "Local-elaborated" {
+ t.Errorf("Name: want Local-elaborated got %q", got.Name)
+ }
+}
+
+// TestElaborateTask_LocalFails_FallsBackToClaude verifies the dispatcher
+// falls back to the Claude path when the local LLM returns an error.
+func TestElaborateTask_LocalFails_FallsBackToClaude(t *testing.T) {
+ srv, _ := testServer(t)
+
+ // Local LLM server that always 500s.
+ failSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "boom", http.StatusInternalServerError)
+ }))
+ t.Cleanup(failSrv.Close)
+ srv.SetLLM(&llm.Client{Endpoint: failSrv.URL + "/v1", Model: "fake"})
+
+ // Configure a working fake Claude binary.
+ taskBody, _ := json.Marshal(elaboratedTask{
+ Name: "Claude-fallback",
+ Description: "From claude after local failed",
+ Agent: elaboratedAgent{
+ Type: "claude",
+ Model: "sonnet",
+ Instructions: "Run tests.",
+ MaxBudgetUSD: 0.25,
+ AllowedTools: []string{"Bash"},
+ },
+ Timeout: "10m",
+ Priority: "normal",
+ })
+ wrapper, _ := json.Marshal(map[string]string{"result": string(taskBody)})
+ srv.elaborateCmdPath = createFakeClaude(t, string(wrapper), 0)
+
+ body := `{"prompt":"run tests"}`
+ req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var got elaboratedTask
+ if err := json.NewDecoder(w.Body).Decode(&got); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if got.Name != "Claude-fallback" {
+ t.Errorf("Name: want Claude-fallback (fallback path) got %q", got.Name)
+ }
+}
+
+// TestElaborateTask_NoLocalLLM_UsesClaude verifies that when SetLLM is not
+// called, behavior is unchanged (Claude path still primary).
+func TestElaborateTask_NoLocalLLM_UsesClaude(t *testing.T) {
+ srv, _ := testServer(t)
+
+ taskBody, _ := json.Marshal(elaboratedTask{
+ Name: "Claude-only",
+ Description: "no local llm configured",
+ Agent: elaboratedAgent{
+ Type: "claude",
+ Model: "sonnet",
+ Instructions: "Do work.",
+ MaxBudgetUSD: 0.25,
+ AllowedTools: []string{"Bash"},
+ },
+ Timeout: "10m",
+ Priority: "normal",
+ })
+ wrapper, _ := json.Marshal(map[string]string{"result": string(taskBody)})
+ srv.elaborateCmdPath = createFakeClaude(t, string(wrapper), 0)
+
+ body := `{"prompt":"do work"}`
+ req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var got elaboratedTask
+ if err := json.NewDecoder(w.Body).Decode(&got); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if got.Name != "Claude-only" {
+ t.Errorf("Name: %q", got.Name)
+ }
+}
+