diff options
Diffstat (limited to 'internal/api')
| -rw-r--r-- | internal/api/elaborate.go | 5 | ||||
| -rw-r--r-- | internal/api/ratelimit.go | 99 | ||||
| -rw-r--r-- | internal/api/server.go | 143 | ||||
| -rw-r--r-- | internal/api/server_test.go | 344 | ||||
| -rw-r--r-- | internal/api/validate.go | 5 |
5 files changed, 557 insertions, 39 deletions
diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go index e480e00..8a18dee 100644 --- a/internal/api/elaborate.go +++ b/internal/api/elaborate.go @@ -85,6 +85,11 @@ func (s *Server) claudeBinaryPath() string { } func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { + if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + var input struct { Prompt string `json:"prompt"` ProjectDir string `json:"project_dir"` diff --git a/internal/api/ratelimit.go b/internal/api/ratelimit.go new file mode 100644 index 0000000..089354c --- /dev/null +++ b/internal/api/ratelimit.go @@ -0,0 +1,99 @@ +package api + +import ( + "net" + "net/http" + "sync" + "time" +) + +// ipRateLimiter provides per-IP token-bucket rate limiting. +type ipRateLimiter struct { + mu sync.Mutex + limiters map[string]*tokenBucket + rate float64 // tokens replenished per second + burst int // maximum token capacity +} + +// newIPRateLimiter creates a limiter with the given replenishment rate (tokens/sec) +// and burst capacity. Use rate=0 to disable replenishment (tokens never refill). +func newIPRateLimiter(rate float64, burst int) *ipRateLimiter { + return &ipRateLimiter{ + limiters: make(map[string]*tokenBucket), + rate: rate, + burst: burst, + } +} + +func (l *ipRateLimiter) allow(ip string) bool { + l.mu.Lock() + b, ok := l.limiters[ip] + if !ok { + b = &tokenBucket{ + tokens: float64(l.burst), + capacity: float64(l.burst), + rate: l.rate, + lastTime: time.Now(), + } + l.limiters[ip] = b + } + l.mu.Unlock() + return b.allow() +} + +// middleware wraps h with per-IP rate limiting, returning 429 when exceeded. +func (l *ipRateLimiter) middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := realIP(r) + if !l.allow(ip) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + next.ServeHTTP(w, r) + }) +} + +// tokenBucket is a simple token-bucket rate limiter for a single key. +type tokenBucket struct { + mu sync.Mutex + tokens float64 + capacity float64 + rate float64 // tokens per second + lastTime time.Time +} + +func (b *tokenBucket) allow() bool { + b.mu.Lock() + defer b.mu.Unlock() + now := time.Now() + if !b.lastTime.IsZero() { + elapsed := now.Sub(b.lastTime).Seconds() + b.tokens = min(b.capacity, b.tokens+elapsed*b.rate) + } + b.lastTime = now + if b.tokens >= 1.0 { + b.tokens-- + return true + } + return false +} + +// realIP extracts the client's real IP from a request. +func realIP(r *http.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + for i, c := range xff { + if c == ',' { + return xff[:i] + } + } + return xff + } + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} diff --git a/internal/api/server.go b/internal/api/server.go index af4710b..833be8b 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -7,39 +7,63 @@ import ( "log/slog" "net/http" "os" + "strings" "time" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" webui "github.com/thepeterstone/claudomator/web" "github.com/google/uuid" ) +// questionStore is the minimal storage interface needed by handleAnswerQuestion. +type questionStore interface { + GetTask(id string) (*task.Task, error) + GetLatestExecution(taskID string) (*storage.Execution, error) + UpdateTaskQuestion(taskID, questionJSON string) error + UpdateTaskState(id string, newState task.State) error +} + // Server provides the REST API and WebSocket endpoint for Claudomator. type Server struct { store *storage.DB - logStore logStore // injectable for tests; defaults to store - taskLogStore taskLogStore // injectable for tests; defaults to store + logStore logStore // injectable for tests; defaults to store + taskLogStore taskLogStore // injectable for tests; defaults to store + questionStore questionStore // injectable for tests; defaults to store pool *executor.Pool hub *Hub logger *slog.Logger mux *http.ServeMux - claudeBinPath string // path to claude binary; defaults to "claude" - elaborateCmdPath string // overrides claudeBinPath; used in tests - validateCmdPath string // overrides claudeBinPath for validate; used in tests - startNextTaskScript string // path to start-next-task script; overridden in tests - deployScript string // path to deploy script; overridden in tests - workDir string // working directory injected into elaborate system prompt + claudeBinPath string // path to claude binary; defaults to "claude" + elaborateCmdPath string // overrides claudeBinPath; used in tests + validateCmdPath string // overrides claudeBinPath for validate; used in tests + scripts ScriptRegistry // optional; maps endpoint name → script path + workDir string // working directory injected into elaborate system prompt + notifier notify.Notifier + apiToken string // if non-empty, required for WebSocket (and REST) connections + elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints +} + +// SetAPIToken configures a bearer token that must be supplied to access the API. +func (s *Server) SetAPIToken(token string) { + s.apiToken = token +} + +// SetNotifier configures a notifier that is called on every task completion. +func (s *Server) SetNotifier(n notify.Notifier) { + s.notifier = n } func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath string) *Server { wd, _ := os.Getwd() s := &Server{ - store: store, - logStore: store, - taskLogStore: store, - pool: pool, + store: store, + logStore: store, + taskLogStore: store, + questionStore: store, + pool: pool, hub: NewHub(), logger: logger, mux: http.NewServeMux(), @@ -84,8 +108,7 @@ func (s *Server) routes() { s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate) s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion) s.mux.HandleFunc("POST /api/tasks/{id}/resume", s.handleResumeTimedOutTask) - s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask) - s.mux.HandleFunc("POST /api/scripts/deploy", s.handleDeploy) + s.mux.HandleFunc("POST /api/scripts/{name}", s.handleScript) s.mux.HandleFunc("GET /api/ws", s.handleWebSocket) s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces) s.mux.HandleFunc("GET /api/health", s.handleHealth) @@ -95,17 +118,44 @@ func (s *Server) routes() { // forwardResults listens on the executor pool's result channel and broadcasts via WebSocket. func (s *Server) forwardResults() { for result := range s.pool.Results() { - event := map[string]interface{}{ - "type": "task_completed", - "task_id": result.TaskID, - "status": result.Execution.Status, - "exit_code": result.Execution.ExitCode, - "cost_usd": result.Execution.CostUSD, - "error": result.Execution.ErrorMsg, - "timestamp": time.Now().UTC(), + s.processResult(result) + } +} + +// processResult broadcasts a task completion event via WebSocket and calls the notifier if set. +func (s *Server) processResult(result *executor.Result) { + event := map[string]interface{}{ + "type": "task_completed", + "task_id": result.TaskID, + "status": result.Execution.Status, + "exit_code": result.Execution.ExitCode, + "cost_usd": result.Execution.CostUSD, + "error": result.Execution.ErrorMsg, + "timestamp": time.Now().UTC(), + } + data, _ := json.Marshal(event) + s.hub.Broadcast(data) + + if s.notifier != nil { + var taskName string + if t, err := s.store.GetTask(result.TaskID); err == nil { + taskName = t.Name + } + var dur string + if !result.Execution.StartTime.IsZero() && !result.Execution.EndTime.IsZero() { + dur = result.Execution.EndTime.Sub(result.Execution.StartTime).String() + } + ne := notify.Event{ + TaskID: result.TaskID, + TaskName: taskName, + Status: result.Execution.Status, + CostUSD: result.Execution.CostUSD, + Duration: dur, + Error: result.Execution.ErrorMsg, + } + if err := s.notifier.Notify(ne); err != nil { + s.logger.Error("notifier failed", "error", err) } - data, _ := json.Marshal(event) - s.hub.Broadcast(data) } } @@ -167,7 +217,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") - tk, err := s.store.GetTask(taskID) + tk, err := s.questionStore.GetTask(taskID) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return @@ -190,15 +240,21 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { } // Look up the session ID from the most recent execution. - latest, err := s.store.GetLatestExecution(taskID) + latest, err := s.questionStore.GetLatestExecution(taskID) if err != nil || latest.SessionID == "" { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no resumable session found"}) return } // Clear the question and transition to QUEUED. - s.store.UpdateTaskQuestion(taskID, "") - s.store.UpdateTaskState(taskID, task.StateQueued) + if err := s.questionStore.UpdateTaskQuestion(taskID, ""); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to clear question"}) + return + } + if err := s.questionStore.UpdateTaskState(taskID, task.StateQueued); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to queue task"}) + return + } // Submit a resume execution. resumeExec := &storage.Execution{ @@ -254,9 +310,23 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request } func (s *Server) handleListWorkspaces(w http.ResponseWriter, r *http.Request) { + if s.apiToken != "" { + token := r.URL.Query().Get("token") + if token == "" { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token = strings.TrimPrefix(auth, "Bearer ") + } + } + if token != s.apiToken { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + entries, err := os.ReadDir("/workspace") if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list workspaces"}) return } var dirs []string @@ -370,6 +440,21 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { return } + // Enforce retry limit for non-initial runs (PENDING is the initial state). + if t.State != task.StatePending { + execs, err := s.store.ListExecutions(id) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to count executions"}) + return + } + if t.Retry.MaxAttempts > 0 && len(execs) >= t.Retry.MaxAttempts { + writeJSON(w, http.StatusConflict, map[string]string{ + "error": fmt.Sprintf("retry limit reached (%d/%d attempts used)", len(execs), t.Retry.MaxAttempts), + }) + return + } + } + if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return diff --git a/internal/api/server_test.go b/internal/api/server_test.go index e012bc1..c3b12ce 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -9,15 +9,70 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" + "time" "context" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" ) +// mockNotifier records calls to Notify. +type mockNotifier struct { + events []notify.Event +} + +func (m *mockNotifier) Notify(e notify.Event) error { + m.events = append(m.events, e) + return nil +} + +func TestServer_ProcessResult_CallsNotifier(t *testing.T) { + srv, store := testServer(t) + + mn := &mockNotifier{} + srv.SetNotifier(mn) + + tk := &task.Task{ + ID: "task-notifier-test", + Name: "Notifier Task", + State: task.StatePending, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + + result := &executor.Result{ + TaskID: tk.ID, + Execution: &storage.Execution{ + ID: "exec-1", + TaskID: tk.ID, + Status: "COMPLETED", + CostUSD: 0.42, + ErrorMsg: "", + }, + } + srv.processResult(result) + + if len(mn.events) != 1 { + t.Fatalf("expected 1 notify event, got %d", len(mn.events)) + } + ev := mn.events[0] + if ev.TaskID != tk.ID { + t.Errorf("event.TaskID = %q, want %q", ev.TaskID, tk.ID) + } + if ev.Status != "COMPLETED" { + t.Errorf("event.Status = %q, want COMPLETED", ev.Status) + } + if ev.CostUSD != 0.42 { + t.Errorf("event.CostUSD = %v, want 0.42", ev.CostUSD) + } +} + func testServer(t *testing.T) (*Server, *storage.DB) { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") @@ -170,6 +225,20 @@ func TestListTasks_WithTasks(t *testing.T) { } } +// stateWalkPaths defines the sequence of intermediate states needed to reach each target state. +var stateWalkPaths = map[task.State][]task.State{ + task.StatePending: {}, + task.StateQueued: {task.StateQueued}, + task.StateRunning: {task.StateQueued, task.StateRunning}, + task.StateCompleted: {task.StateQueued, task.StateRunning, task.StateCompleted}, + task.StateFailed: {task.StateQueued, task.StateRunning, task.StateFailed}, + task.StateTimedOut: {task.StateQueued, task.StateRunning, task.StateTimedOut}, + task.StateCancelled: {task.StateCancelled}, + task.StateBudgetExceeded: {task.StateQueued, task.StateRunning, task.StateBudgetExceeded}, + task.StateReady: {task.StateQueued, task.StateRunning, task.StateReady}, + task.StateBlocked: {task.StateQueued, task.StateRunning, task.StateBlocked}, +} + func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.State) *task.Task { t.Helper() tk := &task.Task{ @@ -183,9 +252,9 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task. if err := store.CreateTask(tk); err != nil { t.Fatalf("createTaskWithState: CreateTask: %v", err) } - if state != task.StatePending { - if err := store.UpdateTaskState(id, state); err != nil { - t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", state, err) + for _, s := range stateWalkPaths[state] { + if err := store.UpdateTaskState(id, s); err != nil { + t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", s, err) } } tk.State = state @@ -420,7 +489,7 @@ func TestHandleStartNextTask_Success(t *testing.T) { } srv, _ := testServer(t) - srv.startNextTaskScript = script + srv.SetScripts(ScriptRegistry{"start-next-task": script}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -447,7 +516,7 @@ func TestHandleStartNextTask_NoTask(t *testing.T) { } srv, _ := testServer(t) - srv.startNextTaskScript = script + srv.SetScripts(ScriptRegistry{"start-next-task": script}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -530,9 +599,87 @@ func TestResumeTimedOut_Success_Returns202(t *testing.T) { } } +func TestRunTask_RetryLimitReached_Returns409(t *testing.T) { + srv, store := testServer(t) + // Task with MaxAttempts: 1 — only 1 attempt allowed. Create directly as FAILED + // so state is consistent without going through transition sequence. + tk := &task.Task{ + ID: "retry-limit-1", + Name: "Retry Limit Task", + Claude: task.ClaudeConfig{Instructions: "do something"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateFailed, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + // Record one execution — the first attempt already used. + exec := &storage.Execution{ + ID: "exec-retry-1", + TaskID: "retry-limit-1", + StartTime: time.Now(), + Status: "FAILED", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/api/tasks/retry-limit-1/run", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if !strings.Contains(body["error"], "retry limit") { + t.Errorf("error body should mention retry limit, got %q", body["error"]) + } +} + +func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) { + srv, store := testServer(t) + // Task with MaxAttempts: 3 — 1 execution used, 2 remaining. + tk := &task.Task{ + ID: "retry-within-1", + Name: "Retry Within Task", + Claude: task.ClaudeConfig{Instructions: "do something"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StatePending, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + exec := &storage.Execution{ + ID: "exec-within-1", + TaskID: "retry-within-1", + StartTime: time.Now(), + Status: "FAILED", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + store.UpdateTaskState("retry-within-1", task.StateFailed) + + req := httptest.NewRequest("POST", "/api/tasks/retry-within-1/run", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Errorf("status: want 202, got %d; body: %s", w.Code, w.Body.String()) + } +} + func TestHandleStartNextTask_ScriptNotFound(t *testing.T) { srv, _ := testServer(t) - srv.startNextTaskScript = "/nonexistent/start-next-task" + srv.SetScripts(ScriptRegistry{"start-next-task": "/nonexistent/start-next-task"}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -578,10 +725,20 @@ func TestDeleteTask_NotFound(t *testing.T) { func TestDeleteTask_RunningTaskRejected(t *testing.T) { srv, store := testServer(t) - created := createTestTask(t, srv, `{"name":"Running Task","claude":{"instructions":"x","model":"sonnet"}}`) - store.UpdateTaskState(created.ID, "RUNNING") - - req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil) + // Create the task directly in RUNNING state to avoid going through state transitions. + tk := &task.Task{ + ID: "running-task-del", + Name: "Running Task", + Claude: task.ClaudeConfig{Instructions: "x", Model: "sonnet"}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateRunning, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + req := httptest.NewRequest("DELETE", "/api/tasks/running-task-del", nil) w := httptest.NewRecorder() srv.Handler().ServeHTTP(w, req) @@ -657,3 +814,170 @@ func TestServer_CancelTask_Completed_Returns409(t *testing.T) { t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) } } + +// mockQuestionStore implements questionStore for testing handleAnswerQuestion. +type mockQuestionStore struct { + getTaskFn func(id string) (*task.Task, error) + getLatestExecutionFn func(taskID string) (*storage.Execution, error) + updateTaskQuestionFn func(taskID, questionJSON string) error + updateTaskStateFn func(id string, newState task.State) error +} + +func (m *mockQuestionStore) GetTask(id string) (*task.Task, error) { + return m.getTaskFn(id) +} +func (m *mockQuestionStore) GetLatestExecution(taskID string) (*storage.Execution, error) { + return m.getLatestExecutionFn(taskID) +} +func (m *mockQuestionStore) UpdateTaskQuestion(taskID, questionJSON string) error { + return m.updateTaskQuestionFn(taskID, questionJSON) +} +func (m *mockQuestionStore) UpdateTaskState(id string, newState task.State) error { + return m.updateTaskStateFn(id, newState) +} + +func TestServer_AnswerQuestion_UpdateQuestionFails_Returns500(t *testing.T) { + srv, _ := testServer(t) + srv.questionStore = &mockQuestionStore{ + getTaskFn: func(id string) (*task.Task, error) { + return &task.Task{ID: id, State: task.StateBlocked}, nil + }, + getLatestExecutionFn: func(taskID string) (*storage.Execution, error) { + return &storage.Execution{SessionID: "sess-1"}, nil + }, + updateTaskQuestionFn: func(taskID, questionJSON string) error { + return fmt.Errorf("db error") + }, + updateTaskStateFn: func(id string, newState task.State) error { + return nil + }, + } + + body := bytes.NewBufferString(`{"answer":"yes"}`) + req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestServer_AnswerQuestion_UpdateStateFails_Returns500(t *testing.T) { + srv, _ := testServer(t) + srv.questionStore = &mockQuestionStore{ + getTaskFn: func(id string) (*task.Task, error) { + return &task.Task{ID: id, State: task.StateBlocked}, nil + }, + getLatestExecutionFn: func(taskID string) (*storage.Execution, error) { + return &storage.Execution{SessionID: "sess-1"}, nil + }, + updateTaskQuestionFn: func(taskID, questionJSON string) error { + return nil + }, + updateTaskStateFn: func(id string, newState task.State) error { + return fmt.Errorf("db error") + }, + } + + body := bytes.NewBufferString(`{"answer":"yes"}`) + req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestRateLimit_ElaborateRejectsExcess(t *testing.T) { + srv, _ := testServer(t) + // Use burst-1 and rate-0 so the second request from the same IP is rejected. + srv.elaborateLimiter = newIPRateLimiter(0, 1) + + makeReq := func(remoteAddr string) int { + req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(`{"description":"x"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = remoteAddr + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + return w.Code + } + + // First request from IP A: limiter allows it (non-429). + if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests { + t.Errorf("first request should not be rate limited, got 429") + } + // Second request from IP A: bucket exhausted, must be 429. + if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests { + t.Errorf("second request from same IP should be 429, got %d", code) + } + // First request from IP B: separate bucket, not limited. + if code := makeReq("192.0.2.2:1234"); code == http.StatusTooManyRequests { + t.Errorf("first request from different IP should not be rate limited, got 429") + } +} + +func TestListWorkspaces_RequiresAuth(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + // No token: expect 401. + req := httptest.NewRequest("GET", "/api/workspaces", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 without token, got %d", w.Code) + } +} + +func TestListWorkspaces_RejectsWrongToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + req := httptest.NewRequest("GET", "/api/workspaces", nil) + req.Header.Set("Authorization", "Bearer wrong-token") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 with wrong token, got %d", w.Code) + } +} + +func TestListWorkspaces_SuppressesRawError(t *testing.T) { + srv, _ := testServer(t) + // No token configured so auth is bypassed; /workspace likely doesn't exist in test env. + + req := httptest.NewRequest("GET", "/api/workspaces", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code == http.StatusInternalServerError { + body := w.Body.String() + if strings.Contains(body, "/workspace") || strings.Contains(body, "no such file") { + t.Errorf("response leaks filesystem details: %s", body) + } + } +} + +func TestRateLimit_ValidateRejectsExcess(t *testing.T) { + srv, _ := testServer(t) + srv.elaborateLimiter = newIPRateLimiter(0, 1) + + makeReq := func(remoteAddr string) int { + req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(`{"description":"x"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = remoteAddr + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + return w.Code + } + + if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests { + t.Errorf("first validate request should not be rate limited, got 429") + } + if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests { + t.Errorf("second validate request from same IP should be 429, got %d", code) + } +} diff --git a/internal/api/validate.go b/internal/api/validate.go index 4b691a9..0fcdb47 100644 --- a/internal/api/validate.go +++ b/internal/api/validate.go @@ -52,6 +52,11 @@ func (s *Server) validateBinaryPath() string { } func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) { + if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + var input struct { Name string `json:"name"` Claude struct { |
