summaryrefslogtreecommitdiff
path: root/internal/api/server.go
diff options
context:
space:
mode:
authorPeter Stone <thepeterstone@gmail.com>2026-03-08 20:40:31 +0000
committerPeter Stone <thepeterstone@gmail.com>2026-03-08 20:40:31 +0000
commit417034be7f745062901a940d1a021f6d85be496e (patch)
tree666956207b58c915090f6641891304156cf93670 /internal/api/server.go
parent181a37698410b68e00a885593b6f2b7acf21f4b4 (diff)
api: SetAPIToken, SetNotifier, questionStore, per-IP rate limiter
- Extract questionStore interface for testability of handleAnswerQuestion - Add SetAPIToken/SetNotifier methods for post-construction wiring - Extract processResult() from forwardResults() for direct testability - Add ipRateLimiter with token-bucket per IP; applied to /elaborate and /validate - Fix tests for running-task deletion and retry-limit that relied on invalid state transitions in setup Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal/api/server.go')
-rw-r--r--internal/api/server.go143
1 files changed, 114 insertions, 29 deletions
diff --git a/internal/api/server.go b/internal/api/server.go
index af4710b..833be8b 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -7,39 +7,63 @@ import (
"log/slog"
"net/http"
"os"
+ "strings"
"time"
"github.com/thepeterstone/claudomator/internal/executor"
+ "github.com/thepeterstone/claudomator/internal/notify"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/task"
webui "github.com/thepeterstone/claudomator/web"
"github.com/google/uuid"
)
+// questionStore is the minimal storage interface needed by handleAnswerQuestion.
+type questionStore interface {
+ GetTask(id string) (*task.Task, error)
+ GetLatestExecution(taskID string) (*storage.Execution, error)
+ UpdateTaskQuestion(taskID, questionJSON string) error
+ UpdateTaskState(id string, newState task.State) error
+}
+
// Server provides the REST API and WebSocket endpoint for Claudomator.
type Server struct {
store *storage.DB
- logStore logStore // injectable for tests; defaults to store
- taskLogStore taskLogStore // injectable for tests; defaults to store
+ logStore logStore // injectable for tests; defaults to store
+ taskLogStore taskLogStore // injectable for tests; defaults to store
+ questionStore questionStore // injectable for tests; defaults to store
pool *executor.Pool
hub *Hub
logger *slog.Logger
mux *http.ServeMux
- claudeBinPath string // path to claude binary; defaults to "claude"
- elaborateCmdPath string // overrides claudeBinPath; used in tests
- validateCmdPath string // overrides claudeBinPath for validate; used in tests
- startNextTaskScript string // path to start-next-task script; overridden in tests
- deployScript string // path to deploy script; overridden in tests
- workDir string // working directory injected into elaborate system prompt
+ claudeBinPath string // path to claude binary; defaults to "claude"
+ elaborateCmdPath string // overrides claudeBinPath; used in tests
+ validateCmdPath string // overrides claudeBinPath for validate; used in tests
+ scripts ScriptRegistry // optional; maps endpoint name → script path
+ workDir string // working directory injected into elaborate system prompt
+ notifier notify.Notifier
+ apiToken string // if non-empty, required for WebSocket (and REST) connections
+ elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints
+}
+
+// SetAPIToken configures a bearer token that must be supplied to access the API.
+func (s *Server) SetAPIToken(token string) {
+ s.apiToken = token
+}
+
+// SetNotifier configures a notifier that is called on every task completion.
+func (s *Server) SetNotifier(n notify.Notifier) {
+ s.notifier = n
}
func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath string) *Server {
wd, _ := os.Getwd()
s := &Server{
- store: store,
- logStore: store,
- taskLogStore: store,
- pool: pool,
+ store: store,
+ logStore: store,
+ taskLogStore: store,
+ questionStore: store,
+ pool: pool,
hub: NewHub(),
logger: logger,
mux: http.NewServeMux(),
@@ -84,8 +108,7 @@ func (s *Server) routes() {
s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate)
s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion)
s.mux.HandleFunc("POST /api/tasks/{id}/resume", s.handleResumeTimedOutTask)
- s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask)
- s.mux.HandleFunc("POST /api/scripts/deploy", s.handleDeploy)
+ s.mux.HandleFunc("POST /api/scripts/{name}", s.handleScript)
s.mux.HandleFunc("GET /api/ws", s.handleWebSocket)
s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces)
s.mux.HandleFunc("GET /api/health", s.handleHealth)
@@ -95,17 +118,44 @@ func (s *Server) routes() {
// forwardResults listens on the executor pool's result channel and broadcasts via WebSocket.
func (s *Server) forwardResults() {
for result := range s.pool.Results() {
- event := map[string]interface{}{
- "type": "task_completed",
- "task_id": result.TaskID,
- "status": result.Execution.Status,
- "exit_code": result.Execution.ExitCode,
- "cost_usd": result.Execution.CostUSD,
- "error": result.Execution.ErrorMsg,
- "timestamp": time.Now().UTC(),
+ s.processResult(result)
+ }
+}
+
+// processResult broadcasts a task completion event via WebSocket and calls the notifier if set.
+func (s *Server) processResult(result *executor.Result) {
+ event := map[string]interface{}{
+ "type": "task_completed",
+ "task_id": result.TaskID,
+ "status": result.Execution.Status,
+ "exit_code": result.Execution.ExitCode,
+ "cost_usd": result.Execution.CostUSD,
+ "error": result.Execution.ErrorMsg,
+ "timestamp": time.Now().UTC(),
+ }
+ data, _ := json.Marshal(event)
+ s.hub.Broadcast(data)
+
+ if s.notifier != nil {
+ var taskName string
+ if t, err := s.store.GetTask(result.TaskID); err == nil {
+ taskName = t.Name
+ }
+ var dur string
+ if !result.Execution.StartTime.IsZero() && !result.Execution.EndTime.IsZero() {
+ dur = result.Execution.EndTime.Sub(result.Execution.StartTime).String()
+ }
+ ne := notify.Event{
+ TaskID: result.TaskID,
+ TaskName: taskName,
+ Status: result.Execution.Status,
+ CostUSD: result.Execution.CostUSD,
+ Duration: dur,
+ Error: result.Execution.ErrorMsg,
+ }
+ if err := s.notifier.Notify(ne); err != nil {
+ s.logger.Error("notifier failed", "error", err)
}
- data, _ := json.Marshal(event)
- s.hub.Broadcast(data)
}
}
@@ -167,7 +217,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
taskID := r.PathValue("id")
- tk, err := s.store.GetTask(taskID)
+ tk, err := s.questionStore.GetTask(taskID)
if err != nil {
writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
return
@@ -190,15 +240,21 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) {
}
// Look up the session ID from the most recent execution.
- latest, err := s.store.GetLatestExecution(taskID)
+ latest, err := s.questionStore.GetLatestExecution(taskID)
if err != nil || latest.SessionID == "" {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no resumable session found"})
return
}
// Clear the question and transition to QUEUED.
- s.store.UpdateTaskQuestion(taskID, "")
- s.store.UpdateTaskState(taskID, task.StateQueued)
+ if err := s.questionStore.UpdateTaskQuestion(taskID, ""); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to clear question"})
+ return
+ }
+ if err := s.questionStore.UpdateTaskState(taskID, task.StateQueued); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to queue task"})
+ return
+ }
// Submit a resume execution.
resumeExec := &storage.Execution{
@@ -254,9 +310,23 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request
}
func (s *Server) handleListWorkspaces(w http.ResponseWriter, r *http.Request) {
+ if s.apiToken != "" {
+ token := r.URL.Query().Get("token")
+ if token == "" {
+ auth := r.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Bearer ") {
+ token = strings.TrimPrefix(auth, "Bearer ")
+ }
+ }
+ if token != s.apiToken {
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
+ return
+ }
+ }
+
entries, err := os.ReadDir("/workspace")
if err != nil {
- writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list workspaces"})
return
}
var dirs []string
@@ -370,6 +440,21 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
return
}
+ // Enforce retry limit for non-initial runs (PENDING is the initial state).
+ if t.State != task.StatePending {
+ execs, err := s.store.ListExecutions(id)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to count executions"})
+ return
+ }
+ if t.Retry.MaxAttempts > 0 && len(execs) >= t.Retry.MaxAttempts {
+ writeJSON(w, http.StatusConflict, map[string]string{
+ "error": fmt.Sprintf("retry limit reached (%d/%d attempts used)", len(execs), t.Retry.MaxAttempts),
+ })
+ return
+ }
+ }
+
if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return