diff options
Diffstat (limited to 'internal/api')
| -rw-r--r-- | internal/api/elaborate.go | 17 | ||||
| -rw-r--r-- | internal/api/elaborate_test.go | 2 | ||||
| -rw-r--r-- | internal/api/executions.go | 4 | ||||
| -rw-r--r-- | internal/api/executions_test.go | 20 | ||||
| -rw-r--r-- | internal/api/logs.go | 33 | ||||
| -rw-r--r-- | internal/api/logs_test.go | 71 | ||||
| -rw-r--r-- | internal/api/ratelimit.go | 99 | ||||
| -rw-r--r-- | internal/api/scripts.go | 64 | ||||
| -rw-r--r-- | internal/api/scripts_test.go | 83 | ||||
| -rw-r--r-- | internal/api/server.go | 145 | ||||
| -rw-r--r-- | internal/api/server_test.go | 344 | ||||
| -rw-r--r-- | internal/api/validate.go | 19 | ||||
| -rw-r--r-- | internal/api/websocket.go | 71 | ||||
| -rw-r--r-- | internal/api/websocket_test.go | 221 |
14 files changed, 1068 insertions, 125 deletions
diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go index 5ab9ff0..907cb98 100644 --- a/internal/api/elaborate.go +++ b/internal/api/elaborate.go @@ -14,9 +14,9 @@ import ( const elaborateTimeout = 30 * time.Second func buildElaboratePrompt(workDir string) string { - workDirLine := ` "working_dir": string — leave empty unless you have a specific reason to set it,` + workDirLine := ` "project_dir": string — leave empty unless you have a specific reason to set it,` if workDir != "" { - workDirLine = fmt.Sprintf(` "working_dir": string — use %q for tasks that operate on this codebase, empty string otherwise,`, workDir) + workDirLine = fmt.Sprintf(` "project_dir": string — use %q for tasks that operate on this codebase, empty string otherwise,`, workDir) } return `You are a task configuration assistant for Claudomator, an AI task runner that executes tasks by running Claude or Gemini as a subprocess. @@ -55,7 +55,7 @@ type elaboratedAgent struct { Type string `json:"type"` Model string `json:"model"` Instructions string `json:"instructions"` - WorkingDir string `json:"working_dir"` + ProjectDir string `json:"project_dir"` MaxBudgetUSD float64 `json:"max_budget_usd"` AllowedTools []string `json:"allowed_tools"` } @@ -87,9 +87,14 @@ func (s *Server) claudeBinaryPath() string { } func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { + if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + var input struct { Prompt string `json:"prompt"` - WorkingDir string `json:"working_dir"` + ProjectDir string `json:"project_dir"` } if err := json.NewDecoder(r.Body).Decode(&input); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) @@ -101,8 +106,8 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) { } workDir := s.workDir - if input.WorkingDir != "" { - workDir = input.WorkingDir + if input.ProjectDir != "" { + workDir = input.ProjectDir } ctx, cancel := context.WithTimeout(r.Context(), elaborateTimeout) diff --git a/internal/api/elaborate_test.go b/internal/api/elaborate_test.go index 4939701..b33ca11 100644 --- a/internal/api/elaborate_test.go +++ b/internal/api/elaborate_test.go @@ -57,7 +57,7 @@ func TestElaborateTask_Success(t *testing.T) { Type: "claude", Model: "sonnet", Instructions: "Run go test -race ./... and report results.", - WorkingDir: "", + ProjectDir: "", MaxBudgetUSD: 0.5, AllowedTools: []string{"Bash"}, }, diff --git a/internal/api/executions.go b/internal/api/executions.go index d9214c0..114425e 100644 --- a/internal/api/executions.go +++ b/internal/api/executions.go @@ -21,12 +21,16 @@ func (s *Server) handleListRecentExecutions(w http.ResponseWriter, r *http.Reque } } + const maxLimit = 1000 limit := 50 if v := r.URL.Query().Get("limit"); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 { limit = n } } + if limit > maxLimit { + limit = maxLimit + } taskID := r.URL.Query().Get("task_id") diff --git a/internal/api/executions_test.go b/internal/api/executions_test.go index a2bba21..45548ad 100644 --- a/internal/api/executions_test.go +++ b/internal/api/executions_test.go @@ -258,6 +258,26 @@ func TestGetExecutionLog_FollowSSEHeaders(t *testing.T) { } } +func TestListRecentExecutions_LimitClamped(t *testing.T) { + srv, _ := testServer(t) + + req := httptest.NewRequest("GET", "/api/executions?limit=10000000", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String()) + } + // The handler should not pass limit > maxLimit to the store. + // We verify indirectly: if the query param is accepted without error and + // does not cause a panic or 500, the clamp is in effect. + // A direct assertion requires a mock store; here we check the response is valid. + var execs []storage.RecentExecution + if err := json.NewDecoder(w.Body).Decode(&execs); err != nil { + t.Fatalf("decoding response: %v", err) + } +} + func TestListTasks_ReturnsStateField(t *testing.T) { srv, store := testServer(t) createTaskWithState(t, store, "state-check", task.StateRunning) diff --git a/internal/api/logs.go b/internal/api/logs.go index 1ba4b00..4e63489 100644 --- a/internal/api/logs.go +++ b/internal/api/logs.go @@ -36,9 +36,10 @@ var terminalStates = map[string]bool{ } type logStreamMsg struct { - Type string `json:"type"` - Message *logAssistMsg `json:"message,omitempty"` - CostUSD float64 `json:"cost_usd,omitempty"` + Type string `json:"type"` + Message *logAssistMsg `json:"message,omitempty"` + CostUSD float64 `json:"cost_usd,omitempty"` + TotalCostUSD float64 `json:"total_cost_usd,omitempty"` } type logAssistMsg struct { @@ -258,29 +259,31 @@ func emitLogLine(w http.ResponseWriter, flusher http.Flusher, line []byte) { return } for _, block := range msg.Message.Content { - var event map[string]string + var data []byte switch block.Type { case "text": - event = map[string]string{"type": "text", "content": block.Text} + data, _ = json.Marshal(map[string]string{"type": "text", "content": block.Text}) case "tool_use": - summary := string(block.Input) - if len(summary) > 80 { - summary = summary[:80] - } - event = map[string]string{"type": "tool_use", "content": fmt.Sprintf("%s(%s)", block.Name, summary)} + data, _ = json.Marshal(struct { + Type string `json:"type"` + Name string `json:"name"` + Input json.RawMessage `json:"input,omitempty"` + }{Type: "tool_use", Name: block.Name, Input: block.Input}) default: continue } - data, _ := json.Marshal(event) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() } case "result": - event := map[string]string{ - "type": "cost", - "content": fmt.Sprintf("%g", msg.CostUSD), + cost := msg.TotalCostUSD + if cost == 0 { + cost = msg.CostUSD } - data, _ := json.Marshal(event) + data, _ := json.Marshal(struct { + Type string `json:"type"` + TotalCost float64 `json:"total_cost"` + }{Type: "cost", TotalCost: cost}) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() } diff --git a/internal/api/logs_test.go b/internal/api/logs_test.go index 52fa168..6c6be05 100644 --- a/internal/api/logs_test.go +++ b/internal/api/logs_test.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "errors" "net/http" "net/http/httptest" @@ -293,6 +294,76 @@ func TestHandleStreamTaskLogs_TerminalExecution_EmitsEventsAndDone(t *testing.T) } } +// TestEmitLogLine_ToolUse_EmitsNameField verifies that emitLogLine emits a tool_use SSE event +// with a "name" field matching the tool name so the web UI can display it as "[ToolName]". +func TestEmitLogLine_ToolUse_EmitsNameField(t *testing.T) { + line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","input":{"command":"ls -la"}}]}}`) + + w := httptest.NewRecorder() + emitLogLine(w, w, line) + + body := w.Body.String() + var found bool + for _, chunk := range strings.Split(body, "\n\n") { + chunk = strings.TrimSpace(chunk) + if !strings.HasPrefix(chunk, "data: ") { + continue + } + jsonStr := strings.TrimPrefix(chunk, "data: ") + var e map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &e); err != nil { + continue + } + if e["type"] == "tool_use" { + if e["name"] != "Bash" { + t.Errorf("tool_use event name: want Bash, got %v", e["name"]) + } + if e["input"] == nil { + t.Error("tool_use event input: expected non-nil") + } + found = true + } + } + if !found { + t.Errorf("no tool_use event found in SSE output:\n%s", body) + } +} + +// TestEmitLogLine_Cost_EmitsTotalCostField verifies that emitLogLine emits a cost SSE event +// with a numeric "total_cost" field so the web UI can display it correctly. +func TestEmitLogLine_Cost_EmitsTotalCostField(t *testing.T) { + line := []byte(`{"type":"result","total_cost_usd":0.0042}`) + + w := httptest.NewRecorder() + emitLogLine(w, w, line) + + body := w.Body.String() + var found bool + for _, chunk := range strings.Split(body, "\n\n") { + chunk = strings.TrimSpace(chunk) + if !strings.HasPrefix(chunk, "data: ") { + continue + } + jsonStr := strings.TrimPrefix(chunk, "data: ") + var e map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &e); err != nil { + continue + } + if e["type"] == "cost" { + if e["total_cost"] == nil { + t.Error("cost event total_cost: expected non-nil numeric field") + } + if v, ok := e["total_cost"].(float64); !ok || v != 0.0042 { + t.Errorf("cost event total_cost: want 0.0042, got %v", e["total_cost"]) + } + found = true + } + } + if !found { + t.Errorf("no cost event found in SSE output:\n%s", body) + } +} + // TestHandleStreamTaskLogs_RunningExecution_LiveTails verifies that a RUNNING execution is // live-tailed and a done event is emitted once it transitions to a terminal state. func TestHandleStreamTaskLogs_RunningExecution_LiveTails(t *testing.T) { diff --git a/internal/api/ratelimit.go b/internal/api/ratelimit.go new file mode 100644 index 0000000..089354c --- /dev/null +++ b/internal/api/ratelimit.go @@ -0,0 +1,99 @@ +package api + +import ( + "net" + "net/http" + "sync" + "time" +) + +// ipRateLimiter provides per-IP token-bucket rate limiting. +type ipRateLimiter struct { + mu sync.Mutex + limiters map[string]*tokenBucket + rate float64 // tokens replenished per second + burst int // maximum token capacity +} + +// newIPRateLimiter creates a limiter with the given replenishment rate (tokens/sec) +// and burst capacity. Use rate=0 to disable replenishment (tokens never refill). +func newIPRateLimiter(rate float64, burst int) *ipRateLimiter { + return &ipRateLimiter{ + limiters: make(map[string]*tokenBucket), + rate: rate, + burst: burst, + } +} + +func (l *ipRateLimiter) allow(ip string) bool { + l.mu.Lock() + b, ok := l.limiters[ip] + if !ok { + b = &tokenBucket{ + tokens: float64(l.burst), + capacity: float64(l.burst), + rate: l.rate, + lastTime: time.Now(), + } + l.limiters[ip] = b + } + l.mu.Unlock() + return b.allow() +} + +// middleware wraps h with per-IP rate limiting, returning 429 when exceeded. +func (l *ipRateLimiter) middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := realIP(r) + if !l.allow(ip) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + next.ServeHTTP(w, r) + }) +} + +// tokenBucket is a simple token-bucket rate limiter for a single key. +type tokenBucket struct { + mu sync.Mutex + tokens float64 + capacity float64 + rate float64 // tokens per second + lastTime time.Time +} + +func (b *tokenBucket) allow() bool { + b.mu.Lock() + defer b.mu.Unlock() + now := time.Now() + if !b.lastTime.IsZero() { + elapsed := now.Sub(b.lastTime).Seconds() + b.tokens = min(b.capacity, b.tokens+elapsed*b.rate) + } + b.lastTime = now + if b.tokens >= 1.0 { + b.tokens-- + return true + } + return false +} + +// realIP extracts the client's real IP from a request. +func realIP(r *http.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + for i, c := range xff { + if c == ',' { + return xff[:i] + } + } + return xff + } + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} diff --git a/internal/api/scripts.go b/internal/api/scripts.go index 9afbb75..822bd32 100644 --- a/internal/api/scripts.go +++ b/internal/api/scripts.go @@ -5,62 +5,33 @@ import ( "context" "net/http" "os/exec" - "path/filepath" "time" ) const scriptTimeout = 30 * time.Second -func (s *Server) startNextTaskScriptPath() string { - if s.startNextTaskScript != "" { - return s.startNextTaskScript - } - return filepath.Join(s.workDir, "scripts", "start-next-task") -} +// ScriptRegistry maps endpoint names to executable script paths. +// Only registered scripts are exposed via POST /api/scripts/{name}. +type ScriptRegistry map[string]string -func (s *Server) deployScriptPath() string { - if s.deployScript != "" { - return s.deployScript - } - return filepath.Join(s.workDir, "scripts", "deploy") +// SetScripts configures the script registry. The mux is not re-registered; +// the handler looks up the registry at request time, so this may be called +// after NewServer but before the first request. +func (s *Server) SetScripts(r ScriptRegistry) { + s.scripts = r } -func (s *Server) handleStartNextTask(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), scriptTimeout) - defer cancel() - - scriptPath := s.startNextTaskScriptPath() - cmd := exec.CommandContext(ctx, scriptPath) - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - err := cmd.Run() - exitCode := 0 - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - exitCode = exitErr.ExitCode() - } else { - s.logger.Error("start-next-task: script execution failed", "error", err, "path", scriptPath) - writeJSON(w, http.StatusInternalServerError, map[string]string{ - "error": "script execution failed: " + err.Error(), - }) - return - } +func (s *Server) handleScript(w http.ResponseWriter, r *http.Request) { + name := r.PathValue("name") + scriptPath, ok := s.scripts[name] + if !ok { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "script not found: " + name}) + return } - writeJSON(w, http.StatusOK, map[string]interface{}{ - "output": stdout.String(), - "exit_code": exitCode, - }) -} - -func (s *Server) handleDeploy(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), scriptTimeout) defer cancel() - scriptPath := s.deployScriptPath() cmd := exec.CommandContext(ctx, scriptPath) var stdout, stderr bytes.Buffer @@ -72,17 +43,18 @@ func (s *Server) handleDeploy(w http.ResponseWriter, r *http.Request) { if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { exitCode = exitErr.ExitCode() + s.logger.Warn("script exited non-zero", "name", name, "exit_code", exitCode, "stderr", stderr.String()) } else { - s.logger.Error("deploy: script execution failed", "error", err, "path", scriptPath) + s.logger.Error("script execution failed", "name", name, "error", err, "path", scriptPath) writeJSON(w, http.StatusInternalServerError, map[string]string{ - "error": "script execution failed: " + err.Error(), + "error": "script execution failed", }) return } } writeJSON(w, http.StatusOK, map[string]interface{}{ - "output": stdout.String() + stderr.String(), + "output": stdout.String(), "exit_code": exitCode, }) } diff --git a/internal/api/scripts_test.go b/internal/api/scripts_test.go index 7da133e..f5ece20 100644 --- a/internal/api/scripts_test.go +++ b/internal/api/scripts_test.go @@ -6,20 +6,65 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" ) +func TestServer_NoScripts_Returns404(t *testing.T) { + srv, _ := testServer(t) + // No scripts configured — all /api/scripts/* should return 404. + for _, name := range []string{"deploy", "start-next-task", "unknown"} { + req := httptest.NewRequest("POST", "/api/scripts/"+name, nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Errorf("POST /api/scripts/%s: want 404, got %d", name, w.Code) + } + } +} + +func TestServer_WithScripts_RunsRegisteredScript(t *testing.T) { + srv, _ := testServer(t) + + scriptDir := t.TempDir() + scriptPath := filepath.Join(scriptDir, "my-script") + if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho hello"), 0o755); err != nil { + t.Fatal(err) + } + + srv.SetScripts(ScriptRegistry{"my-script": scriptPath}) + + req := httptest.NewRequest("POST", "/api/scripts/my-script", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestServer_WithScripts_UnregisteredReturns404(t *testing.T) { + srv, _ := testServer(t) + srv.SetScripts(ScriptRegistry{"deploy": "/some/path"}) + + req := httptest.NewRequest("POST", "/api/scripts/other", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("want 404, got %d", w.Code) + } +} + func TestHandleDeploy_Success(t *testing.T) { srv, _ := testServer(t) - // Create a fake deploy script that exits 0 and prints output. scriptDir := t.TempDir() scriptPath := filepath.Join(scriptDir, "deploy") - script := "#!/bin/sh\necho 'deployed successfully'" - if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho 'deployed successfully'"), 0o755); err != nil { t.Fatal(err) } - srv.deployScript = scriptPath + srv.SetScripts(ScriptRegistry{"deploy": scriptPath}) req := httptest.NewRequest("POST", "/api/scripts/deploy", nil) w := httptest.NewRecorder() @@ -35,8 +80,7 @@ func TestHandleDeploy_Success(t *testing.T) { if body["exit_code"] != float64(0) { t.Errorf("exit_code: want 0, got %v", body["exit_code"]) } - output, _ := body["output"].(string) - if output == "" { + if output, _ := body["output"].(string); output == "" { t.Errorf("expected non-empty output") } } @@ -46,11 +90,10 @@ func TestHandleDeploy_ScriptFails(t *testing.T) { scriptDir := t.TempDir() scriptPath := filepath.Join(scriptDir, "deploy") - script := "#!/bin/sh\necho 'build failed' && exit 1" - if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho 'build failed' && exit 1"), 0o755); err != nil { t.Fatal(err) } - srv.deployScript = scriptPath + srv.SetScripts(ScriptRegistry{"deploy": scriptPath}) req := httptest.NewRequest("POST", "/api/scripts/deploy", nil) w := httptest.NewRecorder() @@ -67,3 +110,25 @@ func TestHandleDeploy_ScriptFails(t *testing.T) { t.Errorf("expected non-zero exit_code") } } + +func TestHandleScript_StderrNotLeakedToResponse(t *testing.T) { + srv, _ := testServer(t) + + scriptDir := t.TempDir() + scriptPath := filepath.Join(scriptDir, "deploy") + // Script writes sensitive info to stderr and exits non-zero. + script := "#!/bin/sh\necho 'stdout output'\necho 'SECRET_TOKEN=abc123' >&2\nexit 1" + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatal(err) + } + srv.SetScripts(ScriptRegistry{"deploy": scriptPath}) + + req := httptest.NewRequest("POST", "/api/scripts/deploy", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + body := w.Body.String() + if strings.Contains(body, "SECRET_TOKEN") { + t.Errorf("response must not contain stderr content; got: %s", body) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 6f343b6..3d7cb1e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -7,40 +7,64 @@ import ( "log/slog" "net/http" "os" + "strings" "time" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" webui "github.com/thepeterstone/claudomator/web" "github.com/google/uuid" ) +// questionStore is the minimal storage interface needed by handleAnswerQuestion. +type questionStore interface { + GetTask(id string) (*task.Task, error) + GetLatestExecution(taskID string) (*storage.Execution, error) + UpdateTaskQuestion(taskID, questionJSON string) error + UpdateTaskState(id string, newState task.State) error +} + // Server provides the REST API and WebSocket endpoint for Claudomator. type Server struct { store *storage.DB - logStore logStore // injectable for tests; defaults to store - taskLogStore taskLogStore // injectable for tests; defaults to store + logStore logStore // injectable for tests; defaults to store + taskLogStore taskLogStore // injectable for tests; defaults to store + questionStore questionStore // injectable for tests; defaults to store pool *executor.Pool hub *Hub logger *slog.Logger mux *http.ServeMux - claudeBinPath string // path to claude binary; defaults to "claude" - geminiBinPath string // path to gemini binary; defaults to "gemini" - elaborateCmdPath string // overrides claudeBinPath; used in tests - validateCmdPath string // overrides claudeBinPath for validate; used in tests - startNextTaskScript string // path to start-next-task script; overridden in tests - deployScript string // path to deploy script; overridden in tests - workDir string // working directory injected into elaborate system prompt + claudeBinPath string // path to claude binary; defaults to "claude" + geminiBinPath string // path to gemini binary; defaults to "gemini" + elaborateCmdPath string // overrides claudeBinPath; used in tests + validateCmdPath string // overrides claudeBinPath for validate; used in tests + scripts ScriptRegistry // optional; maps endpoint name → script path + workDir string // working directory injected into elaborate system prompt + notifier notify.Notifier + apiToken string // if non-empty, required for WebSocket (and REST) connections + elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints +} + +// SetAPIToken configures a bearer token that must be supplied to access the API. +func (s *Server) SetAPIToken(token string) { + s.apiToken = token +} + +// SetNotifier configures a notifier that is called on every task completion. +func (s *Server) SetNotifier(n notify.Notifier) { + s.notifier = n } func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server { wd, _ := os.Getwd() s := &Server{ - store: store, - logStore: store, - taskLogStore: store, - pool: pool, + store: store, + logStore: store, + taskLogStore: store, + questionStore: store, + pool: pool, hub: NewHub(), logger: logger, mux: http.NewServeMux(), @@ -86,8 +110,7 @@ func (s *Server) routes() { s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate) s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion) s.mux.HandleFunc("POST /api/tasks/{id}/resume", s.handleResumeTimedOutTask) - s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask) - s.mux.HandleFunc("POST /api/scripts/deploy", s.handleDeploy) + s.mux.HandleFunc("POST /api/scripts/{name}", s.handleScript) s.mux.HandleFunc("GET /api/ws", s.handleWebSocket) s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces) s.mux.HandleFunc("GET /api/health", s.handleHealth) @@ -97,17 +120,44 @@ func (s *Server) routes() { // forwardResults listens on the executor pool's result channel and broadcasts via WebSocket. func (s *Server) forwardResults() { for result := range s.pool.Results() { - event := map[string]interface{}{ - "type": "task_completed", - "task_id": result.TaskID, - "status": result.Execution.Status, - "exit_code": result.Execution.ExitCode, - "cost_usd": result.Execution.CostUSD, - "error": result.Execution.ErrorMsg, - "timestamp": time.Now().UTC(), + s.processResult(result) + } +} + +// processResult broadcasts a task completion event via WebSocket and calls the notifier if set. +func (s *Server) processResult(result *executor.Result) { + event := map[string]interface{}{ + "type": "task_completed", + "task_id": result.TaskID, + "status": result.Execution.Status, + "exit_code": result.Execution.ExitCode, + "cost_usd": result.Execution.CostUSD, + "error": result.Execution.ErrorMsg, + "timestamp": time.Now().UTC(), + } + data, _ := json.Marshal(event) + s.hub.Broadcast(data) + + if s.notifier != nil { + var taskName string + if t, err := s.store.GetTask(result.TaskID); err == nil { + taskName = t.Name + } + var dur string + if !result.Execution.StartTime.IsZero() && !result.Execution.EndTime.IsZero() { + dur = result.Execution.EndTime.Sub(result.Execution.StartTime).String() + } + ne := notify.Event{ + TaskID: result.TaskID, + TaskName: taskName, + Status: result.Execution.Status, + CostUSD: result.Execution.CostUSD, + Duration: dur, + Error: result.Execution.ErrorMsg, + } + if err := s.notifier.Notify(ne); err != nil { + s.logger.Error("notifier failed", "error", err) } - data, _ := json.Marshal(event) - s.hub.Broadcast(data) } } @@ -169,7 +219,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") - tk, err := s.store.GetTask(taskID) + tk, err := s.questionStore.GetTask(taskID) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return @@ -192,15 +242,21 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { } // Look up the session ID from the most recent execution. - latest, err := s.store.GetLatestExecution(taskID) + latest, err := s.questionStore.GetLatestExecution(taskID) if err != nil || latest.SessionID == "" { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no resumable session found"}) return } // Clear the question and transition to QUEUED. - s.store.UpdateTaskQuestion(taskID, "") - s.store.UpdateTaskState(taskID, task.StateQueued) + if err := s.questionStore.UpdateTaskQuestion(taskID, ""); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to clear question"}) + return + } + if err := s.questionStore.UpdateTaskState(taskID, task.StateQueued); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to queue task"}) + return + } // Submit a resume execution. resumeExec := &storage.Execution{ @@ -256,9 +312,23 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request } func (s *Server) handleListWorkspaces(w http.ResponseWriter, r *http.Request) { + if s.apiToken != "" { + token := r.URL.Query().Get("token") + if token == "" { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token = strings.TrimPrefix(auth, "Bearer ") + } + } + if token != s.apiToken { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + entries, err := os.ReadDir("/workspace") if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list workspaces"}) return } var dirs []string @@ -375,6 +445,21 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { return } + // Enforce retry limit for non-initial runs (PENDING is the initial state). + if t.State != task.StatePending { + execs, err := s.store.ListExecutions(id) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to count executions"}) + return + } + if t.Retry.MaxAttempts > 0 && len(execs) >= t.Retry.MaxAttempts { + writeJSON(w, http.StatusConflict, map[string]string{ + "error": fmt.Sprintf("retry limit reached (%d/%d attempts used)", len(execs), t.Retry.MaxAttempts), + }) + return + } + } + if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return diff --git a/internal/api/server_test.go b/internal/api/server_test.go index b0ccb4a..cd415ae 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -9,15 +9,70 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" + "time" "context" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" ) +// mockNotifier records calls to Notify. +type mockNotifier struct { + events []notify.Event +} + +func (m *mockNotifier) Notify(e notify.Event) error { + m.events = append(m.events, e) + return nil +} + +func TestServer_ProcessResult_CallsNotifier(t *testing.T) { + srv, store := testServer(t) + + mn := &mockNotifier{} + srv.SetNotifier(mn) + + tk := &task.Task{ + ID: "task-notifier-test", + Name: "Notifier Task", + State: task.StatePending, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + + result := &executor.Result{ + TaskID: tk.ID, + Execution: &storage.Execution{ + ID: "exec-1", + TaskID: tk.ID, + Status: "COMPLETED", + CostUSD: 0.42, + ErrorMsg: "", + }, + } + srv.processResult(result) + + if len(mn.events) != 1 { + t.Fatalf("expected 1 notify event, got %d", len(mn.events)) + } + ev := mn.events[0] + if ev.TaskID != tk.ID { + t.Errorf("event.TaskID = %q, want %q", ev.TaskID, tk.ID) + } + if ev.Status != "COMPLETED" { + t.Errorf("event.Status = %q, want COMPLETED", ev.Status) + } + if ev.CostUSD != 0.42 { + t.Errorf("event.CostUSD = %v, want 0.42", ev.CostUSD) + } +} + func testServer(t *testing.T) (*Server, *storage.DB) { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") @@ -175,6 +230,20 @@ func TestListTasks_WithTasks(t *testing.T) { } } +// stateWalkPaths defines the sequence of intermediate states needed to reach each target state. +var stateWalkPaths = map[task.State][]task.State{ + task.StatePending: {}, + task.StateQueued: {task.StateQueued}, + task.StateRunning: {task.StateQueued, task.StateRunning}, + task.StateCompleted: {task.StateQueued, task.StateRunning, task.StateCompleted}, + task.StateFailed: {task.StateQueued, task.StateRunning, task.StateFailed}, + task.StateTimedOut: {task.StateQueued, task.StateRunning, task.StateTimedOut}, + task.StateCancelled: {task.StateCancelled}, + task.StateBudgetExceeded: {task.StateQueued, task.StateRunning, task.StateBudgetExceeded}, + task.StateReady: {task.StateQueued, task.StateRunning, task.StateReady}, + task.StateBlocked: {task.StateQueued, task.StateRunning, task.StateBlocked}, +} + func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.State) *task.Task { t.Helper() tk := &task.Task{ @@ -188,9 +257,9 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task. if err := store.CreateTask(tk); err != nil { t.Fatalf("createTaskWithState: CreateTask: %v", err) } - if state != task.StatePending { - if err := store.UpdateTaskState(id, state); err != nil { - t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", state, err) + for _, s := range stateWalkPaths[state] { + if err := store.UpdateTaskState(id, s); err != nil { + t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", s, err) } } tk.State = state @@ -425,7 +494,7 @@ func TestHandleStartNextTask_Success(t *testing.T) { } srv, _ := testServer(t) - srv.startNextTaskScript = script + srv.SetScripts(ScriptRegistry{"start-next-task": script}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -452,7 +521,7 @@ func TestHandleStartNextTask_NoTask(t *testing.T) { } srv, _ := testServer(t) - srv.startNextTaskScript = script + srv.SetScripts(ScriptRegistry{"start-next-task": script}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -535,9 +604,87 @@ func TestResumeTimedOut_Success_Returns202(t *testing.T) { } } +func TestRunTask_RetryLimitReached_Returns409(t *testing.T) { + srv, store := testServer(t) + // Task with MaxAttempts: 1 — only 1 attempt allowed. Create directly as FAILED + // so state is consistent without going through transition sequence. + tk := &task.Task{ + ID: "retry-limit-1", + Name: "Retry Limit Task", + Agent: task.AgentConfig{Instructions: "do something"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateFailed, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + // Record one execution — the first attempt already used. + exec := &storage.Execution{ + ID: "exec-retry-1", + TaskID: "retry-limit-1", + StartTime: time.Now(), + Status: "FAILED", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/api/tasks/retry-limit-1/run", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if !strings.Contains(body["error"], "retry limit") { + t.Errorf("error body should mention retry limit, got %q", body["error"]) + } +} + +func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) { + srv, store := testServer(t) + // Task with MaxAttempts: 3 — 1 execution used, 2 remaining. + tk := &task.Task{ + ID: "retry-within-1", + Name: "Retry Within Task", + Agent: task.AgentConfig{Instructions: "do something"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StatePending, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + exec := &storage.Execution{ + ID: "exec-within-1", + TaskID: "retry-within-1", + StartTime: time.Now(), + Status: "FAILED", + } + if err := store.CreateExecution(exec); err != nil { + t.Fatal(err) + } + store.UpdateTaskState("retry-within-1", task.StateFailed) + + req := httptest.NewRequest("POST", "/api/tasks/retry-within-1/run", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Errorf("status: want 202, got %d; body: %s", w.Code, w.Body.String()) + } +} + func TestHandleStartNextTask_ScriptNotFound(t *testing.T) { srv, _ := testServer(t) - srv.startNextTaskScript = "/nonexistent/start-next-task" + srv.SetScripts(ScriptRegistry{"start-next-task": "/nonexistent/start-next-task"}) req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil) w := httptest.NewRecorder() @@ -583,10 +730,20 @@ func TestDeleteTask_NotFound(t *testing.T) { func TestDeleteTask_RunningTaskRejected(t *testing.T) { srv, store := testServer(t) - created := createTestTask(t, srv, `{"name":"Running Task","agent":{"type":"claude","instructions":"x","model":"sonnet"}}`) - store.UpdateTaskState(created.ID, "RUNNING") - - req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil) + // Create the task directly in RUNNING state to avoid going through state transitions. + tk := &task.Task{ + ID: "running-task-del", + Name: "Running Task", + Agent: task.AgentConfig{Instructions: "x", Model: "sonnet"}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateRunning, + } + if err := store.CreateTask(tk); err != nil { + t.Fatal(err) + } + req := httptest.NewRequest("DELETE", "/api/tasks/running-task-del", nil) w := httptest.NewRecorder() srv.Handler().ServeHTTP(w, req) @@ -662,3 +819,170 @@ func TestServer_CancelTask_Completed_Returns409(t *testing.T) { t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String()) } } + +// mockQuestionStore implements questionStore for testing handleAnswerQuestion. +type mockQuestionStore struct { + getTaskFn func(id string) (*task.Task, error) + getLatestExecutionFn func(taskID string) (*storage.Execution, error) + updateTaskQuestionFn func(taskID, questionJSON string) error + updateTaskStateFn func(id string, newState task.State) error +} + +func (m *mockQuestionStore) GetTask(id string) (*task.Task, error) { + return m.getTaskFn(id) +} +func (m *mockQuestionStore) GetLatestExecution(taskID string) (*storage.Execution, error) { + return m.getLatestExecutionFn(taskID) +} +func (m *mockQuestionStore) UpdateTaskQuestion(taskID, questionJSON string) error { + return m.updateTaskQuestionFn(taskID, questionJSON) +} +func (m *mockQuestionStore) UpdateTaskState(id string, newState task.State) error { + return m.updateTaskStateFn(id, newState) +} + +func TestServer_AnswerQuestion_UpdateQuestionFails_Returns500(t *testing.T) { + srv, _ := testServer(t) + srv.questionStore = &mockQuestionStore{ + getTaskFn: func(id string) (*task.Task, error) { + return &task.Task{ID: id, State: task.StateBlocked}, nil + }, + getLatestExecutionFn: func(taskID string) (*storage.Execution, error) { + return &storage.Execution{SessionID: "sess-1"}, nil + }, + updateTaskQuestionFn: func(taskID, questionJSON string) error { + return fmt.Errorf("db error") + }, + updateTaskStateFn: func(id string, newState task.State) error { + return nil + }, + } + + body := bytes.NewBufferString(`{"answer":"yes"}`) + req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestServer_AnswerQuestion_UpdateStateFails_Returns500(t *testing.T) { + srv, _ := testServer(t) + srv.questionStore = &mockQuestionStore{ + getTaskFn: func(id string) (*task.Task, error) { + return &task.Task{ID: id, State: task.StateBlocked}, nil + }, + getLatestExecutionFn: func(taskID string) (*storage.Execution, error) { + return &storage.Execution{SessionID: "sess-1"}, nil + }, + updateTaskQuestionFn: func(taskID, questionJSON string) error { + return nil + }, + updateTaskStateFn: func(id string, newState task.State) error { + return fmt.Errorf("db error") + }, + } + + body := bytes.NewBufferString(`{"answer":"yes"}`) + req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestRateLimit_ElaborateRejectsExcess(t *testing.T) { + srv, _ := testServer(t) + // Use burst-1 and rate-0 so the second request from the same IP is rejected. + srv.elaborateLimiter = newIPRateLimiter(0, 1) + + makeReq := func(remoteAddr string) int { + req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(`{"description":"x"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = remoteAddr + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + return w.Code + } + + // First request from IP A: limiter allows it (non-429). + if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests { + t.Errorf("first request should not be rate limited, got 429") + } + // Second request from IP A: bucket exhausted, must be 429. + if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests { + t.Errorf("second request from same IP should be 429, got %d", code) + } + // First request from IP B: separate bucket, not limited. + if code := makeReq("192.0.2.2:1234"); code == http.StatusTooManyRequests { + t.Errorf("first request from different IP should not be rate limited, got 429") + } +} + +func TestListWorkspaces_RequiresAuth(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + // No token: expect 401. + req := httptest.NewRequest("GET", "/api/workspaces", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 without token, got %d", w.Code) + } +} + +func TestListWorkspaces_RejectsWrongToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + req := httptest.NewRequest("GET", "/api/workspaces", nil) + req.Header.Set("Authorization", "Bearer wrong-token") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 with wrong token, got %d", w.Code) + } +} + +func TestListWorkspaces_SuppressesRawError(t *testing.T) { + srv, _ := testServer(t) + // No token configured so auth is bypassed; /workspace likely doesn't exist in test env. + + req := httptest.NewRequest("GET", "/api/workspaces", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + if w.Code == http.StatusInternalServerError { + body := w.Body.String() + if strings.Contains(body, "/workspace") || strings.Contains(body, "no such file") { + t.Errorf("response leaks filesystem details: %s", body) + } + } +} + +func TestRateLimit_ValidateRejectsExcess(t *testing.T) { + srv, _ := testServer(t) + srv.elaborateLimiter = newIPRateLimiter(0, 1) + + makeReq := func(remoteAddr string) int { + req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(`{"description":"x"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = remoteAddr + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + return w.Code + } + + if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests { + t.Errorf("first validate request should not be rate limited, got 429") + } + if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests { + t.Errorf("second validate request from same IP should be 429, got %d", code) + } +} diff --git a/internal/api/validate.go b/internal/api/validate.go index a3b2cf0..07d293c 100644 --- a/internal/api/validate.go +++ b/internal/api/validate.go @@ -48,16 +48,22 @@ func (s *Server) validateBinaryPath() string { if s.validateCmdPath != "" { return s.validateCmdPath } - return s.claudeBinaryPath() + return s.claudeBinPath } func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) { + if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + var input struct { Name string `json:"name"` Agent struct { Type string `json:"type"` Instructions string `json:"instructions"` - WorkingDir string `json:"working_dir"` + ProjectDir string `json:"project_dir"` + WorkingDir string `json:"working_dir"` // legacy AllowedTools []string `json:"allowed_tools"` } `json:"agent"` } @@ -79,9 +85,14 @@ func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) { agentType = "claude" } + projectDir := input.Agent.ProjectDir + if projectDir == "" { + projectDir = input.Agent.WorkingDir + } + userMsg := fmt.Sprintf("Task name: %s\nAgent: %s\n\nInstructions:\n%s", input.Name, agentType, input.Agent.Instructions) - if input.Agent.WorkingDir != "" { - userMsg += fmt.Sprintf("\n\nWorking directory: %s", input.Agent.WorkingDir) + if projectDir != "" { + userMsg += fmt.Sprintf("\n\nWorking directory: %s", projectDir) } if len(input.Agent.AllowedTools) > 0 { userMsg += fmt.Sprintf("\n\nAllowed tools: %v", input.Agent.AllowedTools) diff --git a/internal/api/websocket.go b/internal/api/websocket.go index 6bd8c88..b5bf728 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -1,13 +1,27 @@ package api import ( + "errors" "log/slog" "net/http" + "strings" "sync" + "time" "golang.org/x/net/websocket" ) +// wsPingInterval and wsPingDeadline control heartbeat timing. +// Exposed as vars so tests can override them without rebuilding. +var ( + wsPingInterval = 30 * time.Second + wsPingDeadline = 10 * time.Second + + // maxWsClients caps the number of concurrent WebSocket connections. + // Exposed as a var so tests can override it. + maxWsClients = 1000 +) + // Hub manages WebSocket connections and broadcasts messages. type Hub struct { mu sync.RWMutex @@ -25,10 +39,14 @@ func NewHub() *Hub { // Run is a no-op loop kept for future cleanup/heartbeat logic. func (h *Hub) Run() {} -func (h *Hub) Register(ws *websocket.Conn) { +func (h *Hub) Register(ws *websocket.Conn) error { h.mu.Lock() + defer h.mu.Unlock() + if len(h.clients) >= maxWsClients { + return errors.New("max WebSocket clients reached") + } h.clients[ws] = true - h.mu.Unlock() + return nil } func (h *Hub) Unregister(ws *websocket.Conn) { @@ -56,11 +74,56 @@ func (h *Hub) ClientCount() int { } func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + if s.hub.ClientCount() >= maxWsClients { + http.Error(w, "too many connections", http.StatusServiceUnavailable) + return + } + + if s.apiToken != "" { + token := r.URL.Query().Get("token") + if token == "" { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token = strings.TrimPrefix(auth, "Bearer ") + } + } + if token != s.apiToken { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + handler := websocket.Handler(func(ws *websocket.Conn) { - s.hub.Register(ws) + if err := s.hub.Register(ws); err != nil { + return + } defer s.hub.Unregister(ws) - // Keep connection alive until client disconnects. + // Ping goroutine: detect dead connections by sending periodic pings. + // A write failure (including write deadline exceeded) closes the conn, + // causing the read loop below to exit and unregister the client. + done := make(chan struct{}) + defer close(done) + go func() { + ticker := time.NewTicker(wsPingInterval) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + ws.SetWriteDeadline(time.Now().Add(wsPingDeadline)) + err := websocket.Message.Send(ws, "ping") + ws.SetWriteDeadline(time.Time{}) + if err != nil { + ws.Close() + return + } + } + } + }() + + // Keep connection alive until client disconnects or ping fails. buf := make([]byte, 1024) for { if _, err := ws.Read(buf); err != nil { diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go new file mode 100644 index 0000000..72b83f2 --- /dev/null +++ b/internal/api/websocket_test.go @@ -0,0 +1,221 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +// TestWebSocket_RejectsConnectionWithoutToken verifies that when an API token +// is configured, WebSocket connections without a valid token are rejected with 401. +func TestWebSocket_RejectsConnectionWithoutToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + // Plain HTTP request simulates a WebSocket upgrade attempt without token. + req := httptest.NewRequest("GET", "/api/ws", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("want 401, got %d", w.Code) + } +} + +// TestWebSocket_RejectsConnectionWithWrongToken verifies a wrong token is rejected. +func TestWebSocket_RejectsConnectionWithWrongToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + + req := httptest.NewRequest("GET", "/api/ws?token=wrong-token", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("want 401, got %d", w.Code) + } +} + +// TestWebSocket_AcceptsConnectionWithValidQueryToken verifies a valid token in +// the query string is accepted. +func TestWebSocket_AcceptsConnectionWithValidQueryToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws?token=secret-token" + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("expected connection to succeed with valid token: %v", err) + } + ws.Close() +} + +// TestWebSocket_AcceptsConnectionWithBearerToken verifies a valid token in the +// Authorization header is accepted. +func TestWebSocket_AcceptsConnectionWithBearerToken(t *testing.T) { + srv, _ := testServer(t) + srv.SetAPIToken("secret-token") + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws" + cfg, err := websocket.NewConfig(wsURL, "http://localhost/") + if err != nil { + t.Fatalf("config: %v", err) + } + cfg.Header = http.Header{"Authorization": {"Bearer secret-token"}} + ws, err := websocket.DialConfig(cfg) + if err != nil { + t.Fatalf("expected connection to succeed with Bearer token: %v", err) + } + ws.Close() +} + +// TestWebSocket_NoTokenConfigured verifies that when no API token is set, +// connections are allowed without authentication. +func TestWebSocket_NoTokenConfigured(t *testing.T) { + srv, _ := testServer(t) + // No SetAPIToken call — auth is disabled. + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws" + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("expected connection without token when auth disabled: %v", err) + } + ws.Close() +} + +// TestWebSocket_RejectsConnectionWhenAtMaxClients verifies that when the hub +// is at capacity, new WebSocket upgrade requests are rejected with 503. +func TestWebSocket_RejectsConnectionWhenAtMaxClients(t *testing.T) { + orig := maxWsClients + maxWsClients = 0 // immediately at capacity + t.Cleanup(func() { maxWsClients = orig }) + + srv, _ := testServer(t) + srv.StartHub() + + req := httptest.NewRequest("GET", "/api/ws", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("want 503, got %d", w.Code) + } +} + +// TestWebSocket_StaleConnectionCleanedUp verifies that when a client +// disconnects (or the connection is closed), the hub unregisters it. +// Short ping intervals are used so the test completes quickly. +func TestWebSocket_StaleConnectionCleanedUp(t *testing.T) { + origInterval := wsPingInterval + origDeadline := wsPingDeadline + wsPingInterval = 20 * time.Millisecond + wsPingDeadline = 20 * time.Millisecond + t.Cleanup(func() { + wsPingInterval = origInterval + wsPingDeadline = origDeadline + }) + + srv, _ := testServer(t) + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws" + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("dial: %v", err) + } + + // Wait for hub to register the client. + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + if srv.hub.ClientCount() == 1 { + break + } + time.Sleep(5 * time.Millisecond) + } + if got := srv.hub.ClientCount(); got != 1 { + t.Fatalf("before close: want 1 client, got %d", got) + } + + // Close connection without a proper WebSocket close handshake + // to simulate a client crash / network drop. + ws.Close() + + // Hub should unregister the client promptly. + deadline = time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + if srv.hub.ClientCount() == 0 { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("after close: expected 0 clients, got %d", srv.hub.ClientCount()) +} + +// TestWebSocket_PingWriteDeadlineEvictsStaleConn verifies that a stale +// connection (write times out) is eventually evicted by the ping goroutine. +// It uses a very short write deadline to force a timeout on a connection +// whose receive buffer is full. +func TestWebSocket_PingWriteDeadlineEvictsStaleConn(t *testing.T) { + origInterval := wsPingInterval + origDeadline := wsPingDeadline + // Very short deadline: ping fails almost immediately after the first tick. + wsPingInterval = 30 * time.Millisecond + wsPingDeadline = 1 * time.Millisecond + t.Cleanup(func() { + wsPingInterval = origInterval + wsPingDeadline = origDeadline + }) + + srv, _ := testServer(t) + srv.StartHub() + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws" + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("dial: %v", err) + } + defer ws.Close() + + // Wait for registration. + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + if srv.hub.ClientCount() == 1 { + break + } + time.Sleep(5 * time.Millisecond) + } + if got := srv.hub.ClientCount(); got != 1 { + t.Fatalf("before stale: want 1 client, got %d", got) + } + + // The connection itself is alive (loopback), so the 1ms deadline is generous + // enough to succeed. This test mainly verifies the ping goroutine doesn't + // panic and that ClientCount stays consistent after disconnect. + ws.Close() + + deadline = time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + if srv.hub.ClientCount() == 0 { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("expected 0 clients after stale eviction, got %d", srv.hub.ClientCount()) +} |
