package api import ( "context" "encoding/json" "fmt" "log/slog" "net/http" "os" "strings" "time" "github.com/thepeterstone/claudomator/internal/executor" "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/task" webui "github.com/thepeterstone/claudomator/web" "github.com/google/uuid" ) // questionStore is the minimal storage interface needed by handleAnswerQuestion. type questionStore interface { GetTask(id string) (*task.Task, error) GetLatestExecution(taskID string) (*storage.Execution, error) UpdateTaskQuestion(taskID, questionJSON string) error UpdateTaskState(id string, newState task.State) error } // Server provides the REST API and WebSocket endpoint for Claudomator. type Server struct { store *storage.DB logStore logStore // injectable for tests; defaults to store taskLogStore taskLogStore // injectable for tests; defaults to store questionStore questionStore // injectable for tests; defaults to store pool *executor.Pool hub *Hub logger *slog.Logger mux *http.ServeMux claudeBinPath string // path to claude binary; defaults to "claude" geminiBinPath string // path to gemini binary; defaults to "gemini" elaborateCmdPath string // overrides claudeBinPath; used in tests validateCmdPath string // overrides claudeBinPath for validate; used in tests scripts ScriptRegistry // optional; maps endpoint name → script path workDir string // working directory injected into elaborate system prompt workspaceRoot string // root directory for listing workspaces; defaults to "/workspace" notifier notify.Notifier apiToken string // if non-empty, required for WebSocket (and REST) connections elaborateLimiter *ipRateLimiter // per-IP rate limiter for elaborate/validate endpoints } // SetAPIToken configures a bearer token that must be supplied to access the API. func (s *Server) SetAPIToken(token string) { s.apiToken = token } // SetNotifier configures a notifier that is called on every task completion. func (s *Server) SetNotifier(n notify.Notifier) { s.notifier = n } // SetWorkspaceRoot configures the root directory used by handleListWorkspaces. func (s *Server) SetWorkspaceRoot(path string) { s.workspaceRoot = path } func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server { wd, _ := os.Getwd() s := &Server{ store: store, logStore: store, taskLogStore: store, questionStore: store, pool: pool, hub: NewHub(), logger: logger, mux: http.NewServeMux(), claudeBinPath: claudeBinPath, geminiBinPath: geminiBinPath, workDir: wd, workspaceRoot: "/workspace", } s.routes() return s } func (s *Server) Handler() http.Handler { return corsMiddleware(s.mux) } func (s *Server) StartHub() { go s.hub.Run() go s.forwardResults() } func (s *Server) routes() { s.mux.HandleFunc("POST /api/tasks/elaborate", s.handleElaborateTask) s.mux.HandleFunc("POST /api/tasks/validate", s.handleValidateTask) s.mux.HandleFunc("POST /api/tasks", s.handleCreateTask) s.mux.HandleFunc("GET /api/tasks", s.handleListTasks) s.mux.HandleFunc("GET /api/tasks/{id}", s.handleGetTask) s.mux.HandleFunc("POST /api/tasks/{id}/run", s.handleRunTask) s.mux.HandleFunc("POST /api/tasks/{id}/cancel", s.handleCancelTask) s.mux.HandleFunc("POST /api/tasks/{id}/accept", s.handleAcceptTask) s.mux.HandleFunc("POST /api/tasks/{id}/reject", s.handleRejectTask) s.mux.HandleFunc("DELETE /api/tasks/{id}", s.handleDeleteTask) s.mux.HandleFunc("GET /api/tasks/{id}/subtasks", s.handleListSubtasks) s.mux.HandleFunc("GET /api/tasks/{id}/executions", s.handleListExecutions) s.mux.HandleFunc("GET /api/executions", s.handleListRecentExecutions) s.mux.HandleFunc("GET /api/executions/{id}", s.handleGetExecution) s.mux.HandleFunc("GET /api/executions/{id}/log", s.handleGetExecutionLog) s.mux.HandleFunc("GET /api/tasks/{id}/logs/stream", s.handleStreamTaskLogs) s.mux.HandleFunc("GET /api/executions/{id}/logs/stream", s.handleStreamLogs) s.mux.HandleFunc("POST /api/tasks/{id}/answer", s.handleAnswerQuestion) s.mux.HandleFunc("POST /api/tasks/{id}/resume", s.handleResumeTimedOutTask) s.mux.HandleFunc("POST /api/scripts/{name}", s.handleScript) s.mux.HandleFunc("GET /api/ws", s.handleWebSocket) s.mux.HandleFunc("GET /api/workspaces", s.handleListWorkspaces) s.mux.HandleFunc("GET /api/health", s.handleHealth) s.mux.Handle("GET /", http.FileServerFS(webui.Files)) } // forwardResults listens on the executor pool's result channel and broadcasts via WebSocket. func (s *Server) forwardResults() { for result := range s.pool.Results() { s.processResult(result) } } // processResult broadcasts a task completion event via WebSocket and calls the notifier if set. func (s *Server) processResult(result *executor.Result) { event := map[string]interface{}{ "type": "task_completed", "task_id": result.TaskID, "status": result.Execution.Status, "exit_code": result.Execution.ExitCode, "cost_usd": result.Execution.CostUSD, "error": result.Execution.ErrorMsg, "timestamp": time.Now().UTC(), } data, _ := json.Marshal(event) s.hub.Broadcast(data) if s.notifier != nil { var taskName string if t, err := s.store.GetTask(result.TaskID); err == nil { taskName = t.Name } var dur string if !result.Execution.StartTime.IsZero() && !result.Execution.EndTime.IsZero() { dur = result.Execution.EndTime.Sub(result.Execution.StartTime).String() } ne := notify.Event{ TaskID: result.TaskID, TaskName: taskName, Status: result.Execution.Status, CostUSD: result.Execution.CostUSD, Duration: dur, Error: result.Execution.ErrorMsg, } if err := s.notifier.Notify(ne); err != nil { s.logger.Error("notifier failed", "error", err) } } } // BroadcastQuestion sends a task_question event to all WebSocket clients. func (s *Server) BroadcastQuestion(taskID, toolUseID string, questionData json.RawMessage) { event := map[string]interface{}{ "type": "task_question", "task_id": taskID, "question_id": toolUseID, "data": json.RawMessage(questionData), "timestamp": time.Now().UTC(), } data, _ := json.Marshal(event) s.hub.Broadcast(data) } func (s *Server) handleDeleteTask(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") t, err := s.store.GetTask(id) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return } if t.State == task.StateRunning || t.State == task.StateQueued { writeJSON(w, http.StatusConflict, map[string]string{"error": "cannot delete a running or queued task"}) return } if err := s.store.DeleteTask(id); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") tk, err := s.store.GetTask(taskID) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return } // If the task is actively running in the pool, cancel it there. if s.pool.Cancel(taskID) { writeJSON(w, http.StatusOK, map[string]string{"message": "task cancellation requested", "task_id": taskID}) return } // For non-running tasks (PENDING, QUEUED), transition directly to CANCELLED. if !task.ValidTransition(tk.State, task.StateCancelled) { writeJSON(w, http.StatusConflict, map[string]string{"error": "task cannot be cancelled from state " + string(tk.State)}) return } if err := s.store.UpdateTaskState(taskID, task.StateCancelled); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to cancel task"}) return } writeJSON(w, http.StatusOK, map[string]string{"message": "task cancelled", "task_id": taskID}) } func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") tk, err := s.questionStore.GetTask(taskID) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return } if tk.State != task.StateBlocked { writeJSON(w, http.StatusConflict, map[string]string{"error": "task is not blocked"}) return } var input struct { Answer string `json:"answer"` } if err := json.NewDecoder(r.Body).Decode(&input); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) return } if input.Answer == "" { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "answer is required"}) return } // Look up the session ID from the most recent execution. latest, err := s.questionStore.GetLatestExecution(taskID) if err != nil || latest.SessionID == "" { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no resumable session found"}) return } // Clear the question and transition to QUEUED. if err := s.questionStore.UpdateTaskQuestion(taskID, ""); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to clear question"}) return } if err := s.questionStore.UpdateTaskState(taskID, task.StateQueued); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to queue task"}) return } // Submit a resume execution. resumeExec := &storage.Execution{ ID: uuid.New().String(), TaskID: taskID, ResumeSessionID: latest.SessionID, ResumeAnswer: input.Answer, } if err := s.pool.SubmitResume(context.Background(), tk, resumeExec); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": err.Error()}) return } writeJSON(w, http.StatusOK, map[string]string{"message": "task queued for resume", "task_id": taskID}) } func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") tk, err := s.store.GetTask(taskID) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return } if tk.State != task.StateTimedOut { writeJSON(w, http.StatusConflict, map[string]string{"error": "task is not timed out"}) return } latest, err := s.store.GetLatestExecution(taskID) if err != nil || latest.SessionID == "" { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no resumable session found"}) return } s.store.UpdateTaskState(taskID, task.StateQueued) resumeExec := &storage.Execution{ ID: uuid.New().String(), TaskID: taskID, ResumeSessionID: latest.SessionID, ResumeAnswer: "Your previous execution timed out. Please continue where you left off and complete the task.", } if err := s.pool.SubmitResume(context.Background(), tk, resumeExec); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": err.Error()}) return } writeJSON(w, http.StatusAccepted, map[string]string{ "message": "task queued for resume", "task_id": taskID, }) } func (s *Server) handleListWorkspaces(w http.ResponseWriter, r *http.Request) { if s.apiToken != "" { token := r.URL.Query().Get("token") if token == "" { auth := r.Header.Get("Authorization") if strings.HasPrefix(auth, "Bearer ") { token = strings.TrimPrefix(auth, "Bearer ") } } if token != s.apiToken { http.Error(w, "unauthorized", http.StatusUnauthorized) return } } entries, err := os.ReadDir(s.workspaceRoot) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to list workspaces"}) return } var dirs []string for _, e := range entries { if e.IsDir() { dirs = append(dirs, s.workspaceRoot+"/"+e.Name()) } } writeJSON(w, http.StatusOK, dirs) } func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) } func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { var input struct { Name string `json:"name"` Description string `json:"description"` Agent task.AgentConfig `json:"agent"` Claude task.AgentConfig `json:"claude"` // legacy alias Timeout string `json:"timeout"` Priority string `json:"priority"` Tags []string `json:"tags"` ParentTaskID string `json:"parent_task_id"` } if err := json.NewDecoder(r.Body).Decode(&input); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) return } // Accept legacy "claude" key when "agent" is not provided. if input.Agent.Instructions == "" && input.Claude.Instructions != "" { input.Agent = input.Claude } now := time.Now().UTC() t := &task.Task{ ID: uuid.New().String(), Name: input.Name, Description: input.Description, Agent: input.Agent, Priority: task.Priority(input.Priority), Tags: input.Tags, DependsOn: []string{}, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, State: task.StatePending, CreatedAt: now, UpdatedAt: now, ParentTaskID: input.ParentTaskID, } if t.Agent.Type == "" { t.Agent.Type = "claude" } if t.Priority == "" { t.Priority = task.PriorityNormal } if t.Tags == nil { t.Tags = []string{} } if input.Timeout != "" { dur, err := time.ParseDuration(input.Timeout) if err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid timeout: " + err.Error()}) return } t.Timeout.Duration = dur } if err := task.Validate(t); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } if err := s.store.CreateTask(t); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } writeJSON(w, http.StatusCreated, t) } // validTaskStates is the set of all known task states for query param validation. var validTaskStates = map[task.State]bool{ task.StatePending: true, task.StateQueued: true, task.StateRunning: true, task.StateReady: true, task.StateCompleted: true, task.StateFailed: true, task.StateTimedOut: true, task.StateCancelled: true, task.StateBudgetExceeded: true, task.StateBlocked: true, } func (s *Server) handleListTasks(w http.ResponseWriter, r *http.Request) { filter := storage.TaskFilter{} if state := r.URL.Query().Get("state"); state != "" { ts := task.State(state) if !validTaskStates[ts] { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid state: " + state}) return } filter.State = ts } tasks, err := s.store.ListTasks(filter) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if tasks == nil { tasks = []*task.Task{} } writeJSON(w, http.StatusOK, tasks) } func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") t, err := s.store.GetTask(id) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return } writeJSON(w, http.StatusOK, t) } func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") t, err := s.store.GetTask(id) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return } if !task.ValidTransition(t.State, task.StateQueued) { writeJSON(w, http.StatusConflict, map[string]string{ "error": fmt.Sprintf("task cannot be queued from state %s", t.State), }) return } if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } t.State = task.StateQueued if err := s.pool.Submit(context.Background(), t); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)}) return } writeJSON(w, http.StatusAccepted, map[string]string{ "message": "task queued for execution", "task_id": id, }) } func (s *Server) handleAcceptTask(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") t, err := s.store.GetTask(id) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return } if !task.ValidTransition(t.State, task.StateCompleted) { writeJSON(w, http.StatusConflict, map[string]string{ "error": fmt.Sprintf("task cannot be accepted from state %s", t.State), }) return } if err := s.store.UpdateTaskState(id, task.StateCompleted); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } writeJSON(w, http.StatusOK, map[string]string{"message": "task accepted", "task_id": id}) } func (s *Server) handleRejectTask(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") t, err := s.store.GetTask(id) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) return } if !task.ValidTransition(t.State, task.StatePending) { writeJSON(w, http.StatusConflict, map[string]string{ "error": fmt.Sprintf("task cannot be rejected from state %s", t.State), }) return } var input struct { Comment string `json:"comment"` } if err := json.NewDecoder(r.Body).Decode(&input); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) return } if err := s.store.RejectTask(id, input.Comment); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } writeJSON(w, http.StatusOK, map[string]string{"message": "task rejected", "task_id": id}) } func (s *Server) handleListSubtasks(w http.ResponseWriter, r *http.Request) { parentID := r.PathValue("id") tasks, err := s.store.ListSubtasks(parentID) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if tasks == nil { tasks = []*task.Task{} } writeJSON(w, http.StatusOK, tasks) } func (s *Server) handleListExecutions(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") execs, err := s.store.ListExecutions(taskID) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if execs == nil { execs = []*storage.Execution{} } writeJSON(w, http.StatusOK, execs) } func (s *Server) handleGetExecution(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") exec, err := s.store.GetExecution(id) if err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": "execution not found"}) return } writeJSON(w, http.StatusOK, exec) } func writeJSON(w http.ResponseWriter, status int, v interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(v) } func corsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) }