summaryrefslogtreecommitdiff
path: root/internal/api/server_test.go
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-08 20:40:31 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-08 20:40:31 +0000
commit417034be7f745062901a940d1a021f6d85be496e (patch)
tree666956207b58c915090f6641891304156cf93670 /internal/api/server_test.go
parent181a37698410b68e00a885593b6f2b7acf21f4b4 (diff)
api: SetAPIToken, SetNotifier, questionStore, per-IP rate limiter
- Extract questionStore interface for testability of handleAnswerQuestion - Add SetAPIToken/SetNotifier methods for post-construction wiring - Extract processResult() from forwardResults() for direct testability - Add ipRateLimiter with token-bucket per IP; applied to /elaborate and /validate - Fix tests for running-task deletion and retry-limit that relied on invalid state transitions in setup Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal/api/server_test.go')
-rw-r--r--internal/api/server_test.go344
1 files changed, 334 insertions, 10 deletions
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index e012bc1..c3b12ce 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -9,15 +9,70 @@ import (
"net/http/httptest"
"os"
"path/filepath"
+ "strings"
"testing"
+ "time"
"context"
"github.com/thepeterstone/claudomator/internal/executor"
+ "github.com/thepeterstone/claudomator/internal/notify"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
)
+// mockNotifier records calls to Notify.
+type mockNotifier struct {
+ events []notify.Event
+}
+
+func (m *mockNotifier) Notify(e notify.Event) error {
+ m.events = append(m.events, e)
+ return nil
+}
+
+func TestServer_ProcessResult_CallsNotifier(t *testing.T) {
+ srv, store := testServer(t)
+
+ mn := &mockNotifier{}
+ srv.SetNotifier(mn)
+
+ tk := &task.Task{
+ ID: "task-notifier-test",
+ Name: "Notifier Task",
+ State: task.StatePending,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+
+ result := &executor.Result{
+ TaskID: tk.ID,
+ Execution: &storage.Execution{
+ ID: "exec-1",
+ TaskID: tk.ID,
+ Status: "COMPLETED",
+ CostUSD: 0.42,
+ ErrorMsg: "",
+ },
+ }
+ srv.processResult(result)
+
+ if len(mn.events) != 1 {
+ t.Fatalf("expected 1 notify event, got %d", len(mn.events))
+ }
+ ev := mn.events[0]
+ if ev.TaskID != tk.ID {
+ t.Errorf("event.TaskID = %q, want %q", ev.TaskID, tk.ID)
+ }
+ if ev.Status != "COMPLETED" {
+ t.Errorf("event.Status = %q, want COMPLETED", ev.Status)
+ }
+ if ev.CostUSD != 0.42 {
+ t.Errorf("event.CostUSD = %v, want 0.42", ev.CostUSD)
+ }
+}
+
func testServer(t *testing.T) (*Server, *storage.DB) {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test.db")
@@ -170,6 +225,20 @@ func TestListTasks_WithTasks(t *testing.T) {
}
}
+// stateWalkPaths defines the sequence of intermediate states needed to reach each target state.
+var stateWalkPaths = map[task.State][]task.State{
+ task.StatePending: {},
+ task.StateQueued: {task.StateQueued},
+ task.StateRunning: {task.StateQueued, task.StateRunning},
+ task.StateCompleted: {task.StateQueued, task.StateRunning, task.StateCompleted},
+ task.StateFailed: {task.StateQueued, task.StateRunning, task.StateFailed},
+ task.StateTimedOut: {task.StateQueued, task.StateRunning, task.StateTimedOut},
+ task.StateCancelled: {task.StateCancelled},
+ task.StateBudgetExceeded: {task.StateQueued, task.StateRunning, task.StateBudgetExceeded},
+ task.StateReady: {task.StateQueued, task.StateRunning, task.StateReady},
+ task.StateBlocked: {task.StateQueued, task.StateRunning, task.StateBlocked},
+}
+
func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.State) *task.Task {
t.Helper()
tk := &task.Task{
@@ -183,9 +252,9 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.
if err := store.CreateTask(tk); err != nil {
t.Fatalf("createTaskWithState: CreateTask: %v", err)
}
- if state != task.StatePending {
- if err := store.UpdateTaskState(id, state); err != nil {
- t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", state, err)
+ for _, s := range stateWalkPaths[state] {
+ if err := store.UpdateTaskState(id, s); err != nil {
+ t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", s, err)
}
}
tk.State = state
@@ -420,7 +489,7 @@ func TestHandleStartNextTask_Success(t *testing.T) {
}
srv, _ := testServer(t)
- srv.startNextTaskScript = script
+ srv.SetScripts(ScriptRegistry{"start-next-task": script})
req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
w := httptest.NewRecorder()
@@ -447,7 +516,7 @@ func TestHandleStartNextTask_NoTask(t *testing.T) {
}
srv, _ := testServer(t)
- srv.startNextTaskScript = script
+ srv.SetScripts(ScriptRegistry{"start-next-task": script})
req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
w := httptest.NewRecorder()
@@ -530,9 +599,87 @@ func TestResumeTimedOut_Success_Returns202(t *testing.T) {
}
}
+func TestRunTask_RetryLimitReached_Returns409(t *testing.T) {
+ srv, store := testServer(t)
+ // Task with MaxAttempts: 1 — only 1 attempt allowed. Create directly as FAILED
+ // so state is consistent without going through transition sequence.
+ tk := &task.Task{
+ ID: "retry-limit-1",
+ Name: "Retry Limit Task",
+ Claude: task.ClaudeConfig{Instructions: "do something"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StateFailed,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ // Record one execution — the first attempt already used.
+ exec := &storage.Execution{
+ ID: "exec-retry-1",
+ TaskID: "retry-limit-1",
+ StartTime: time.Now(),
+ Status: "FAILED",
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatal(err)
+ }
+
+ req := httptest.NewRequest("POST", "/api/tasks/retry-limit-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusConflict {
+ t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if !strings.Contains(body["error"], "retry limit") {
+ t.Errorf("error body should mention retry limit, got %q", body["error"])
+ }
+}
+
+func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) {
+ srv, store := testServer(t)
+ // Task with MaxAttempts: 3 — 1 execution used, 2 remaining.
+ tk := &task.Task{
+ ID: "retry-within-1",
+ Name: "Retry Within Task",
+ Claude: task.ClaudeConfig{Instructions: "do something"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StatePending,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ exec := &storage.Execution{
+ ID: "exec-within-1",
+ TaskID: "retry-within-1",
+ StartTime: time.Now(),
+ Status: "FAILED",
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatal(err)
+ }
+ store.UpdateTaskState("retry-within-1", task.StateFailed)
+
+ req := httptest.NewRequest("POST", "/api/tasks/retry-within-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Errorf("status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
func TestHandleStartNextTask_ScriptNotFound(t *testing.T) {
srv, _ := testServer(t)
- srv.startNextTaskScript = "/nonexistent/start-next-task"
+ srv.SetScripts(ScriptRegistry{"start-next-task": "/nonexistent/start-next-task"})
req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
w := httptest.NewRecorder()
@@ -578,10 +725,20 @@ func TestDeleteTask_NotFound(t *testing.T) {
func TestDeleteTask_RunningTaskRejected(t *testing.T) {
srv, store := testServer(t)
- created := createTestTask(t, srv, `{"name":"Running Task","claude":{"instructions":"x","model":"sonnet"}}`)
- store.UpdateTaskState(created.ID, "RUNNING")
-
- req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil)
+ // Create the task directly in RUNNING state to avoid going through state transitions.
+ tk := &task.Task{
+ ID: "running-task-del",
+ Name: "Running Task",
+ Claude: task.ClaudeConfig{Instructions: "x", Model: "sonnet"},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StateRunning,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ req := httptest.NewRequest("DELETE", "/api/tasks/running-task-del", nil)
w := httptest.NewRecorder()
srv.Handler().ServeHTTP(w, req)
@@ -657,3 +814,170 @@ func TestServer_CancelTask_Completed_Returns409(t *testing.T) {
t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String())
}
}
+
+// mockQuestionStore implements questionStore for testing handleAnswerQuestion.
+type mockQuestionStore struct {
+ getTaskFn func(id string) (*task.Task, error)
+ getLatestExecutionFn func(taskID string) (*storage.Execution, error)
+ updateTaskQuestionFn func(taskID, questionJSON string) error
+ updateTaskStateFn func(id string, newState task.State) error
+}
+
+func (m *mockQuestionStore) GetTask(id string) (*task.Task, error) {
+ return m.getTaskFn(id)
+}
+func (m *mockQuestionStore) GetLatestExecution(taskID string) (*storage.Execution, error) {
+ return m.getLatestExecutionFn(taskID)
+}
+func (m *mockQuestionStore) UpdateTaskQuestion(taskID, questionJSON string) error {
+ return m.updateTaskQuestionFn(taskID, questionJSON)
+}
+func (m *mockQuestionStore) UpdateTaskState(id string, newState task.State) error {
+ return m.updateTaskStateFn(id, newState)
+}
+
+func TestServer_AnswerQuestion_UpdateQuestionFails_Returns500(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.questionStore = &mockQuestionStore{
+ getTaskFn: func(id string) (*task.Task, error) {
+ return &task.Task{ID: id, State: task.StateBlocked}, nil
+ },
+ getLatestExecutionFn: func(taskID string) (*storage.Execution, error) {
+ return &storage.Execution{SessionID: "sess-1"}, nil
+ },
+ updateTaskQuestionFn: func(taskID, questionJSON string) error {
+ return fmt.Errorf("db error")
+ },
+ updateTaskStateFn: func(id string, newState task.State) error {
+ return nil
+ },
+ }
+
+ body := bytes.NewBufferString(`{"answer":"yes"}`)
+ req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body)
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestServer_AnswerQuestion_UpdateStateFails_Returns500(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.questionStore = &mockQuestionStore{
+ getTaskFn: func(id string) (*task.Task, error) {
+ return &task.Task{ID: id, State: task.StateBlocked}, nil
+ },
+ getLatestExecutionFn: func(taskID string) (*storage.Execution, error) {
+ return &storage.Execution{SessionID: "sess-1"}, nil
+ },
+ updateTaskQuestionFn: func(taskID, questionJSON string) error {
+ return nil
+ },
+ updateTaskStateFn: func(id string, newState task.State) error {
+ return fmt.Errorf("db error")
+ },
+ }
+
+ body := bytes.NewBufferString(`{"answer":"yes"}`)
+ req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body)
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestRateLimit_ElaborateRejectsExcess(t *testing.T) {
+ srv, _ := testServer(t)
+ // Use burst-1 and rate-0 so the second request from the same IP is rejected.
+ srv.elaborateLimiter = newIPRateLimiter(0, 1)
+
+ makeReq := func(remoteAddr string) int {
+ req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(`{"description":"x"}`))
+ req.Header.Set("Content-Type", "application/json")
+ req.RemoteAddr = remoteAddr
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ return w.Code
+ }
+
+ // First request from IP A: limiter allows it (non-429).
+ if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests {
+ t.Errorf("first request should not be rate limited, got 429")
+ }
+ // Second request from IP A: bucket exhausted, must be 429.
+ if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests {
+ t.Errorf("second request from same IP should be 429, got %d", code)
+ }
+ // First request from IP B: separate bucket, not limited.
+ if code := makeReq("192.0.2.2:1234"); code == http.StatusTooManyRequests {
+ t.Errorf("first request from different IP should not be rate limited, got 429")
+ }
+}
+
+func TestListWorkspaces_RequiresAuth(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+
+ // No token: expect 401.
+ req := httptest.NewRequest("GET", "/api/workspaces", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("expected 401 without token, got %d", w.Code)
+ }
+}
+
+func TestListWorkspaces_RejectsWrongToken(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+
+ req := httptest.NewRequest("GET", "/api/workspaces", nil)
+ req.Header.Set("Authorization", "Bearer wrong-token")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("expected 401 with wrong token, got %d", w.Code)
+ }
+}
+
+func TestListWorkspaces_SuppressesRawError(t *testing.T) {
+ srv, _ := testServer(t)
+ // No token configured so auth is bypassed; /workspace likely doesn't exist in test env.
+
+ req := httptest.NewRequest("GET", "/api/workspaces", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code == http.StatusInternalServerError {
+ body := w.Body.String()
+ if strings.Contains(body, "/workspace") || strings.Contains(body, "no such file") {
+ t.Errorf("response leaks filesystem details: %s", body)
+ }
+ }
+}
+
+func TestRateLimit_ValidateRejectsExcess(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.elaborateLimiter = newIPRateLimiter(0, 1)
+
+ makeReq := func(remoteAddr string) int {
+ req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(`{"description":"x"}`))
+ req.Header.Set("Content-Type", "application/json")
+ req.RemoteAddr = remoteAddr
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ return w.Code
+ }
+
+ if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests {
+ t.Errorf("first validate request should not be rate limited, got 429")
+ }
+ if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests {
+ t.Errorf("second validate request from same IP should be 429, got %d", code)
+ }
+}