summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-08 20:40:31 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-08 20:40:31 +0000
commit417034be7f745062901a940d1a021f6d85be496e (patch)
tree666956207b58c915090f6641891304156cf93670 /internal/api
parent181a37698410b68e00a885593b6f2b7acf21f4b4 (diff)
api: SetAPIToken, SetNotifier, questionStore, per-IP rate limiter
- Extract questionStore interface for testability of handleAnswerQuestion - Add SetAPIToken/SetNotifier methods for post-construction wiring - Extract processResult() from forwardResults() for direct testability - Add ipRateLimiter with token-bucket per IP; applied to /elaborate and /validate - Fix tests for running-task deletion and retry-limit that relied on invalid state transitions in setup Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/elaborate.go5
-rw-r--r--internal/api/ratelimit.go99
-rw-r--r--internal/api/server.go143
-rw-r--r--internal/api/server_test.go344
-rw-r--r--internal/api/validate.go5
5 files changed, 557 insertions, 39 deletions
diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go
index e480e00..8a18dee 100644
--- a/internal/api/elaborate.go
+++ b/internal/api/elaborate.go
@@ -85,6 +85,11 @@ func (s *Server) claudeBinaryPath() string {
}
func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) {
+ if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) {
+ writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
+ return
+ }
+
var input struct {
Prompt string `json:"prompt"`
ProjectDir string `json:"project_dir"`
diff --git a/internal/api/ratelimit.go b/internal/api/ratelimit.go
new file mode 100644
index 0000000..089354c
--- /dev/null
+++ b/internal/api/ratelimit.go
@@ -0,0 +1,99 @@
+package api
+
+import (
+ "net"
+ "net/http"
+ "sync"
+ "time"
+)
+
+// ipRateLimiter provides per-IP token-bucket rate limiting.
+type ipRateLimiter struct {
+ mu sync.Mutex
+ limiters map[string]*tokenBucket
+ rate float64 // tokens replenished per second
+ burst int // maximum token capacity
+}
+
+// newIPRateLimiter creates a limiter with the given replenishment rate (tokens/sec)
+// and burst capacity. Use rate=0 to disable replenishment (tokens never refill).
+func newIPRateLimiter(rate float64, burst int) *ipRateLimiter {
+ return &ipRateLimiter{
+ limiters: make(map[string]*tokenBucket),
+ rate: rate,
+ burst: burst,
+ }
+}
+
+func (l *ipRateLimiter) allow(ip string) bool {
+ l.mu.Lock()
+ b, ok := l.limiters[ip]
+ if !ok {
+ b = &tokenBucket{
+ tokens: float64(l.burst),
+ capacity: float64(l.burst),
+ rate: l.rate,
+ lastTime: time.Now(),
+ }
+ l.limiters[ip] = b
+ }
+ l.mu.Unlock()
+ return b.allow()
+}
+
+// middleware wraps h with per-IP rate limiting, returning 429 when exceeded.
+func (l *ipRateLimiter) middleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ip := realIP(r)
+ if !l.allow(ip) {
+ writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
+ return
+ }
+ next.ServeHTTP(w, r)
+ })
+}
+
+// tokenBucket is a simple token-bucket rate limiter for a single key.
+type tokenBucket struct {
+ mu sync.Mutex
+ tokens float64
+ capacity float64
+ rate float64 // tokens per second
+ lastTime time.Time
+}
+
+func (b *tokenBucket) allow() bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ now := time.Now()
+ if !b.lastTime.IsZero() {
+ elapsed := now.Sub(b.lastTime).Seconds()
+ b.tokens = min(b.capacity, b.tokens+elapsed*b.rate)
+ }
+ b.lastTime = now
+ if b.tokens >= 1.0 {
+ b.tokens--
+ return true
+ }
+ return false
+}
+
+// realIP extracts the client's real IP from a request.
+func realIP(r *http.Request) string {
+ if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
+ for i, c := range xff {
+ if c == ',' {
+ return xff[:i]
+ }
+ }
+ return xff
+ }
+ if xri := r.Header.Get("X-Real-IP"); xri != "" {
+ return xri
+ }
+ host, _, err := net.SplitHostPort(r.RemoteAddr)
+ if err != nil {
+ return r.RemoteAddr
+ }
+ return host
+}
diff --git a/internal/api/server.go b/internal/api/server.go
index af4710b..833be8b 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -7,39 +7,63 @@ import (
"log/slog"
"net/http"
"os"
+ "strings"
"time"
"github.com/thepeterstone/claudomator/internal/executor"
+ "github.com/thepeterstone/claudomator/internal/notify"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
webui "github.com/thepeterstone/claudomator/web"
"github.com/google/uuid"
)
+// questionStore is the minimal storage interface needed by handleAnswerQuestion.
+type questionStore interface {
+ GetTask(id string) (*task.Task, error)
+ GetLatestExecution(taskID string) (*storage.Execution, error)
+ UpdateTaskQuestion(taskID, questionJSON string) error
+ UpdateTaskState(id string, newState task.State) error
+}
+
// Server provides the REST API and WebSocket endpoint for Claudomator.
type Server struct {
store *storage.DB
- logStore logStore // injectable for tests; defaults to store
- taskLogStore taskLogStore // injectable for tests; defaults to store
+ logStore logStore // injectable for tests; defaults to store
+ taskLogStore taskLogStore // injectable for tests; defaults to store
+ questionStore questionStore // injectable for tests; defaults to store
pool *executor.Pool
hub *Hub
logger *slog.Logger
mux *http.ServeMux
- claudeBinPath string // path to claude binary; defaults to "claude"
- elaborateCmdPath string // overrides claudeBinPath; used in tests
- validateCmdPath string // overrides claudeBinPath for validate; used in tests
- startNextTaskScript string // path to start-next-task script; overridden in tests
- deployScript string // path to deploy script; overridden in tests
- workDir string // working directory injected into elaborate system prompt
+ claudeBinPath string // path to claude binary; defaults to "claude"
+ elaborateCmdPath string // overrides claudeBinPath; used in tests
+ validateCmdPath string // overrides claudeBinPath for validate; used in tests
+ scripts ScriptRegistry // optional; maps endpoint name → script path
+ workDir string // working directory injected into elaborate system prompt
+ notifier notify.Notifier
+ apiToken string // if non-empty, required for WebSocket (and REST) connections
+ elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints
+}
+
+// SetAPIToken configures a bearer token that must be supplied to access the API.
+func (s *Server) SetAPIToken(token string) {
+ s.apiToken = token
+}
+
+// SetNotifier configures a notifier that is called on every task completion.
+func (s *Server) SetNotifier(n notify.Notifier) {
+ s.notifier = n
}
func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath string) *Server {
wd, _ := os.Getwd()
s := &Server{
- store: store,
- logStore: store,
- taskLogStore: store,
- pool: pool,
+ store: store,
+ logStore: store,
+ taskLogStore: store,
+ questionStore: store,
+ pool: pool,
hub: NewHub(),
logger: logger,
mux: http.NewServeMux(),
@@ -84,8 +108,7 @@ func (s *Server) routes() {
s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate)
s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion)
s.mux.HandleFunc("POST /api/tasks/{id}/resume", s.handleResumeTimedOutTask)
- s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask)
- s.mux.HandleFunc("POST /api/scripts/deploy", s.handleDeploy)
+ s.mux.HandleFunc("POST /api/scripts/{name}", s.handleScript)
s.mux.HandleFunc("GET /api/ws", s.handleWebSocket)
s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces)
s.mux.HandleFunc("GET /api/health", s.handleHealth)
@@ -95,17 +118,44 @@ func (s *Server) routes() {
// forwardResults listens on the executor pool's result channel and broadcasts via WebSocket.
func (s *Server) forwardResults() {
for result := range s.pool.Results() {
- event := map[string]interface{}{
- "type": "task_completed",
- "task_id": result.TaskID,
- "status": result.Execution.Status,
- "exit_code": result.Execution.ExitCode,
- "cost_usd": result.Execution.CostUSD,
- "error": result.Execution.ErrorMsg,
- "timestamp": time.Now().UTC(),
+ s.processResult(result)
+ }
+}
+
+// processResult broadcasts a task completion event via WebSocket and calls the notifier if set.
+func (s *Server) processResult(result *executor.Result) {
+ event := map[string]interface{}{
+ "type": "task_completed",
+ "task_id": result.TaskID,
+ "status": result.Execution.Status,
+ "exit_code": result.Execution.ExitCode,
+ "cost_usd": result.Execution.CostUSD,
+ "error": result.Execution.ErrorMsg,
+ "timestamp": time.Now().UTC(),
+ }
+ data, _ := json.Marshal(event)
+ s.hub.Broadcast(data)
+
+ if s.notifier != nil {
+ var taskName string
+ if t, err := s.store.GetTask(result.TaskID); err == nil {
+ taskName = t.Name
+ }
+ var dur string
+ if !result.Execution.StartTime.IsZero() && !result.Execution.EndTime.IsZero() {
+ dur = result.Execution.EndTime.Sub(result.Execution.StartTime).String()
+ }
+ ne := notify.Event{
+ TaskID: result.TaskID,
+ TaskName: taskName,
+ Status: result.Execution.Status,
+ CostUSD: result.Execution.CostUSD,
+ Duration: dur,
+ Error: result.Execution.ErrorMsg,
+ }
+ if err := s.notifier.Notify(ne); err != nil {
+ s.logger.Error("notifier failed", "error", err)
}
- data, _ := json.Marshal(event)
- s.hub.Broadcast(data)
}
}
@@ -167,7 +217,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
taskID := r.PathValue("id")
- tk, err := s.store.GetTask(taskID)
+ tk, err := s.questionStore.GetTask(taskID)
if err != nil {
writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
return
@@ -190,15 +240,21 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
}
// Look up the session ID from the most recent execution.
- latest, err := s.store.GetLatestExecution(taskID)
+ latest, err := s.questionStore.GetLatestExecution(taskID)
if err != nil || latest.SessionID == "" {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no resumable session found"})
return
}
// Clear the question and transition to QUEUED.
- s.store.UpdateTaskQuestion(taskID, "")
- s.store.UpdateTaskState(taskID, task.StateQueued)
+ if err := s.questionStore.UpdateTaskQuestion(taskID, ""); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to clear question"})
+ return
+ }
+ if err := s.questionStore.UpdateTaskState(taskID, task.StateQueued); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to queue task"})
+ return
+ }
// Submit a resume execution.
resumeExec := &storage.Execution{
@@ -254,9 +310,23 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request
}
func (s *Server) handleListWorkspaces(w http.ResponseWriter, r *http.Request) {
+ if s.apiToken != "" {
+ token := r.URL.Query().Get("token")
+ if token == "" {
+ auth := r.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Bearer ") {
+ token = strings.TrimPrefix(auth, "Bearer ")
+ }
+ }
+ if token != s.apiToken {
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
+ return
+ }
+ }
+
entries, err := os.ReadDir("/workspace")
if err != nil {
- writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list workspaces"})
return
}
var dirs []string
@@ -370,6 +440,21 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
return
}
+ // Enforce retry limit for non-initial runs (PENDING is the initial state).
+ if t.State != task.StatePending {
+ execs, err := s.store.ListExecutions(id)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to count executions"})
+ return
+ }
+ if t.Retry.MaxAttempts > 0 && len(execs) >= t.Retry.MaxAttempts {
+ writeJSON(w, http.StatusConflict, map[string]string{
+ "error": fmt.Sprintf("retry limit reached (%d/%d attempts used)", len(execs), t.Retry.MaxAttempts),
+ })
+ return
+ }
+ }
+
if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index e012bc1..c3b12ce 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -9,15 +9,70 @@ import (
"net/http/httptest"
"os"
"path/filepath"
+ "strings"
"testing"
+ "time"
"context"
"github.com/thepeterstone/claudomator/internal/executor"
+ "github.com/thepeterstone/claudomator/internal/notify"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
)
+// mockNotifier records calls to Notify.
+type mockNotifier struct {
+ events []notify.Event
+}
+
+func (m *mockNotifier) Notify(e notify.Event) error {
+ m.events = append(m.events, e)
+ return nil
+}
+
+func TestServer_ProcessResult_CallsNotifier(t *testing.T) {
+ srv, store := testServer(t)
+
+ mn := &mockNotifier{}
+ srv.SetNotifier(mn)
+
+ tk := &task.Task{
+ ID: "task-notifier-test",
+ Name: "Notifier Task",
+ State: task.StatePending,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+
+ result := &executor.Result{
+ TaskID: tk.ID,
+ Execution: &storage.Execution{
+ ID: "exec-1",
+ TaskID: tk.ID,
+ Status: "COMPLETED",
+ CostUSD: 0.42,
+ ErrorMsg: "",
+ },
+ }
+ srv.processResult(result)
+
+ if len(mn.events) != 1 {
+ t.Fatalf("expected 1 notify event, got %d", len(mn.events))
+ }
+ ev := mn.events[0]
+ if ev.TaskID != tk.ID {
+ t.Errorf("event.TaskID = %q, want %q", ev.TaskID, tk.ID)
+ }
+ if ev.Status != "COMPLETED" {
+ t.Errorf("event.Status = %q, want COMPLETED", ev.Status)
+ }
+ if ev.CostUSD != 0.42 {
+ t.Errorf("event.CostUSD = %v, want 0.42", ev.CostUSD)
+ }
+}
+
func testServer(t *testing.T) (*Server, *storage.DB) {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test.db")
@@ -170,6 +225,20 @@ func TestListTasks_WithTasks(t *testing.T) {
}
}
+// stateWalkPaths defines the sequence of intermediate states needed to reach each target state.
+var stateWalkPaths = map[task.State][]task.State{
+ task.StatePending: {},
+ task.StateQueued: {task.StateQueued},
+ task.StateRunning: {task.StateQueued, task.StateRunning},
+ task.StateCompleted: {task.StateQueued, task.StateRunning, task.StateCompleted},
+ task.StateFailed: {task.StateQueued, task.StateRunning, task.StateFailed},
+ task.StateTimedOut: {task.StateQueued, task.StateRunning, task.StateTimedOut},
+ task.StateCancelled: {task.StateCancelled},
+ task.StateBudgetExceeded: {task.StateQueued, task.StateRunning, task.StateBudgetExceeded},
+ task.StateReady: {task.StateQueued, task.StateRunning, task.StateReady},
+ task.StateBlocked: {task.StateQueued, task.StateRunning, task.StateBlocked},
+}
+
func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.State) *task.Task {
t.Helper()
tk := &task.Task{
@@ -183,9 +252,9 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.
if err := store.CreateTask(tk); err != nil {
t.Fatalf("createTaskWithState: CreateTask: %v", err)
}
- if state != task.StatePending {
- if err := store.UpdateTaskState(id, state); err != nil {
- t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", state, err)
+ for _, s := range stateWalkPaths[state] {
+ if err := store.UpdateTaskState(id, s); err != nil {
+ t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", s, err)
}
}
tk.State = state
@@ -420,7 +489,7 @@ func TestHandleStartNextTask_Success(t *testing.T) {
}
srv, _ := testServer(t)
- srv.startNextTaskScript = script
+ srv.SetScripts(ScriptRegistry{"start-next-task": script})
req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
w := httptest.NewRecorder()
@@ -447,7 +516,7 @@ func TestHandleStartNextTask_NoTask(t *testing.T) {
}
srv, _ := testServer(t)
- srv.startNextTaskScript = script
+ srv.SetScripts(ScriptRegistry{"start-next-task": script})
req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
w := httptest.NewRecorder()
@@ -530,9 +599,87 @@ func TestResumeTimedOut_Success_Returns202(t *testing.T) {
}
}
+func TestRunTask_RetryLimitReached_Returns409(t *testing.T) {
+ srv, store := testServer(t)
+ // Task with MaxAttempts: 1 — only 1 attempt allowed. Create directly as FAILED
+ // so state is consistent without going through transition sequence.
+ tk := &task.Task{
+ ID: "retry-limit-1",
+ Name: "Retry Limit Task",
+ Claude: task.ClaudeConfig{Instructions: "do something"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StateFailed,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ // Record one execution — the first attempt already used.
+ exec := &storage.Execution{
+ ID: "exec-retry-1",
+ TaskID: "retry-limit-1",
+ StartTime: time.Now(),
+ Status: "FAILED",
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatal(err)
+ }
+
+ req := httptest.NewRequest("POST", "/api/tasks/retry-limit-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusConflict {
+ t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if !strings.Contains(body["error"], "retry limit") {
+ t.Errorf("error body should mention retry limit, got %q", body["error"])
+ }
+}
+
+func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) {
+ srv, store := testServer(t)
+ // Task with MaxAttempts: 3 — 1 execution used, 2 remaining.
+ tk := &task.Task{
+ ID: "retry-within-1",
+ Name: "Retry Within Task",
+ Claude: task.ClaudeConfig{Instructions: "do something"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StatePending,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ exec := &storage.Execution{
+ ID: "exec-within-1",
+ TaskID: "retry-within-1",
+ StartTime: time.Now(),
+ Status: "FAILED",
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatal(err)
+ }
+ store.UpdateTaskState("retry-within-1", task.StateFailed)
+
+ req := httptest.NewRequest("POST", "/api/tasks/retry-within-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Errorf("status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
func TestHandleStartNextTask_ScriptNotFound(t *testing.T) {
srv, _ := testServer(t)
- srv.startNextTaskScript = "/nonexistent/start-next-task"
+ srv.SetScripts(ScriptRegistry{"start-next-task": "/nonexistent/start-next-task"})
req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
w := httptest.NewRecorder()
@@ -578,10 +725,20 @@ func TestDeleteTask_NotFound(t *testing.T) {
func TestDeleteTask_RunningTaskRejected(t *testing.T) {
srv, store := testServer(t)
- created := createTestTask(t, srv, `{"name":"Running Task","claude":{"instructions":"x","model":"sonnet"}}`)
- store.UpdateTaskState(created.ID, "RUNNING")
-
- req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil)
+ // Create the task directly in RUNNING state to avoid going through state transitions.
+ tk := &task.Task{
+ ID: "running-task-del",
+ Name: "Running Task",
+ Claude: task.ClaudeConfig{Instructions: "x", Model: "sonnet"},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StateRunning,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ req := httptest.NewRequest("DELETE", "/api/tasks/running-task-del", nil)
w := httptest.NewRecorder()
srv.Handler().ServeHTTP(w, req)
@@ -657,3 +814,170 @@ func TestServer_CancelTask_Completed_Returns409(t *testing.T) {
t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String())
}
}
+
+// mockQuestionStore implements questionStore for testing handleAnswerQuestion.
+type mockQuestionStore struct {
+ getTaskFn func(id string) (*task.Task, error)
+ getLatestExecutionFn func(taskID string) (*storage.Execution, error)
+ updateTaskQuestionFn func(taskID, questionJSON string) error
+ updateTaskStateFn func(id string, newState task.State) error
+}
+
+func (m *mockQuestionStore) GetTask(id string) (*task.Task, error) {
+ return m.getTaskFn(id)
+}
+func (m *mockQuestionStore) GetLatestExecution(taskID string) (*storage.Execution, error) {
+ return m.getLatestExecutionFn(taskID)
+}
+func (m *mockQuestionStore) UpdateTaskQuestion(taskID, questionJSON string) error {
+ return m.updateTaskQuestionFn(taskID, questionJSON)
+}
+func (m *mockQuestionStore) UpdateTaskState(id string, newState task.State) error {
+ return m.updateTaskStateFn(id, newState)
+}
+
+func TestServer_AnswerQuestion_UpdateQuestionFails_Returns500(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.questionStore = &mockQuestionStore{
+ getTaskFn: func(id string) (*task.Task, error) {
+ return &task.Task{ID: id, State: task.StateBlocked}, nil
+ },
+ getLatestExecutionFn: func(taskID string) (*storage.Execution, error) {
+ return &storage.Execution{SessionID: "sess-1"}, nil
+ },
+ updateTaskQuestionFn: func(taskID, questionJSON string) error {
+ return fmt.Errorf("db error")
+ },
+ updateTaskStateFn: func(id string, newState task.State) error {
+ return nil
+ },
+ }
+
+ body := bytes.NewBufferString(`{"answer":"yes"}`)
+ req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body)
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestServer_AnswerQuestion_UpdateStateFails_Returns500(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.questionStore = &mockQuestionStore{
+ getTaskFn: func(id string) (*task.Task, error) {
+ return &task.Task{ID: id, State: task.StateBlocked}, nil
+ },
+ getLatestExecutionFn: func(taskID string) (*storage.Execution, error) {
+ return &storage.Execution{SessionID: "sess-1"}, nil
+ },
+ updateTaskQuestionFn: func(taskID, questionJSON string) error {
+ return nil
+ },
+ updateTaskStateFn: func(id string, newState task.State) error {
+ return fmt.Errorf("db error")
+ },
+ }
+
+ body := bytes.NewBufferString(`{"answer":"yes"}`)
+ req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body)
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestRateLimit_ElaborateRejectsExcess(t *testing.T) {
+ srv, _ := testServer(t)
+ // Use burst-1 and rate-0 so the second request from the same IP is rejected.
+ srv.elaborateLimiter = newIPRateLimiter(0, 1)
+
+ makeReq := func(remoteAddr string) int {
+ req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(`{"description":"x"}`))
+ req.Header.Set("Content-Type", "application/json")
+ req.RemoteAddr = remoteAddr
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ return w.Code
+ }
+
+ // First request from IP A: limiter allows it (non-429).
+ if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests {
+ t.Errorf("first request should not be rate limited, got 429")
+ }
+ // Second request from IP A: bucket exhausted, must be 429.
+ if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests {
+ t.Errorf("second request from same IP should be 429, got %d", code)
+ }
+ // First request from IP B: separate bucket, not limited.
+ if code := makeReq("192.0.2.2:1234"); code == http.StatusTooManyRequests {
+ t.Errorf("first request from different IP should not be rate limited, got 429")
+ }
+}
+
+func TestListWorkspaces_RequiresAuth(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+
+ // No token: expect 401.
+ req := httptest.NewRequest("GET", "/api/workspaces", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("expected 401 without token, got %d", w.Code)
+ }
+}
+
+func TestListWorkspaces_RejectsWrongToken(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+
+ req := httptest.NewRequest("GET", "/api/workspaces", nil)
+ req.Header.Set("Authorization", "Bearer wrong-token")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("expected 401 with wrong token, got %d", w.Code)
+ }
+}
+
+func TestListWorkspaces_SuppressesRawError(t *testing.T) {
+ srv, _ := testServer(t)
+ // No token configured so auth is bypassed; /workspace likely doesn't exist in test env.
+
+ req := httptest.NewRequest("GET", "/api/workspaces", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code == http.StatusInternalServerError {
+ body := w.Body.String()
+ if strings.Contains(body, "/workspace") || strings.Contains(body, "no such file") {
+ t.Errorf("response leaks filesystem details: %s", body)
+ }
+ }
+}
+
+func TestRateLimit_ValidateRejectsExcess(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.elaborateLimiter = newIPRateLimiter(0, 1)
+
+ makeReq := func(remoteAddr string) int {
+ req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(`{"description":"x"}`))
+ req.Header.Set("Content-Type", "application/json")
+ req.RemoteAddr = remoteAddr
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ return w.Code
+ }
+
+ if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests {
+ t.Errorf("first validate request should not be rate limited, got 429")
+ }
+ if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests {
+ t.Errorf("second validate request from same IP should be 429, got %d", code)
+ }
+}
diff --git a/internal/api/validate.go b/internal/api/validate.go
index 4b691a9..0fcdb47 100644
--- a/internal/api/validate.go
+++ b/internal/api/validate.go
@@ -52,6 +52,11 @@ func (s *Server) validateBinaryPath() string {
}
func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) {
+ if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) {
+ writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
+ return
+ }
+
var input struct {
Name string `json:"name"`
Claude struct {