diff options
Diffstat (limited to 'internal/api/server.go')
| -rw-r--r-- | internal/api/server.go | 225 |
1 files changed, 225 insertions, 0 deletions
diff --git a/internal/api/server.go b/internal/api/server.go new file mode 100644 index 0000000..cc5e6e5 --- /dev/null +++ b/internal/api/server.go @@ -0,0 +1,225 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/claudomator/claudomator/internal/executor" + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" + "github.com/google/uuid" +) + +// Server provides the REST API and WebSocket endpoint for Claudomator. +type Server struct { + store *storage.DB + pool *executor.Pool + hub *Hub + logger *slog.Logger + mux *http.ServeMux +} + +func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger) *Server { + s := &Server{ + store: store, + pool: pool, + hub: NewHub(), + logger: logger, + mux: http.NewServeMux(), + } + s.routes() + return s +} + +func (s *Server) Handler() http.Handler { + return corsMiddleware(s.mux) +} + +func (s *Server) StartHub() { + go s.hub.Run() + go s.forwardResults() +} + +func (s *Server) routes() { + s.mux.HandleFunc("POST /api/tasks", s.handleCreateTask) + s.mux.HandleFunc("GET /api/tasks", s.handleListTasks) + s.mux.HandleFunc("GET /api/tasks/{id}", s.handleGetTask) + s.mux.HandleFunc("POST /api/tasks/{id}/run", s.handleRunTask) + s.mux.HandleFunc("GET /api/tasks/{id}/executions", s.handleListExecutions) + s.mux.HandleFunc("GET /api/executions/{id}", s.handleGetExecution) + s.mux.HandleFunc("GET /api/ws", s.handleWebSocket) + s.mux.HandleFunc("GET /api/health", s.handleHealth) +} + +// forwardResults listens on the executor pool's result channel and broadcasts via WebSocket. +func (s *Server) forwardResults() { + for result := range s.pool.Results() { + event := map[string]interface{}{ + "type": "task_completed", + "task_id": result.TaskID, + "status": result.Execution.Status, + "exit_code": result.Execution.ExitCode, + "cost_usd": result.Execution.CostUSD, + "error": result.Execution.ErrorMsg, + "timestamp": time.Now().UTC(), + } + data, _ := json.Marshal(event) + s.hub.Broadcast(data) + } +} + +func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} + +func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { + var input struct { + Name string `json:"name"` + Description string `json:"description"` + Claude task.ClaudeConfig `json:"claude"` + Timeout string `json:"timeout"` + Priority string `json:"priority"` + Tags []string `json:"tags"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + + now := time.Now().UTC() + t := &task.Task{ + ID: uuid.New().String(), + Name: input.Name, + Description: input.Description, + Claude: input.Claude, + Priority: task.Priority(input.Priority), + Tags: input.Tags, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + if t.Priority == "" { + t.Priority = task.PriorityNormal + } + if t.Tags == nil { + t.Tags = []string{} + } + if input.Timeout != "" { + dur, err := time.ParseDuration(input.Timeout) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid timeout: " + err.Error()}) + return + } + t.Timeout.Duration = dur + } + + if err := task.Validate(t); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) + return + } + if err := s.store.CreateTask(t); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + writeJSON(w, http.StatusCreated, t) +} + +func (s *Server) handleListTasks(w http.ResponseWriter, r *http.Request) { + filter := storage.TaskFilter{} + if state := r.URL.Query().Get("state"); state != "" { + filter.State = task.State(state) + } + tasks, err := s.store.ListTasks(filter) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if tasks == nil { + tasks = []*task.Task{} + } + writeJSON(w, http.StatusOK, tasks) +} + +func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + t, err := s.store.GetTask(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + writeJSON(w, http.StatusOK, t) +} + +func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + t, err := s.store.GetTask(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + + if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + t.State = task.StateQueued + + if err := s.pool.Submit(context.Background(), t); err != nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)}) + return + } + + writeJSON(w, http.StatusAccepted, map[string]string{ + "message": "task queued for execution", + "task_id": id, + }) +} + +func (s *Server) handleListExecutions(w http.ResponseWriter, r *http.Request) { + taskID := r.PathValue("id") + execs, err := s.store.ListExecutions(taskID) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if execs == nil { + execs = []*storage.Execution{} + } + writeJSON(w, http.StatusOK, execs) +} + +func (s *Server) handleGetExecution(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + exec, err := s.store.GetExecution(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "execution not found"}) + return + } + writeJSON(w, http.StatusOK, exec) +} + +func writeJSON(w http.ResponseWriter, status int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(v) +} + +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + next.ServeHTTP(w, r) + }) +} |
