diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-03-08 20:40:31 +0000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-03-08 20:40:31 +0000 |
| commit | 417034be7f745062901a940d1a021f6d85be496e (patch) | |
| tree | 666956207b58c915090f6641891304156cf93670 /internal/api/server_test.go | |
| parent | 181a37698410b68e00a885593b6f2b7acf21f4b4 (diff) | |
api: SetAPIToken, SetNotifier, questionStore, per-IP rate limiter
- Extract questionStore interface for testability of handleAnswerQuestion
- Add SetAPIToken/SetNotifier methods for post-construction wiring
- Extract processResult() from forwardResults() for direct testability
- Add ipRateLimiter with token-bucket per IP; applied to /elaborate and /validate
- Fix tests for running-task deletion and retry-limit that relied on
invalid state transitions in setup
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal/api/server_test.go')
| -rw-r--r-- | internal/api/server_test.go | 344 |
1 files changed, 334 insertions, 10 deletions
diff --git a/internal/api/server_test.go b/internal/api/server_test.go index e012bc1..c3b12ce 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -9,15 +9,70 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" + "time" "context" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" ) +// mockNotifier records calls to Notify. +type mockNotifier struct { + events []notify.Event +} + +func (m *mockNotifier) Notify(e notify.Event) error { + m.events = append(m.events, e) + return nil +} + +func TestServer_ProcessResult_CallsNotifier(t *testing.T) { + srv, store := testServer(t) + + mn := &mockNotifier{} + srv.SetNotifier(mn) + + tk := &task.Task{ + ID: "task-notifier-test", + Name: "Notifier Task", + State: task.StatePending, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + + result := &executor.Result{ + TaskID: tk.ID, + Execution: &storage.Execution{ + ID: "exec-1", + TaskID: tk.ID, + Status: "COMPLETED", + CostUSD: 0.42, + ErrorMsg: "", + }, + } + srv.processResult(result) + + if len(mn.events) != 1 { + t.Fatalf("expected 1 notify event, got %d", len(mn.events)) + } + ev := mn.events[0] + if ev.TaskID != tk.ID { + t.Errorf("event.TaskID = %q, want %q", ev.TaskID, tk.ID) + } + if ev.Status != "COMPLETED" { + t.Errorf("event.Status = %q, want COMPLETED", ev.Status) + } + if ev.CostUSD != 0.42 { + t.Errorf("event.CostUSD = %v, want 0.42", ev.CostUSD) + } +} + func testServer(t *testing.T) (*Server, *storage.DB) { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") @@ -170,6 +225,20 @@ func TestListTasks_WithTasks(t *testing.T) { } } +// stateWalkPaths defines the sequence of intermediate states needed to reach each target state. +var stateWalkPaths = map[task.State][]task.State{ + task.StatePending: {}, + task.StateQueued: {task.StateQueued}, + task.StateRunning: {task.StateQueued, task.StateRunning}, + task.StateCompleted: {task.StateQueued, task.StateRunning, task.StateCompleted}, + task.StateFailed: {task.StateQueued, task.StateRunning, task.StateFailed}, + task.StateTimedOut: {task.StateQueued, task.StateRunning, task.StateTimedOut}, + task.StateCancelled: {task.StateCancelled}, + task.StateBudgetExceeded: {task.StateQueued, task.StateRunning, task.StateBudgetExceeded}, + task.StateReady: {task.StateQueued, task.StateRunning, task.StateReady}, + task.StateBlocked: {task.StateQueued, task.StateRunning, task.StateBlocked}, +} + func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.State) *task.Task { t.Helper() tk := &task.Task{ @@ -183,9 +252,9 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task. if err := store.CreateTask(tk); err != nil { t.Fatalf("createTaskWithState: CreateTask: %v", err) } - if state != task.StatePending { - if err := store.UpdateTaskState(id, state); err != nil { - t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", state, err) + for _, s := range stateWalkPaths[state] { + if err := store.UpdateTaskState(id, s); err != nil { + t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", s, err) } } tk.State = state @@ -420,7 +489,7 @@ func TestHandleStartNextTask_Success(t *testing.T) { } srv, _ := testServer(t) - srv.startNextTaskScript = script + srv.SetScripts(ScriptRegistry{"start-next-task": script}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -447,7 +516,7 @@ func TestHandleStartNextTask_NoTask(t *testing.T) { } srv, _ := testServer(t) - srv.startNextTaskScript = script + srv.SetScripts(ScriptRegistry{"start-next-task": script}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -530,9 +599,87 @@ func TestResumeTimedOut_Success_Returns202(t *testing.T) { } } +func TestRunTask_RetryLimitReached_Returns409(t *testing.T) { + srv, store := testServer(t) + // Task with MaxAttempts: 1 — only 1 attempt allowed. Create directly as FAILED + // so state is consistent without going through transition sequence. + tk := &task.Task{ + ID: "retry-limit-1", + Name: "Retry Limit Task", + Claude: task.ClaudeConfig{Instructions: "do something"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateFailed, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + // Record one execution — the first attempt already used. + exec := &storage.Execution{ + ID: "exec-retry-1", + TaskID: "retry-limit-1", + StartTime: time.Now(), + Status: "FAILED", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/api/tasks/retry-limit-1/run", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if !strings.Contains(body["error"], "retry limit") { + t.Errorf("error body should mention retry limit, got %q", body["error"]) + } +} + +func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) { + srv, store := testServer(t) + // Task with MaxAttempts: 3 — 1 execution used, 2 remaining. + tk := &task.Task{ + ID: "retry-within-1", + Name: "Retry Within Task", + Claude: task.ClaudeConfig{Instructions: "do something"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StatePending, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + exec := &storage.Execution{ + ID: "exec-within-1", + TaskID: "retry-within-1", + StartTime: time.Now(), + Status: "FAILED", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + store.UpdateTaskState("retry-within-1", task.StateFailed) + + req := httptest.NewRequest("POST", "/api/tasks/retry-within-1/run", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Errorf("status: want 202, got %d; body: %s", w.Code, w.Body.String()) + } +} + func TestHandleStartNextTask_ScriptNotFound(t *testing.T) { srv, _ := testServer(t) - srv.startNextTaskScript = "/nonexistent/start-next-task" + srv.SetScripts(ScriptRegistry{"start-next-task": "/nonexistent/start-next-task"}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -578,10 +725,20 @@ func TestDeleteTask_NotFound(t *testing.T) { func TestDeleteTask_RunningTaskRejected(t *testing.T) { srv, store := testServer(t) - created := createTestTask(t, srv, `{"name":"Running Task","claude":{"instructions":"x","model":"sonnet"}}`) - store.UpdateTaskState(created.ID, "RUNNING") - - req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil) + // Create the task directly in RUNNING state to avoid going through state transitions. + tk := &task.Task{ + ID: "running-task-del", + Name: "Running Task", + Claude: task.ClaudeConfig{Instructions: "x", Model: "sonnet"}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateRunning, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + req := httptest.NewRequest("DELETE", "/api/tasks/running-task-del", nil) w := httptest.NewRecorder() srv.Handler().ServeHTTP(w, req) @@ -657,3 +814,170 @@ func TestServer_CancelTask_Completed_Returns409(t *testing.T) { t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) } } + +// mockQuestionStore implements questionStore for testing handleAnswerQuestion. +type mockQuestionStore struct { + getTaskFn func(id string) (*task.Task, error) + getLatestExecutionFn func(taskID string) (*storage.Execution, error) + updateTaskQuestionFn func(taskID, questionJSON string) error + updateTaskStateFn func(id string, newState task.State) error +} + +func (m *mockQuestionStore) GetTask(id string) (*task.Task, error) { + return m.getTaskFn(id) +} +func (m *mockQuestionStore) GetLatestExecution(taskID string) (*storage.Execution, error) { + return m.getLatestExecutionFn(taskID) +} +func (m *mockQuestionStore) UpdateTaskQuestion(taskID, questionJSON string) error { + return m.updateTaskQuestionFn(taskID, questionJSON) +} +func (m *mockQuestionStore) UpdateTaskState(id string, newState task.State) error { + return m.updateTaskStateFn(id, newState) +} + +func TestServer_AnswerQuestion_UpdateQuestionFails_Returns500(t *testing.T) { + srv, _ := testServer(t) + srv.questionStore = &mockQuestionStore{ + getTaskFn: func(id string) (*task.Task, error) { + return &task.Task{ID: id, State: task.StateBlocked}, nil + }, + getLatestExecutionFn: func(taskID string) (*storage.Execution, error) { + return &storage.Execution{SessionID: "sess-1"}, nil + }, + updateTaskQuestionFn: func(taskID, questionJSON string) error { + return fmt.Errorf("db error") + }, + updateTaskStateFn: func(id string, newState task.State) error { + return nil + }, + } + + body := bytes.NewBufferString(`{"answer":"yes"}`) + req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestServer_AnswerQuestion_UpdateStateFails_Returns500(t *testing.T) { + srv, _ := testServer(t) + srv.questionStore = &mockQuestionStore{ + getTaskFn: func(id string) (*task.Task, error) { + return &task.Task{ID: id, State: task.StateBlocked}, nil + }, + getLatestExecutionFn: func(taskID string) (*storage.Execution, error) { + return &storage.Execution{SessionID: "sess-1"}, nil + }, + updateTaskQuestionFn: func(taskID, questionJSON string) error { + return nil + }, + updateTaskStateFn: func(id string, newState task.State) error { + return fmt.Errorf("db error") + }, + } + + body := bytes.NewBufferString(`{"answer":"yes"}`) + req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestRateLimit_ElaborateRejectsExcess(t *testing.T) { + srv, _ := testServer(t) + // Use burst-1 and rate-0 so the second request from the same IP is rejected. + srv.elaborateLimiter = newIPRateLimiter(0, 1) + + makeReq := func(remoteAddr string) int { + req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(`{"description":"x"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = remoteAddr + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + return w.Code + } + + // First request from IP A: limiter allows it (non-429). + if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests { + t.Errorf("first request should not be rate limited, got 429") + } + // Second request from IP A: bucket exhausted, must be 429. + if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests { + t.Errorf("second request from same IP should be 429, got %d", code) + } + // First request from IP B: separate bucket, not limited. + if code := makeReq("192.0.2.2:1234"); code == http.StatusTooManyRequests { + t.Errorf("first request from different IP should not be rate limited, got 429") + } +} + +func TestListWorkspaces_RequiresAuth(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + // No token: expect 401. + req := httptest.NewRequest("GET", "/api/workspaces", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 without token, got %d", w.Code) + } +} + +func TestListWorkspaces_RejectsWrongToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + req := httptest.NewRequest("GET", "/api/workspaces", nil) + req.Header.Set("Authorization", "Bearer wrong-token") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 with wrong token, got %d", w.Code) + } +} + +func TestListWorkspaces_SuppressesRawError(t *testing.T) { + srv, _ := testServer(t) + // No token configured so auth is bypassed; /workspace likely doesn't exist in test env. + + req := httptest.NewRequest("GET", "/api/workspaces", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code == http.StatusInternalServerError { + body := w.Body.String() + if strings.Contains(body, "/workspace") || strings.Contains(body, "no such file") { + t.Errorf("response leaks filesystem details: %s", body) + } + } +} + +func TestRateLimit_ValidateRejectsExcess(t *testing.T) { + srv, _ := testServer(t) + srv.elaborateLimiter = newIPRateLimiter(0, 1) + + makeReq := func(remoteAddr string) int { + req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(`{"description":"x"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = remoteAddr + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + return w.Code + } + + if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests { + t.Errorf("first validate request should not be rate limited, got 429") + } + if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests { + t.Errorf("second validate request from same IP should be 429, got %d", code) + } +} |
