summaryrefslogtreecommitdiff
path: root/internal/executor/classifier_test.go
blob: 84fffcfe7cda4b3be0c64ed321731138f36260ec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package executor

import (
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"

	"github.com/thepeterstone/claudomator/internal/llm"
)

// TestClassifier_Classify_Mock tests the classifier with a mocked gemini binary.
func TestClassifier_Classify_Mock(t *testing.T) {
	// Create a temporary mock binary.
	mockBinary := filepathJoin(t.TempDir(), "mock-gemini")
	mockContent := `#!/bin/sh
echo '{"response": "{\"agent_type\": \"gemini\", \"model\": \"gemini-2.5-flash-lite\", \"reason\": \"test reason\"}"}'
`
	if err := os.WriteFile(mockBinary, []byte(mockContent), 0755); err != nil {
		t.Fatal(err)
	}

	c := &Classifier{GeminiBinaryPath: mockBinary}
	status := SystemStatus{
		ActiveTasks: map[string]int{"claude": 5, "gemini": 1},
		RateLimited: map[string]bool{"claude": false, "gemini": false},
	}

	cls, err := c.Classify(context.Background(), "Test Task", "Test Instructions", status, "gemini")
	if err != nil {
		t.Fatalf("Classify failed: %v", err)
	}

	if cls.AgentType != "gemini" {
		t.Errorf("expected gemini, got %s", cls.AgentType)
	}
	if cls.Model != "gemini-2.5-flash-lite" {
		t.Errorf("expected gemini-2.5-flash-lite, got %s", cls.Model)
	}
}

// TestClassifier_Classify_LLM tests classification through a local OpenAI-compatible LLM.
func TestClassifier_Classify_LLM(t *testing.T) {
	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Verify the classifier asked for JSON mode.
		var body struct {
			ResponseFormat *struct {
				Type string `json:"type"`
			} `json:"response_format"`
		}
		if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
			t.Fatalf("decode body: %v", err)
		}
		if body.ResponseFormat == nil || body.ResponseFormat.Type != "json_object" {
			t.Errorf("classifier should request json_object response format")
		}

		w.Header().Set("Content-Type", "application/json")
		fmt.Fprintln(w, `{
			"model":"local-fast",
			"choices":[{"message":{"role":"assistant","content":"{\"agent_type\":\"claude\",\"model\":\"claude-haiku-4-5-20251001\",\"reason\":\"trivial task\"}"},"finish_reason":"stop"}],
			"usage":{"prompt_tokens":10,"completion_tokens":15}
		}`)
	}))
	defer srv.Close()

	c := &Classifier{
		LLM: &llm.Client{Endpoint: srv.URL + "/v1", Model: "local-fast"},
	}
	status := SystemStatus{
		ActiveTasks: map[string]int{"claude": 1, "gemini": 0},
		RateLimited: map[string]bool{},
	}

	cls, err := c.Classify(context.Background(), "List files", "ls -la", status, "claude")
	if err != nil {
		t.Fatalf("Classify: %v", err)
	}
	if cls.AgentType != "claude" {
		t.Errorf("AgentType: want claude got %q", cls.AgentType)
	}
	if cls.Model != "claude-haiku-4-5-20251001" {
		t.Errorf("Model: want claude-haiku-4-5-20251001 got %q", cls.Model)
	}
	if !strings.Contains(cls.Reason, "trivial") {
		t.Errorf("Reason mismatch: %q", cls.Reason)
	}
}

// TestClassifier_LLMTakesPrecedence_OverGemini ensures the LLM path is preferred when both are configured.
func TestClassifier_LLMTakesPrecedence_OverGemini(t *testing.T) {
	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		fmt.Fprintln(w, `{"model":"x","choices":[{"message":{"content":"{\"agent_type\":\"claude\",\"model\":\"claude-sonnet-4-6\",\"reason\":\"r\"}"},"finish_reason":"stop"}],"usage":{}}`)
	}))
	defer srv.Close()

	c := &Classifier{
		LLM:              &llm.Client{Endpoint: srv.URL + "/v1", Model: "x"},
		GeminiBinaryPath: "/nonexistent/gemini-binary-should-not-be-called",
	}
	cls, err := c.Classify(context.Background(), "n", "i", SystemStatus{}, "claude")
	if err != nil {
		t.Fatalf("Classify: %v", err)
	}
	if cls.Model != "claude-sonnet-4-6" {
		t.Errorf("expected LLM path; got Model=%q", cls.Model)
	}
}

func filepathJoin(elems ...string) string {
	var path string
	for i, e := range elems {
		if i == 0 {
			path = e
		} else {
			path = path + string(os.PathSeparator) + e
		}
	}
	return path
}