summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-08 21:03:50 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-08 21:03:50 +0000
commit632ea5a44731af94b6238f330a3b5440906c8ae7 (patch)
treed8c780412598d66b89ef390b5729e379fdfd9d5b /internal/api
parent406247b14985ab57902e8e42898dc8cb8960290d (diff)
parent93a4c852bf726b00e8014d385165f847763fa214 (diff)
merge: pull latest from master and resolve conflicts
- Resolve conflicts in API server, CLI, and executor. - Maintain Gemini classification and assignment logic. - Update UI to use generic agent config and project_dir. - Fix ProjectDir/WorkingDir inconsistencies in Gemini runner. - All tests passing after merge.
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/elaborate.go17
-rw-r--r--internal/api/elaborate_test.go2
-rw-r--r--internal/api/executions.go4
-rw-r--r--internal/api/executions_test.go20
-rw-r--r--internal/api/logs.go33
-rw-r--r--internal/api/logs_test.go71
-rw-r--r--internal/api/ratelimit.go99
-rw-r--r--internal/api/scripts.go64
-rw-r--r--internal/api/scripts_test.go83
-rw-r--r--internal/api/server.go145
-rw-r--r--internal/api/server_test.go344
-rw-r--r--internal/api/validate.go19
-rw-r--r--internal/api/websocket.go71
-rw-r--r--internal/api/websocket_test.go221
14 files changed, 1068 insertions, 125 deletions
diff --git a/internal/api/elaborate.go b/internal/api/elaborate.go
index 5ab9ff0..907cb98 100644
--- a/internal/api/elaborate.go
+++ b/internal/api/elaborate.go
@@ -14,9 +14,9 @@ import (
const elaborateTimeout = 30 * time.Second
func buildElaboratePrompt(workDir string) string {
- workDirLine := ` "working_dir": string — leave empty unless you have a specific reason to set it,`
+ workDirLine := ` "project_dir": string — leave empty unless you have a specific reason to set it,`
if workDir != "" {
- workDirLine = fmt.Sprintf(` "working_dir": string — use %q for tasks that operate on this codebase, empty string otherwise,`, workDir)
+ workDirLine = fmt.Sprintf(` "project_dir": string — use %q for tasks that operate on this codebase, empty string otherwise,`, workDir)
}
return `You are a task configuration assistant for Claudomator, an AI task runner that executes tasks by running Claude or Gemini as a subprocess.
@@ -55,7 +55,7 @@ type elaboratedAgent struct {
Type string `json:"type"`
Model string `json:"model"`
Instructions string `json:"instructions"`
- WorkingDir string `json:"working_dir"`
+ ProjectDir string `json:"project_dir"`
MaxBudgetUSD float64 `json:"max_budget_usd"`
AllowedTools []string `json:"allowed_tools"`
}
@@ -87,9 +87,14 @@ func (s *Server) claudeBinaryPath() string {
}
func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) {
+ if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) {
+ writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
+ return
+ }
+
var input struct {
Prompt string `json:"prompt"`
- WorkingDir string `json:"working_dir"`
+ ProjectDir string `json:"project_dir"`
}
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
@@ -101,8 +106,8 @@ func (s *Server) handleElaborateTask(w http.ResponseWriter, r *http.Request) {
}
workDir := s.workDir
- if input.WorkingDir != "" {
- workDir = input.WorkingDir
+ if input.ProjectDir != "" {
+ workDir = input.ProjectDir
}
ctx, cancel := context.WithTimeout(r.Context(), elaborateTimeout)
diff --git a/internal/api/elaborate_test.go b/internal/api/elaborate_test.go
index 4939701..b33ca11 100644
--- a/internal/api/elaborate_test.go
+++ b/internal/api/elaborate_test.go
@@ -57,7 +57,7 @@ func TestElaborateTask_Success(t *testing.T) {
Type: "claude",
Model: "sonnet",
Instructions: "Run go test -race ./... and report results.",
- WorkingDir: "",
+ ProjectDir: "",
MaxBudgetUSD: 0.5,
AllowedTools: []string{"Bash"},
},
diff --git a/internal/api/executions.go b/internal/api/executions.go
index d9214c0..114425e 100644
--- a/internal/api/executions.go
+++ b/internal/api/executions.go
@@ -21,12 +21,16 @@ func (s *Server) handleListRecentExecutions(w http.ResponseWriter, r *http.Reque
}
}
+ const maxLimit = 1000
limit := 50
if v := r.URL.Query().Get("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
limit = n
}
}
+ if limit > maxLimit {
+ limit = maxLimit
+ }
taskID := r.URL.Query().Get("task_id")
diff --git a/internal/api/executions_test.go b/internal/api/executions_test.go
index a2bba21..45548ad 100644
--- a/internal/api/executions_test.go
+++ b/internal/api/executions_test.go
@@ -258,6 +258,26 @@ func TestGetExecutionLog_FollowSSEHeaders(t *testing.T) {
}
}
+func TestListRecentExecutions_LimitClamped(t *testing.T) {
+ srv, _ := testServer(t)
+
+ req := httptest.NewRequest("GET", "/api/executions?limit=10000000", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("status: want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+ // The handler should not pass limit > maxLimit to the store.
+ // We verify indirectly: if the query param is accepted without error and
+ // does not cause a panic or 500, the clamp is in effect.
+ // A direct assertion requires a mock store; here we check the response is valid.
+ var execs []storage.RecentExecution
+ if err := json.NewDecoder(w.Body).Decode(&execs); err != nil {
+ t.Fatalf("decoding response: %v", err)
+ }
+}
+
func TestListTasks_ReturnsStateField(t *testing.T) {
srv, store := testServer(t)
createTaskWithState(t, store, "state-check", task.StateRunning)
diff --git a/internal/api/logs.go b/internal/api/logs.go
index 1ba4b00..4e63489 100644
--- a/internal/api/logs.go
+++ b/internal/api/logs.go
@@ -36,9 +36,10 @@ var terminalStates = map[string]bool{
}
type logStreamMsg struct {
- Type string `json:"type"`
- Message *logAssistMsg `json:"message,omitempty"`
- CostUSD float64 `json:"cost_usd,omitempty"`
+ Type string `json:"type"`
+ Message *logAssistMsg `json:"message,omitempty"`
+ CostUSD float64 `json:"cost_usd,omitempty"`
+ TotalCostUSD float64 `json:"total_cost_usd,omitempty"`
}
type logAssistMsg struct {
@@ -258,29 +259,31 @@ func emitLogLine(w http.ResponseWriter, flusher http.Flusher, line []byte) {
return
}
for _, block := range msg.Message.Content {
- var event map[string]string
+ var data []byte
switch block.Type {
case "text":
- event = map[string]string{"type": "text", "content": block.Text}
+ data, _ = json.Marshal(map[string]string{"type": "text", "content": block.Text})
case "tool_use":
- summary := string(block.Input)
- if len(summary) > 80 {
- summary = summary[:80]
- }
- event = map[string]string{"type": "tool_use", "content": fmt.Sprintf("%s(%s)", block.Name, summary)}
+ data, _ = json.Marshal(struct {
+ Type string `json:"type"`
+ Name string `json:"name"`
+ Input json.RawMessage `json:"input,omitempty"`
+ }{Type: "tool_use", Name: block.Name, Input: block.Input})
default:
continue
}
- data, _ := json.Marshal(event)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
case "result":
- event := map[string]string{
- "type": "cost",
- "content": fmt.Sprintf("%g", msg.CostUSD),
+ cost := msg.TotalCostUSD
+ if cost == 0 {
+ cost = msg.CostUSD
}
- data, _ := json.Marshal(event)
+ data, _ := json.Marshal(struct {
+ Type string `json:"type"`
+ TotalCost float64 `json:"total_cost"`
+ }{Type: "cost", TotalCost: cost})
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
diff --git a/internal/api/logs_test.go b/internal/api/logs_test.go
index 52fa168..6c6be05 100644
--- a/internal/api/logs_test.go
+++ b/internal/api/logs_test.go
@@ -1,6 +1,7 @@
package api
import (
+ "encoding/json"
"errors"
"net/http"
"net/http/httptest"
@@ -293,6 +294,76 @@ func TestHandleStreamTaskLogs_TerminalExecution_EmitsEventsAndDone(t *testing.T)
}
}
+// TestEmitLogLine_ToolUse_EmitsNameField verifies that emitLogLine emits a tool_use SSE event
+// with a "name" field matching the tool name so the web UI can display it as "[ToolName]".
+func TestEmitLogLine_ToolUse_EmitsNameField(t *testing.T) {
+ line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","input":{"command":"ls -la"}}]}}`)
+
+ w := httptest.NewRecorder()
+ emitLogLine(w, w, line)
+
+ body := w.Body.String()
+ var found bool
+ for _, chunk := range strings.Split(body, "\n\n") {
+ chunk = strings.TrimSpace(chunk)
+ if !strings.HasPrefix(chunk, "data: ") {
+ continue
+ }
+ jsonStr := strings.TrimPrefix(chunk, "data: ")
+ var e map[string]interface{}
+ if err := json.Unmarshal([]byte(jsonStr), &e); err != nil {
+ continue
+ }
+ if e["type"] == "tool_use" {
+ if e["name"] != "Bash" {
+ t.Errorf("tool_use event name: want Bash, got %v", e["name"])
+ }
+ if e["input"] == nil {
+ t.Error("tool_use event input: expected non-nil")
+ }
+ found = true
+ }
+ }
+ if !found {
+ t.Errorf("no tool_use event found in SSE output:\n%s", body)
+ }
+}
+
+// TestEmitLogLine_Cost_EmitsTotalCostField verifies that emitLogLine emits a cost SSE event
+// with a numeric "total_cost" field so the web UI can display it correctly.
+func TestEmitLogLine_Cost_EmitsTotalCostField(t *testing.T) {
+ line := []byte(`{"type":"result","total_cost_usd":0.0042}`)
+
+ w := httptest.NewRecorder()
+ emitLogLine(w, w, line)
+
+ body := w.Body.String()
+ var found bool
+ for _, chunk := range strings.Split(body, "\n\n") {
+ chunk = strings.TrimSpace(chunk)
+ if !strings.HasPrefix(chunk, "data: ") {
+ continue
+ }
+ jsonStr := strings.TrimPrefix(chunk, "data: ")
+ var e map[string]interface{}
+ if err := json.Unmarshal([]byte(jsonStr), &e); err != nil {
+ continue
+ }
+ if e["type"] == "cost" {
+ if e["total_cost"] == nil {
+ t.Error("cost event total_cost: expected non-nil numeric field")
+ }
+ if v, ok := e["total_cost"].(float64); !ok || v != 0.0042 {
+ t.Errorf("cost event total_cost: want 0.0042, got %v", e["total_cost"])
+ }
+ found = true
+ }
+ }
+ if !found {
+ t.Errorf("no cost event found in SSE output:\n%s", body)
+ }
+}
+
// TestHandleStreamTaskLogs_RunningExecution_LiveTails verifies that a RUNNING execution is
// live-tailed and a done event is emitted once it transitions to a terminal state.
func TestHandleStreamTaskLogs_RunningExecution_LiveTails(t *testing.T) {
diff --git a/internal/api/ratelimit.go b/internal/api/ratelimit.go
new file mode 100644
index 0000000..089354c
--- /dev/null
+++ b/internal/api/ratelimit.go
@@ -0,0 +1,99 @@
+package api
+
+import (
+ "net"
+ "net/http"
+ "sync"
+ "time"
+)
+
+// ipRateLimiter provides per-IP token-bucket rate limiting.
+type ipRateLimiter struct {
+ mu sync.Mutex
+ limiters map[string]*tokenBucket
+ rate float64 // tokens replenished per second
+ burst int // maximum token capacity
+}
+
+// newIPRateLimiter creates a limiter with the given replenishment rate (tokens/sec)
+// and burst capacity. Use rate=0 to disable replenishment (tokens never refill).
+func newIPRateLimiter(rate float64, burst int) *ipRateLimiter {
+ return &ipRateLimiter{
+ limiters: make(map[string]*tokenBucket),
+ rate: rate,
+ burst: burst,
+ }
+}
+
+func (l *ipRateLimiter) allow(ip string) bool {
+ l.mu.Lock()
+ b, ok := l.limiters[ip]
+ if !ok {
+ b = &tokenBucket{
+ tokens: float64(l.burst),
+ capacity: float64(l.burst),
+ rate: l.rate,
+ lastTime: time.Now(),
+ }
+ l.limiters[ip] = b
+ }
+ l.mu.Unlock()
+ return b.allow()
+}
+
+// middleware wraps h with per-IP rate limiting, returning 429 when exceeded.
+func (l *ipRateLimiter) middleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ip := realIP(r)
+ if !l.allow(ip) {
+ writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
+ return
+ }
+ next.ServeHTTP(w, r)
+ })
+}
+
+// tokenBucket is a simple token-bucket rate limiter for a single key.
+type tokenBucket struct {
+ mu sync.Mutex
+ tokens float64
+ capacity float64
+ rate float64 // tokens per second
+ lastTime time.Time
+}
+
+func (b *tokenBucket) allow() bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ now := time.Now()
+ if !b.lastTime.IsZero() {
+ elapsed := now.Sub(b.lastTime).Seconds()
+ b.tokens = min(b.capacity, b.tokens+elapsed*b.rate)
+ }
+ b.lastTime = now
+ if b.tokens >= 1.0 {
+ b.tokens--
+ return true
+ }
+ return false
+}
+
+// realIP extracts the client's real IP from a request.
+func realIP(r *http.Request) string {
+ if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
+ for i, c := range xff {
+ if c == ',' {
+ return xff[:i]
+ }
+ }
+ return xff
+ }
+ if xri := r.Header.Get("X-Real-IP"); xri != "" {
+ return xri
+ }
+ host, _, err := net.SplitHostPort(r.RemoteAddr)
+ if err != nil {
+ return r.RemoteAddr
+ }
+ return host
+}
diff --git a/internal/api/scripts.go b/internal/api/scripts.go
index 9afbb75..822bd32 100644
--- a/internal/api/scripts.go
+++ b/internal/api/scripts.go
@@ -5,62 +5,33 @@ import (
"context"
"net/http"
"os/exec"
- "path/filepath"
"time"
)
const scriptTimeout = 30 * time.Second
-func (s *Server) startNextTaskScriptPath() string {
- if s.startNextTaskScript != "" {
- return s.startNextTaskScript
- }
- return filepath.Join(s.workDir, "scripts", "start-next-task")
-}
+// ScriptRegistry maps endpoint names to executable script paths.
+// Only registered scripts are exposed via POST /api/scripts/{name}.
+type ScriptRegistry map[string]string
-func (s *Server) deployScriptPath() string {
- if s.deployScript != "" {
- return s.deployScript
- }
- return filepath.Join(s.workDir, "scripts", "deploy")
+// SetScripts configures the script registry. The mux is not re-registered;
+// the handler looks up the registry at request time, so this may be called
+// after NewServer but before the first request.
+func (s *Server) SetScripts(r ScriptRegistry) {
+ s.scripts = r
}
-func (s *Server) handleStartNextTask(w http.ResponseWriter, r *http.Request) {
- ctx, cancel := context.WithTimeout(r.Context(), scriptTimeout)
- defer cancel()
-
- scriptPath := s.startNextTaskScriptPath()
- cmd := exec.CommandContext(ctx, scriptPath)
-
- var stdout, stderr bytes.Buffer
- cmd.Stdout = &stdout
- cmd.Stderr = &stderr
-
- err := cmd.Run()
- exitCode := 0
- if err != nil {
- if exitErr, ok := err.(*exec.ExitError); ok {
- exitCode = exitErr.ExitCode()
- } else {
- s.logger.Error("start-next-task: script execution failed", "error", err, "path", scriptPath)
- writeJSON(w, http.StatusInternalServerError, map[string]string{
- "error": "script execution failed: " + err.Error(),
- })
- return
- }
+func (s *Server) handleScript(w http.ResponseWriter, r *http.Request) {
+ name := r.PathValue("name")
+ scriptPath, ok := s.scripts[name]
+ if !ok {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "script not found: " + name})
+ return
}
- writeJSON(w, http.StatusOK, map[string]interface{}{
- "output": stdout.String(),
- "exit_code": exitCode,
- })
-}
-
-func (s *Server) handleDeploy(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), scriptTimeout)
defer cancel()
- scriptPath := s.deployScriptPath()
cmd := exec.CommandContext(ctx, scriptPath)
var stdout, stderr bytes.Buffer
@@ -72,17 +43,18 @@ func (s *Server) handleDeploy(w http.ResponseWriter, r *http.Request) {
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
+ s.logger.Warn("script exited non-zero", "name", name, "exit_code", exitCode, "stderr", stderr.String())
} else {
- s.logger.Error("deploy: script execution failed", "error", err, "path", scriptPath)
+ s.logger.Error("script execution failed", "name", name, "error", err, "path", scriptPath)
writeJSON(w, http.StatusInternalServerError, map[string]string{
- "error": "script execution failed: " + err.Error(),
+ "error": "script execution failed",
})
return
}
}
writeJSON(w, http.StatusOK, map[string]interface{}{
- "output": stdout.String() + stderr.String(),
+ "output": stdout.String(),
"exit_code": exitCode,
})
}
diff --git a/internal/api/scripts_test.go b/internal/api/scripts_test.go
index 7da133e..f5ece20 100644
--- a/internal/api/scripts_test.go
+++ b/internal/api/scripts_test.go
@@ -6,20 +6,65 @@ import (
"net/http/httptest"
"os"
"path/filepath"
+ "strings"
"testing"
)
+func TestServer_NoScripts_Returns404(t *testing.T) {
+ srv, _ := testServer(t)
+ // No scripts configured — all /api/scripts/* should return 404.
+ for _, name := range []string{"deploy", "start-next-task", "unknown"} {
+ req := httptest.NewRequest("POST", "/api/scripts/"+name, nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusNotFound {
+ t.Errorf("POST /api/scripts/%s: want 404, got %d", name, w.Code)
+ }
+ }
+}
+
+func TestServer_WithScripts_RunsRegisteredScript(t *testing.T) {
+ srv, _ := testServer(t)
+
+ scriptDir := t.TempDir()
+ scriptPath := filepath.Join(scriptDir, "my-script")
+ if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho hello"), 0o755); err != nil {
+ t.Fatal(err)
+ }
+
+ srv.SetScripts(ScriptRegistry{"my-script": scriptPath})
+
+ req := httptest.NewRequest("POST", "/api/scripts/my-script", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("want 200, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestServer_WithScripts_UnregisteredReturns404(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetScripts(ScriptRegistry{"deploy": "/some/path"})
+
+ req := httptest.NewRequest("POST", "/api/scripts/other", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusNotFound {
+ t.Errorf("want 404, got %d", w.Code)
+ }
+}
+
func TestHandleDeploy_Success(t *testing.T) {
srv, _ := testServer(t)
- // Create a fake deploy script that exits 0 and prints output.
scriptDir := t.TempDir()
scriptPath := filepath.Join(scriptDir, "deploy")
- script := "#!/bin/sh\necho 'deployed successfully'"
- if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
+ if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho 'deployed successfully'"), 0o755); err != nil {
t.Fatal(err)
}
- srv.deployScript = scriptPath
+ srv.SetScripts(ScriptRegistry{"deploy": scriptPath})
req := httptest.NewRequest("POST", "/api/scripts/deploy", nil)
w := httptest.NewRecorder()
@@ -35,8 +80,7 @@ func TestHandleDeploy_Success(t *testing.T) {
if body["exit_code"] != float64(0) {
t.Errorf("exit_code: want 0, got %v", body["exit_code"])
}
- output, _ := body["output"].(string)
- if output == "" {
+ if output, _ := body["output"].(string); output == "" {
t.Errorf("expected non-empty output")
}
}
@@ -46,11 +90,10 @@ func TestHandleDeploy_ScriptFails(t *testing.T) {
scriptDir := t.TempDir()
scriptPath := filepath.Join(scriptDir, "deploy")
- script := "#!/bin/sh\necho 'build failed' && exit 1"
- if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
+ if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho 'build failed' && exit 1"), 0o755); err != nil {
t.Fatal(err)
}
- srv.deployScript = scriptPath
+ srv.SetScripts(ScriptRegistry{"deploy": scriptPath})
req := httptest.NewRequest("POST", "/api/scripts/deploy", nil)
w := httptest.NewRecorder()
@@ -67,3 +110,25 @@ func TestHandleDeploy_ScriptFails(t *testing.T) {
t.Errorf("expected non-zero exit_code")
}
}
+
+func TestHandleScript_StderrNotLeakedToResponse(t *testing.T) {
+ srv, _ := testServer(t)
+
+ scriptDir := t.TempDir()
+ scriptPath := filepath.Join(scriptDir, "deploy")
+ // Script writes sensitive info to stderr and exits non-zero.
+ script := "#!/bin/sh\necho 'stdout output'\necho 'SECRET_TOKEN=abc123' >&2\nexit 1"
+ if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
+ t.Fatal(err)
+ }
+ srv.SetScripts(ScriptRegistry{"deploy": scriptPath})
+
+ req := httptest.NewRequest("POST", "/api/scripts/deploy", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ body := w.Body.String()
+ if strings.Contains(body, "SECRET_TOKEN") {
+ t.Errorf("response must not contain stderr content; got: %s", body)
+ }
+}
diff --git a/internal/api/server.go b/internal/api/server.go
index 6f343b6..3d7cb1e 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -7,40 +7,64 @@ import (
"log/slog"
"net/http"
"os"
+ "strings"
"time"
"github.com/thepeterstone/claudomator/internal/executor"
+ "github.com/thepeterstone/claudomator/internal/notify"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
webui "github.com/thepeterstone/claudomator/web"
"github.com/google/uuid"
)
+// questionStore is the minimal storage interface needed by handleAnswerQuestion.
+type questionStore interface {
+ GetTask(id string) (*task.Task, error)
+ GetLatestExecution(taskID string) (*storage.Execution, error)
+ UpdateTaskQuestion(taskID, questionJSON string) error
+ UpdateTaskState(id string, newState task.State) error
+}
+
// Server provides the REST API and WebSocket endpoint for Claudomator.
type Server struct {
store *storage.DB
- logStore logStore // injectable for tests; defaults to store
- taskLogStore taskLogStore // injectable for tests; defaults to store
+ logStore logStore // injectable for tests; defaults to store
+ taskLogStore taskLogStore // injectable for tests; defaults to store
+ questionStore questionStore // injectable for tests; defaults to store
pool *executor.Pool
hub *Hub
logger *slog.Logger
mux *http.ServeMux
- claudeBinPath string // path to claude binary; defaults to "claude"
- geminiBinPath string // path to gemini binary; defaults to "gemini"
- elaborateCmdPath string // overrides claudeBinPath; used in tests
- validateCmdPath string // overrides claudeBinPath for validate; used in tests
- startNextTaskScript string // path to start-next-task script; overridden in tests
- deployScript string // path to deploy script; overridden in tests
- workDir string // working directory injected into elaborate system prompt
+ claudeBinPath string // path to claude binary; defaults to "claude"
+ geminiBinPath string // path to gemini binary; defaults to "gemini"
+ elaborateCmdPath string // overrides claudeBinPath; used in tests
+ validateCmdPath string // overrides claudeBinPath for validate; used in tests
+ scripts ScriptRegistry // optional; maps endpoint name → script path
+ workDir string // working directory injected into elaborate system prompt
+ notifier notify.Notifier
+ apiToken string // if non-empty, required for WebSocket (and REST) connections
+ elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints
+}
+
+// SetAPIToken configures a bearer token that must be supplied to access the API.
+func (s *Server) SetAPIToken(token string) {
+ s.apiToken = token
+}
+
+// SetNotifier configures a notifier that is called on every task completion.
+func (s *Server) SetNotifier(n notify.Notifier) {
+ s.notifier = n
}
func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server {
wd, _ := os.Getwd()
s := &Server{
- store: store,
- logStore: store,
- taskLogStore: store,
- pool: pool,
+ store: store,
+ logStore: store,
+ taskLogStore: store,
+ questionStore: store,
+ pool: pool,
hub: NewHub(),
logger: logger,
mux: http.NewServeMux(),
@@ -86,8 +110,7 @@ func (s *Server) routes() {
s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate)
s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion)
s.mux.HandleFunc("POST /api/tasks/{id}/resume", s.handleResumeTimedOutTask)
- s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask)
- s.mux.HandleFunc("POST /api/scripts/deploy", s.handleDeploy)
+ s.mux.HandleFunc("POST /api/scripts/{name}", s.handleScript)
s.mux.HandleFunc("GET /api/ws", s.handleWebSocket)
s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces)
s.mux.HandleFunc("GET /api/health", s.handleHealth)
@@ -97,17 +120,44 @@ func (s *Server) routes() {
// forwardResults listens on the executor pool's result channel and broadcasts via WebSocket.
func (s *Server) forwardResults() {
for result := range s.pool.Results() {
- event := map[string]interface{}{
- "type": "task_completed",
- "task_id": result.TaskID,
- "status": result.Execution.Status,
- "exit_code": result.Execution.ExitCode,
- "cost_usd": result.Execution.CostUSD,
- "error": result.Execution.ErrorMsg,
- "timestamp": time.Now().UTC(),
+ s.processResult(result)
+ }
+}
+
+// processResult broadcasts a task completion event via WebSocket and calls the notifier if set.
+func (s *Server) processResult(result *executor.Result) {
+ event := map[string]interface{}{
+ "type": "task_completed",
+ "task_id": result.TaskID,
+ "status": result.Execution.Status,
+ "exit_code": result.Execution.ExitCode,
+ "cost_usd": result.Execution.CostUSD,
+ "error": result.Execution.ErrorMsg,
+ "timestamp": time.Now().UTC(),
+ }
+ data, _ := json.Marshal(event)
+ s.hub.Broadcast(data)
+
+ if s.notifier != nil {
+ var taskName string
+ if t, err := s.store.GetTask(result.TaskID); err == nil {
+ taskName = t.Name
+ }
+ var dur string
+ if !result.Execution.StartTime.IsZero() && !result.Execution.EndTime.IsZero() {
+ dur = result.Execution.EndTime.Sub(result.Execution.StartTime).String()
+ }
+ ne := notify.Event{
+ TaskID: result.TaskID,
+ TaskName: taskName,
+ Status: result.Execution.Status,
+ CostUSD: result.Execution.CostUSD,
+ Duration: dur,
+ Error: result.Execution.ErrorMsg,
+ }
+ if err := s.notifier.Notify(ne); err != nil {
+ s.logger.Error("notifier failed", "error", err)
}
- data, _ := json.Marshal(event)
- s.hub.Broadcast(data)
}
}
@@ -169,7 +219,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
taskID := r.PathValue("id")
- tk, err := s.store.GetTask(taskID)
+ tk, err := s.questionStore.GetTask(taskID)
if err != nil {
writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
return
@@ -192,15 +242,21 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
}
// Look up the session ID from the most recent execution.
- latest, err := s.store.GetLatestExecution(taskID)
+ latest, err := s.questionStore.GetLatestExecution(taskID)
if err != nil || latest.SessionID == "" {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no resumable session found"})
return
}
// Clear the question and transition to QUEUED.
- s.store.UpdateTaskQuestion(taskID, "")
- s.store.UpdateTaskState(taskID, task.StateQueued)
+ if err := s.questionStore.UpdateTaskQuestion(taskID, ""); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to clear question"})
+ return
+ }
+ if err := s.questionStore.UpdateTaskState(taskID, task.StateQueued); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to queue task"})
+ return
+ }
// Submit a resume execution.
resumeExec := &storage.Execution{
@@ -256,9 +312,23 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request
}
func (s *Server) handleListWorkspaces(w http.ResponseWriter, r *http.Request) {
+ if s.apiToken != "" {
+ token := r.URL.Query().Get("token")
+ if token == "" {
+ auth := r.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Bearer ") {
+ token = strings.TrimPrefix(auth, "Bearer ")
+ }
+ }
+ if token != s.apiToken {
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
+ return
+ }
+ }
+
entries, err := os.ReadDir("/workspace")
if err != nil {
- writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list workspaces"})
return
}
var dirs []string
@@ -375,6 +445,21 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
return
}
+ // Enforce retry limit for non-initial runs (PENDING is the initial state).
+ if t.State != task.StatePending {
+ execs, err := s.store.ListExecutions(id)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to count executions"})
+ return
+ }
+ if t.Retry.MaxAttempts > 0 && len(execs) >= t.Retry.MaxAttempts {
+ writeJSON(w, http.StatusConflict, map[string]string{
+ "error": fmt.Sprintf("retry limit reached (%d/%d attempts used)", len(execs), t.Retry.MaxAttempts),
+ })
+ return
+ }
+ }
+
if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index b0ccb4a..cd415ae 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -9,15 +9,70 @@ import (
"net/http/httptest"
"os"
"path/filepath"
+ "strings"
"testing"
+ "time"
"context"
"github.com/thepeterstone/claudomator/internal/executor"
+ "github.com/thepeterstone/claudomator/internal/notify"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
)
+// mockNotifier records calls to Notify.
+type mockNotifier struct {
+ events []notify.Event
+}
+
+func (m *mockNotifier) Notify(e notify.Event) error {
+ m.events = append(m.events, e)
+ return nil
+}
+
+func TestServer_ProcessResult_CallsNotifier(t *testing.T) {
+ srv, store := testServer(t)
+
+ mn := &mockNotifier{}
+ srv.SetNotifier(mn)
+
+ tk := &task.Task{
+ ID: "task-notifier-test",
+ Name: "Notifier Task",
+ State: task.StatePending,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+
+ result := &executor.Result{
+ TaskID: tk.ID,
+ Execution: &storage.Execution{
+ ID: "exec-1",
+ TaskID: tk.ID,
+ Status: "COMPLETED",
+ CostUSD: 0.42,
+ ErrorMsg: "",
+ },
+ }
+ srv.processResult(result)
+
+ if len(mn.events) != 1 {
+ t.Fatalf("expected 1 notify event, got %d", len(mn.events))
+ }
+ ev := mn.events[0]
+ if ev.TaskID != tk.ID {
+ t.Errorf("event.TaskID = %q, want %q", ev.TaskID, tk.ID)
+ }
+ if ev.Status != "COMPLETED" {
+ t.Errorf("event.Status = %q, want COMPLETED", ev.Status)
+ }
+ if ev.CostUSD != 0.42 {
+ t.Errorf("event.CostUSD = %v, want 0.42", ev.CostUSD)
+ }
+}
+
func testServer(t *testing.T) (*Server, *storage.DB) {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test.db")
@@ -175,6 +230,20 @@ func TestListTasks_WithTasks(t *testing.T) {
}
}
+// stateWalkPaths defines the sequence of intermediate states needed to reach each target state.
+var stateWalkPaths = map[task.State][]task.State{
+ task.StatePending: {},
+ task.StateQueued: {task.StateQueued},
+ task.StateRunning: {task.StateQueued, task.StateRunning},
+ task.StateCompleted: {task.StateQueued, task.StateRunning, task.StateCompleted},
+ task.StateFailed: {task.StateQueued, task.StateRunning, task.StateFailed},
+ task.StateTimedOut: {task.StateQueued, task.StateRunning, task.StateTimedOut},
+ task.StateCancelled: {task.StateCancelled},
+ task.StateBudgetExceeded: {task.StateQueued, task.StateRunning, task.StateBudgetExceeded},
+ task.StateReady: {task.StateQueued, task.StateRunning, task.StateReady},
+ task.StateBlocked: {task.StateQueued, task.StateRunning, task.StateBlocked},
+}
+
func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.State) *task.Task {
t.Helper()
tk := &task.Task{
@@ -188,9 +257,9 @@ func createTaskWithState(t *testing.T, store *storage.DB, id string, state task.
if err := store.CreateTask(tk); err != nil {
t.Fatalf("createTaskWithState: CreateTask: %v", err)
}
- if state != task.StatePending {
- if err := store.UpdateTaskState(id, state); err != nil {
- t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", state, err)
+ for _, s := range stateWalkPaths[state] {
+ if err := store.UpdateTaskState(id, s); err != nil {
+ t.Fatalf("createTaskWithState: UpdateTaskState(%s): %v", s, err)
}
}
tk.State = state
@@ -425,7 +494,7 @@ func TestHandleStartNextTask_Success(t *testing.T) {
}
srv, _ := testServer(t)
- srv.startNextTaskScript = script
+ srv.SetScripts(ScriptRegistry{"start-next-task": script})
req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
w := httptest.NewRecorder()
@@ -452,7 +521,7 @@ func TestHandleStartNextTask_NoTask(t *testing.T) {
}
srv, _ := testServer(t)
- srv.startNextTaskScript = script
+ srv.SetScripts(ScriptRegistry{"start-next-task": script})
req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
w := httptest.NewRecorder()
@@ -535,9 +604,87 @@ func TestResumeTimedOut_Success_Returns202(t *testing.T) {
}
}
+func TestRunTask_RetryLimitReached_Returns409(t *testing.T) {
+ srv, store := testServer(t)
+ // Task with MaxAttempts: 1 — only 1 attempt allowed. Create directly as FAILED
+ // so state is consistent without going through transition sequence.
+ tk := &task.Task{
+ ID: "retry-limit-1",
+ Name: "Retry Limit Task",
+ Agent: task.AgentConfig{Instructions: "do something"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StateFailed,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ // Record one execution — the first attempt already used.
+ exec := &storage.Execution{
+ ID: "exec-retry-1",
+ TaskID: "retry-limit-1",
+ StartTime: time.Now(),
+ Status: "FAILED",
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatal(err)
+ }
+
+ req := httptest.NewRequest("POST", "/api/tasks/retry-limit-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusConflict {
+ t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String())
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if !strings.Contains(body["error"], "retry limit") {
+ t.Errorf("error body should mention retry limit, got %q", body["error"])
+ }
+}
+
+func TestRunTask_WithinRetryLimit_Returns202(t *testing.T) {
+ srv, store := testServer(t)
+ // Task with MaxAttempts: 3 — 1 execution used, 2 remaining.
+ tk := &task.Task{
+ ID: "retry-within-1",
+ Name: "Retry Within Task",
+ Agent: task.AgentConfig{Instructions: "do something"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StatePending,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ exec := &storage.Execution{
+ ID: "exec-within-1",
+ TaskID: "retry-within-1",
+ StartTime: time.Now(),
+ Status: "FAILED",
+ }
+ if err := store.CreateExecution(exec); err != nil {
+ t.Fatal(err)
+ }
+ store.UpdateTaskState("retry-within-1", task.StateFailed)
+
+ req := httptest.NewRequest("POST", "/api/tasks/retry-within-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Errorf("status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
func TestHandleStartNextTask_ScriptNotFound(t *testing.T) {
srv, _ := testServer(t)
- srv.startNextTaskScript = "/nonexistent/start-next-task"
+ srv.SetScripts(ScriptRegistry{"start-next-task": "/nonexistent/start-next-task"})
req := httptest.NewRequest("POST", "/api/scripts/start-next-task", nil)
w := httptest.NewRecorder()
@@ -583,10 +730,20 @@ func TestDeleteTask_NotFound(t *testing.T) {
func TestDeleteTask_RunningTaskRejected(t *testing.T) {
srv, store := testServer(t)
- created := createTestTask(t, srv, `{"name":"Running Task","agent":{"type":"claude","instructions":"x","model":"sonnet"}}`)
- store.UpdateTaskState(created.ID, "RUNNING")
-
- req := httptest.NewRequest("DELETE", "/api/tasks/"+created.ID, nil)
+ // Create the task directly in RUNNING state to avoid going through state transitions.
+ tk := &task.Task{
+ ID: "running-task-del",
+ Name: "Running Task",
+ Agent: task.AgentConfig{Instructions: "x", Model: "sonnet"},
+ Priority: task.PriorityNormal,
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StateRunning,
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ req := httptest.NewRequest("DELETE", "/api/tasks/running-task-del", nil)
w := httptest.NewRecorder()
srv.Handler().ServeHTTP(w, req)
@@ -662,3 +819,170 @@ func TestServer_CancelTask_Completed_Returns409(t *testing.T) {
t.Errorf("status: want 409, got %d; body: %s", w.Code, w.Body.String())
}
}
+
+// mockQuestionStore implements questionStore for testing handleAnswerQuestion.
+type mockQuestionStore struct {
+ getTaskFn func(id string) (*task.Task, error)
+ getLatestExecutionFn func(taskID string) (*storage.Execution, error)
+ updateTaskQuestionFn func(taskID, questionJSON string) error
+ updateTaskStateFn func(id string, newState task.State) error
+}
+
+func (m *mockQuestionStore) GetTask(id string) (*task.Task, error) {
+ return m.getTaskFn(id)
+}
+func (m *mockQuestionStore) GetLatestExecution(taskID string) (*storage.Execution, error) {
+ return m.getLatestExecutionFn(taskID)
+}
+func (m *mockQuestionStore) UpdateTaskQuestion(taskID, questionJSON string) error {
+ return m.updateTaskQuestionFn(taskID, questionJSON)
+}
+func (m *mockQuestionStore) UpdateTaskState(id string, newState task.State) error {
+ return m.updateTaskStateFn(id, newState)
+}
+
+func TestServer_AnswerQuestion_UpdateQuestionFails_Returns500(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.questionStore = &mockQuestionStore{
+ getTaskFn: func(id string) (*task.Task, error) {
+ return &task.Task{ID: id, State: task.StateBlocked}, nil
+ },
+ getLatestExecutionFn: func(taskID string) (*storage.Execution, error) {
+ return &storage.Execution{SessionID: "sess-1"}, nil
+ },
+ updateTaskQuestionFn: func(taskID, questionJSON string) error {
+ return fmt.Errorf("db error")
+ },
+ updateTaskStateFn: func(id string, newState task.State) error {
+ return nil
+ },
+ }
+
+ body := bytes.NewBufferString(`{"answer":"yes"}`)
+ req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body)
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestServer_AnswerQuestion_UpdateStateFails_Returns500(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.questionStore = &mockQuestionStore{
+ getTaskFn: func(id string) (*task.Task, error) {
+ return &task.Task{ID: id, State: task.StateBlocked}, nil
+ },
+ getLatestExecutionFn: func(taskID string) (*storage.Execution, error) {
+ return &storage.Execution{SessionID: "sess-1"}, nil
+ },
+ updateTaskQuestionFn: func(taskID, questionJSON string) error {
+ return nil
+ },
+ updateTaskStateFn: func(id string, newState task.State) error {
+ return fmt.Errorf("db error")
+ },
+ }
+
+ body := bytes.NewBufferString(`{"answer":"yes"}`)
+ req := httptest.NewRequest("POST", "/api/tasks/task-1/answer", body)
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("status: want 500, got %d; body: %s", w.Code, w.Body.String())
+ }
+}
+
+func TestRateLimit_ElaborateRejectsExcess(t *testing.T) {
+ srv, _ := testServer(t)
+ // Use burst-1 and rate-0 so the second request from the same IP is rejected.
+ srv.elaborateLimiter = newIPRateLimiter(0, 1)
+
+ makeReq := func(remoteAddr string) int {
+ req := httptest.NewRequest("POST", "/api/tasks/elaborate", bytes.NewBufferString(`{"description":"x"}`))
+ req.Header.Set("Content-Type", "application/json")
+ req.RemoteAddr = remoteAddr
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ return w.Code
+ }
+
+ // First request from IP A: limiter allows it (non-429).
+ if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests {
+ t.Errorf("first request should not be rate limited, got 429")
+ }
+ // Second request from IP A: bucket exhausted, must be 429.
+ if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests {
+ t.Errorf("second request from same IP should be 429, got %d", code)
+ }
+ // First request from IP B: separate bucket, not limited.
+ if code := makeReq("192.0.2.2:1234"); code == http.StatusTooManyRequests {
+ t.Errorf("first request from different IP should not be rate limited, got 429")
+ }
+}
+
+func TestListWorkspaces_RequiresAuth(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+
+ // No token: expect 401.
+ req := httptest.NewRequest("GET", "/api/workspaces", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("expected 401 without token, got %d", w.Code)
+ }
+}
+
+func TestListWorkspaces_RejectsWrongToken(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+
+ req := httptest.NewRequest("GET", "/api/workspaces", nil)
+ req.Header.Set("Authorization", "Bearer wrong-token")
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("expected 401 with wrong token, got %d", w.Code)
+ }
+}
+
+func TestListWorkspaces_SuppressesRawError(t *testing.T) {
+ srv, _ := testServer(t)
+ // No token configured so auth is bypassed; /workspace likely doesn't exist in test env.
+
+ req := httptest.NewRequest("GET", "/api/workspaces", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ if w.Code == http.StatusInternalServerError {
+ body := w.Body.String()
+ if strings.Contains(body, "/workspace") || strings.Contains(body, "no such file") {
+ t.Errorf("response leaks filesystem details: %s", body)
+ }
+ }
+}
+
+func TestRateLimit_ValidateRejectsExcess(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.elaborateLimiter = newIPRateLimiter(0, 1)
+
+ makeReq := func(remoteAddr string) int {
+ req := httptest.NewRequest("POST", "/api/tasks/validate", bytes.NewBufferString(`{"description":"x"}`))
+ req.Header.Set("Content-Type", "application/json")
+ req.RemoteAddr = remoteAddr
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+ return w.Code
+ }
+
+ if code := makeReq("192.0.2.1:1234"); code == http.StatusTooManyRequests {
+ t.Errorf("first validate request should not be rate limited, got 429")
+ }
+ if code := makeReq("192.0.2.1:1234"); code != http.StatusTooManyRequests {
+ t.Errorf("second validate request from same IP should be 429, got %d", code)
+ }
+}
diff --git a/internal/api/validate.go b/internal/api/validate.go
index a3b2cf0..07d293c 100644
--- a/internal/api/validate.go
+++ b/internal/api/validate.go
@@ -48,16 +48,22 @@ func (s *Server) validateBinaryPath() string {
if s.validateCmdPath != "" {
return s.validateCmdPath
}
- return s.claudeBinaryPath()
+ return s.claudeBinPath
}
func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) {
+ if s.elaborateLimiter != nil && !s.elaborateLimiter.allow(realIP(r)) {
+ writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"})
+ return
+ }
+
var input struct {
Name string `json:"name"`
Agent struct {
Type string `json:"type"`
Instructions string `json:"instructions"`
- WorkingDir string `json:"working_dir"`
+ ProjectDir string `json:"project_dir"`
+ WorkingDir string `json:"working_dir"` // legacy
AllowedTools []string `json:"allowed_tools"`
} `json:"agent"`
}
@@ -79,9 +85,14 @@ func (s *Server) handleValidateTask(w http.ResponseWriter, r *http.Request) {
agentType = "claude"
}
+ projectDir := input.Agent.ProjectDir
+ if projectDir == "" {
+ projectDir = input.Agent.WorkingDir
+ }
+
userMsg := fmt.Sprintf("Task name: %s\nAgent: %s\n\nInstructions:\n%s", input.Name, agentType, input.Agent.Instructions)
- if input.Agent.WorkingDir != "" {
- userMsg += fmt.Sprintf("\n\nWorking directory: %s", input.Agent.WorkingDir)
+ if projectDir != "" {
+ userMsg += fmt.Sprintf("\n\nWorking directory: %s", projectDir)
}
if len(input.Agent.AllowedTools) > 0 {
userMsg += fmt.Sprintf("\n\nAllowed tools: %v", input.Agent.AllowedTools)
diff --git a/internal/api/websocket.go b/internal/api/websocket.go
index 6bd8c88..b5bf728 100644
--- a/internal/api/websocket.go
+++ b/internal/api/websocket.go
@@ -1,13 +1,27 @@
package api
import (
+ "errors"
"log/slog"
"net/http"
+ "strings"
"sync"
+ "time"
"golang.org/x/net/websocket"
)
+// wsPingInterval and wsPingDeadline control heartbeat timing.
+// Exposed as vars so tests can override them without rebuilding.
+var (
+ wsPingInterval = 30 * time.Second
+ wsPingDeadline = 10 * time.Second
+
+ // maxWsClients caps the number of concurrent WebSocket connections.
+ // Exposed as a var so tests can override it.
+ maxWsClients = 1000
+)
+
// Hub manages WebSocket connections and broadcasts messages.
type Hub struct {
mu sync.RWMutex
@@ -25,10 +39,14 @@ func NewHub() *Hub {
// Run is a no-op loop kept for future cleanup/heartbeat logic.
func (h *Hub) Run() {}
-func (h *Hub) Register(ws *websocket.Conn) {
+func (h *Hub) Register(ws *websocket.Conn) error {
h.mu.Lock()
+ defer h.mu.Unlock()
+ if len(h.clients) >= maxWsClients {
+ return errors.New("max WebSocket clients reached")
+ }
h.clients[ws] = true
- h.mu.Unlock()
+ return nil
}
func (h *Hub) Unregister(ws *websocket.Conn) {
@@ -56,11 +74,56 @@ func (h *Hub) ClientCount() int {
}
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
+ if s.hub.ClientCount() >= maxWsClients {
+ http.Error(w, "too many connections", http.StatusServiceUnavailable)
+ return
+ }
+
+ if s.apiToken != "" {
+ token := r.URL.Query().Get("token")
+ if token == "" {
+ auth := r.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Bearer ") {
+ token = strings.TrimPrefix(auth, "Bearer ")
+ }
+ }
+ if token != s.apiToken {
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
+ return
+ }
+ }
+
handler := websocket.Handler(func(ws *websocket.Conn) {
- s.hub.Register(ws)
+ if err := s.hub.Register(ws); err != nil {
+ return
+ }
defer s.hub.Unregister(ws)
- // Keep connection alive until client disconnects.
+ // Ping goroutine: detect dead connections by sending periodic pings.
+ // A write failure (including write deadline exceeded) closes the conn,
+ // causing the read loop below to exit and unregister the client.
+ done := make(chan struct{})
+ defer close(done)
+ go func() {
+ ticker := time.NewTicker(wsPingInterval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-done:
+ return
+ case <-ticker.C:
+ ws.SetWriteDeadline(time.Now().Add(wsPingDeadline))
+ err := websocket.Message.Send(ws, "ping")
+ ws.SetWriteDeadline(time.Time{})
+ if err != nil {
+ ws.Close()
+ return
+ }
+ }
+ }
+ }()
+
+ // Keep connection alive until client disconnects or ping fails.
buf := make([]byte, 1024)
for {
if _, err := ws.Read(buf); err != nil {
diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go
new file mode 100644
index 0000000..72b83f2
--- /dev/null
+++ b/internal/api/websocket_test.go
@@ -0,0 +1,221 @@
+package api
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "golang.org/x/net/websocket"
+)
+
+// TestWebSocket_RejectsConnectionWithoutToken verifies that when an API token
+// is configured, WebSocket connections without a valid token are rejected with 401.
+func TestWebSocket_RejectsConnectionWithoutToken(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+
+ // Plain HTTP request simulates a WebSocket upgrade attempt without token.
+ req := httptest.NewRequest("GET", "/api/ws", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("want 401, got %d", w.Code)
+ }
+}
+
+// TestWebSocket_RejectsConnectionWithWrongToken verifies a wrong token is rejected.
+func TestWebSocket_RejectsConnectionWithWrongToken(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+
+ req := httptest.NewRequest("GET", "/api/ws?token=wrong-token", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("want 401, got %d", w.Code)
+ }
+}
+
+// TestWebSocket_AcceptsConnectionWithValidQueryToken verifies a valid token in
+// the query string is accepted.
+func TestWebSocket_AcceptsConnectionWithValidQueryToken(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+ srv.StartHub()
+ ts := httptest.NewServer(srv.Handler())
+ t.Cleanup(ts.Close)
+
+ wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws?token=secret-token"
+ ws, err := websocket.Dial(wsURL, "", "http://localhost/")
+ if err != nil {
+ t.Fatalf("expected connection to succeed with valid token: %v", err)
+ }
+ ws.Close()
+}
+
+// TestWebSocket_AcceptsConnectionWithBearerToken verifies a valid token in the
+// Authorization header is accepted.
+func TestWebSocket_AcceptsConnectionWithBearerToken(t *testing.T) {
+ srv, _ := testServer(t)
+ srv.SetAPIToken("secret-token")
+ srv.StartHub()
+ ts := httptest.NewServer(srv.Handler())
+ t.Cleanup(ts.Close)
+
+ wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws"
+ cfg, err := websocket.NewConfig(wsURL, "http://localhost/")
+ if err != nil {
+ t.Fatalf("config: %v", err)
+ }
+ cfg.Header = http.Header{"Authorization": {"Bearer secret-token"}}
+ ws, err := websocket.DialConfig(cfg)
+ if err != nil {
+ t.Fatalf("expected connection to succeed with Bearer token: %v", err)
+ }
+ ws.Close()
+}
+
+// TestWebSocket_NoTokenConfigured verifies that when no API token is set,
+// connections are allowed without authentication.
+func TestWebSocket_NoTokenConfigured(t *testing.T) {
+ srv, _ := testServer(t)
+ // No SetAPIToken call — auth is disabled.
+ srv.StartHub()
+ ts := httptest.NewServer(srv.Handler())
+ t.Cleanup(ts.Close)
+
+ wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws"
+ ws, err := websocket.Dial(wsURL, "", "http://localhost/")
+ if err != nil {
+ t.Fatalf("expected connection without token when auth disabled: %v", err)
+ }
+ ws.Close()
+}
+
+// TestWebSocket_RejectsConnectionWhenAtMaxClients verifies that when the hub
+// is at capacity, new WebSocket upgrade requests are rejected with 503.
+func TestWebSocket_RejectsConnectionWhenAtMaxClients(t *testing.T) {
+ orig := maxWsClients
+ maxWsClients = 0 // immediately at capacity
+ t.Cleanup(func() { maxWsClients = orig })
+
+ srv, _ := testServer(t)
+ srv.StartHub()
+
+ req := httptest.NewRequest("GET", "/api/ws", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusServiceUnavailable {
+ t.Errorf("want 503, got %d", w.Code)
+ }
+}
+
+// TestWebSocket_StaleConnectionCleanedUp verifies that when a client
+// disconnects (or the connection is closed), the hub unregisters it.
+// Short ping intervals are used so the test completes quickly.
+func TestWebSocket_StaleConnectionCleanedUp(t *testing.T) {
+ origInterval := wsPingInterval
+ origDeadline := wsPingDeadline
+ wsPingInterval = 20 * time.Millisecond
+ wsPingDeadline = 20 * time.Millisecond
+ t.Cleanup(func() {
+ wsPingInterval = origInterval
+ wsPingDeadline = origDeadline
+ })
+
+ srv, _ := testServer(t)
+ srv.StartHub()
+ ts := httptest.NewServer(srv.Handler())
+ t.Cleanup(ts.Close)
+
+ wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws"
+ ws, err := websocket.Dial(wsURL, "", "http://localhost/")
+ if err != nil {
+ t.Fatalf("dial: %v", err)
+ }
+
+ // Wait for hub to register the client.
+ deadline := time.Now().Add(200 * time.Millisecond)
+ for time.Now().Before(deadline) {
+ if srv.hub.ClientCount() == 1 {
+ break
+ }
+ time.Sleep(5 * time.Millisecond)
+ }
+ if got := srv.hub.ClientCount(); got != 1 {
+ t.Fatalf("before close: want 1 client, got %d", got)
+ }
+
+ // Close connection without a proper WebSocket close handshake
+ // to simulate a client crash / network drop.
+ ws.Close()
+
+ // Hub should unregister the client promptly.
+ deadline = time.Now().Add(500 * time.Millisecond)
+ for time.Now().Before(deadline) {
+ if srv.hub.ClientCount() == 0 {
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ t.Fatalf("after close: expected 0 clients, got %d", srv.hub.ClientCount())
+}
+
+// TestWebSocket_PingWriteDeadlineEvictsStaleConn verifies that a stale
+// connection (write times out) is eventually evicted by the ping goroutine.
+// It uses a very short write deadline to force a timeout on a connection
+// whose receive buffer is full.
+func TestWebSocket_PingWriteDeadlineEvictsStaleConn(t *testing.T) {
+ origInterval := wsPingInterval
+ origDeadline := wsPingDeadline
+ // Very short deadline: ping fails almost immediately after the first tick.
+ wsPingInterval = 30 * time.Millisecond
+ wsPingDeadline = 1 * time.Millisecond
+ t.Cleanup(func() {
+ wsPingInterval = origInterval
+ wsPingDeadline = origDeadline
+ })
+
+ srv, _ := testServer(t)
+ srv.StartHub()
+ ts := httptest.NewServer(srv.Handler())
+ t.Cleanup(ts.Close)
+
+ wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/api/ws"
+ ws, err := websocket.Dial(wsURL, "", "http://localhost/")
+ if err != nil {
+ t.Fatalf("dial: %v", err)
+ }
+ defer ws.Close()
+
+ // Wait for registration.
+ deadline := time.Now().Add(200 * time.Millisecond)
+ for time.Now().Before(deadline) {
+ if srv.hub.ClientCount() == 1 {
+ break
+ }
+ time.Sleep(5 * time.Millisecond)
+ }
+ if got := srv.hub.ClientCount(); got != 1 {
+ t.Fatalf("before stale: want 1 client, got %d", got)
+ }
+
+ // The connection itself is alive (loopback), so the 1ms deadline is generous
+ // enough to succeed. This test mainly verifies the ping goroutine doesn't
+ // panic and that ClientCount stays consistent after disconnect.
+ ws.Close()
+
+ deadline = time.Now().Add(500 * time.Millisecond)
+ for time.Now().Before(deadline) {
+ if srv.hub.ClientCount() == 0 {
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ t.Fatalf("expected 0 clients after stale eviction, got %d", srv.hub.ClientCount())
+}