diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-03-08 21:03:50 +0000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-03-08 21:03:50 +0000 |
| commit | 632ea5a44731af94b6238f330a3b5440906c8ae7 (patch) | |
| tree | d8c780412598d66b89ef390b5729e379fdfd9d5b /internal/api/server.go | |
| parent | 406247b14985ab57902e8e42898dc8cb8960290d (diff) | |
| parent | 93a4c852bf726b00e8014d385165f847763fa214 (diff) | |
merge: pull latest from master and resolve conflicts
- Resolve conflicts in API server, CLI, and executor.
- Maintain Gemini classification and assignment logic.
- Update UI to use generic agent config and project_dir.
- Fix ProjectDir/WorkingDir inconsistencies in Gemini runner.
- All tests passing after merge.
Diffstat (limited to 'internal/api/server.go')
| -rw-r--r-- | internal/api/server.go | 145 |
1 files changed, 115 insertions, 30 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index 6f343b6..3d7cb1e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -7,40 +7,64 @@ import ( "log/slog" "net/http" "os" + "strings" "time" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" webui "github.com/thepeterstone/claudomator/web" "github.com/google/uuid" ) +// questionStore is the minimal storage interface needed by handleAnswerQuestion. +type questionStore interface { + GetTask(id string) (*task.Task, error) + GetLatestExecution(taskID string) (*storage.Execution, error) + UpdateTaskQuestion(taskID, questionJSON string) error + UpdateTaskState(id string, newState task.State) error +} + // Server provides the REST API and WebSocket endpoint for Claudomator. type Server struct { store *storage.DB - logStore logStore // injectable for tests; defaults to store - taskLogStore taskLogStore // injectable for tests; defaults to store + logStore logStore // injectable for tests; defaults to store + taskLogStore taskLogStore // injectable for tests; defaults to store + questionStore questionStore // injectable for tests; defaults to store pool *executor.Pool hub *Hub logger *slog.Logger mux *http.ServeMux - claudeBinPath string // path to claude binary; defaults to "claude" - geminiBinPath string // path to gemini binary; defaults to "gemini" - elaborateCmdPath string // overrides claudeBinPath; used in tests - validateCmdPath string // overrides claudeBinPath for validate; used in tests - startNextTaskScript string // path to start-next-task script; overridden in tests - deployScript string // path to deploy script; overridden in tests - workDir string // working directory injected into elaborate system prompt + claudeBinPath string // path to claude binary; defaults to "claude" + geminiBinPath string // path to gemini binary; defaults to "gemini" + elaborateCmdPath string // overrides claudeBinPath; used in tests + validateCmdPath string // overrides claudeBinPath for validate; used in tests + scripts ScriptRegistry // optional; maps endpoint name → script path + workDir string // working directory injected into elaborate system prompt + notifier notify.Notifier + apiToken string // if non-empty, required for WebSocket (and REST) connections + elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints +} + +// SetAPIToken configures a bearer token that must be supplied to access the API. +func (s *Server) SetAPIToken(token string) { + s.apiToken = token +} + +// SetNotifier configures a notifier that is called on every task completion. +func (s *Server) SetNotifier(n notify.Notifier) { + s.notifier = n } func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server { wd, _ := os.Getwd() s := &Server{ - store: store, - logStore: store, - taskLogStore: store, - pool: pool, + store: store, + logStore: store, + taskLogStore: store, + questionStore: store, + pool: pool, hub: NewHub(), logger: logger, mux: http.NewServeMux(), @@ -86,8 +110,7 @@ func (s *Server) routes() { s.mux.HandleFunc("DELETE /api/templates/{id}", s.handleDeleteTemplate) s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion) s.mux.HandleFunc("POST /api/tasks/{id}/resume", s.handleResumeTimedOutTask) - s.mux.HandleFunc("POST /api/scripts/start-next-task", s.handleStartNextTask) - s.mux.HandleFunc("POST /api/scripts/deploy", s.handleDeploy) + s.mux.HandleFunc("POST /api/scripts/{name}", s.handleScript) s.mux.HandleFunc("GET /api/ws", s.handleWebSocket) s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces) s.mux.HandleFunc("GET /api/health", s.handleHealth) @@ -97,17 +120,44 @@ func (s *Server) routes() { // forwardResults listens on the executor pool's result channel and broadcasts via WebSocket. func (s *Server) forwardResults() { for result := range s.pool.Results() { - event := map[string]interface{}{ - "type": "task_completed", - "task_id": result.TaskID, - "status": result.Execution.Status, - "exit_code": result.Execution.ExitCode, - "cost_usd": result.Execution.CostUSD, - "error": result.Execution.ErrorMsg, - "timestamp": time.Now().UTC(), + s.processResult(result) + } +} + +// processResult broadcasts a task completion event via WebSocket and calls the notifier if set. +func (s *Server) processResult(result *executor.Result) { + event := map[string]interface{}{ + "type": "task_completed", + "task_id": result.TaskID, + "status": result.Execution.Status, + "exit_code": result.Execution.ExitCode, + "cost_usd": result.Execution.CostUSD, + "error": result.Execution.ErrorMsg, + "timestamp": time.Now().UTC(), + } + data, _ := json.Marshal(event) + s.hub.Broadcast(data) + + if s.notifier != nil { + var taskName string + if t, err := s.store.GetTask(result.TaskID); err == nil { + taskName = t.Name + } + var dur string + if !result.Execution.StartTime.IsZero() && !result.Execution.EndTime.IsZero() { + dur = result.Execution.EndTime.Sub(result.Execution.StartTime).String() + } + ne := notify.Event{ + TaskID: result.TaskID, + TaskName: taskName, + Status: result.Execution.Status, + CostUSD: result.Execution.CostUSD, + Duration: dur, + Error: result.Execution.ErrorMsg, + } + if err := s.notifier.Notify(ne); err != nil { + s.logger.Error("notifier failed", "error", err) } - data, _ := json.Marshal(event) - s.hub.Broadcast(data) } } @@ -169,7 +219,7 @@ func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") - tk, err := s.store.GetTask(taskID) + tk, err := s.questionStore.GetTask(taskID) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return @@ -192,15 +242,21 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { } // Look up the session ID from the most recent execution. - latest, err := s.store.GetLatestExecution(taskID) + latest, err := s.questionStore.GetLatestExecution(taskID) if err != nil || latest.SessionID == "" { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no resumable session found"}) return } // Clear the question and transition to QUEUED. - s.store.UpdateTaskQuestion(taskID, "") - s.store.UpdateTaskState(taskID, task.StateQueued) + if err := s.questionStore.UpdateTaskQuestion(taskID, ""); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to clear question"}) + return + } + if err := s.questionStore.UpdateTaskState(taskID, task.StateQueued); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to queue task"}) + return + } // Submit a resume execution. resumeExec := &storage.Execution{ @@ -256,9 +312,23 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request } func (s *Server) handleListWorkspaces(w http.ResponseWriter, r *http.Request) { + if s.apiToken != "" { + token := r.URL.Query().Get("token") + if token == "" { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token = strings.TrimPrefix(auth, "Bearer ") + } + } + if token != s.apiToken { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + entries, err := os.ReadDir("/workspace") if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list workspaces"}) return } var dirs []string @@ -375,6 +445,21 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { return } + // Enforce retry limit for non-initial runs (PENDING is the initial state). + if t.State != task.StatePending { + execs, err := s.store.ListExecutions(id) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to count executions"}) + return + } + if t.Retry.MaxAttempts > 0 && len(execs) >= t.Retry.MaxAttempts { + writeJSON(w, http.StatusConflict, map[string]string{ + "error": fmt.Sprintf("retry limit reached (%d/%d attempts used)", len(execs), t.Retry.MaxAttempts), + }) + return + } + } + if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return |
