summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/server.go225
-rw-r--r--internal/api/server_test.go186
-rw-r--r--internal/api/websocket.go72
3 files changed, 483 insertions, 0 deletions
diff --git a/internal/api/server.go b/internal/api/server.go
new file mode 100644
index 0000000..cc5e6e5
--- /dev/null
+++ b/internal/api/server.go
@@ -0,0 +1,225 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "time"
+
+ "github.com/claudomator/claudomator/internal/executor"
+ "github.com/claudomator/claudomator/internal/storage"
+ "github.com/claudomator/claudomator/internal/task"
+ "github.com/google/uuid"
+)
+
+// Server provides the REST API and WebSocket endpoint for Claudomator.
+type Server struct {
+ store *storage.DB
+ pool *executor.Pool
+ hub *Hub
+ logger *slog.Logger
+ mux *http.ServeMux
+}
+
+func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger) *Server {
+ s := &Server{
+ store: store,
+ pool: pool,
+ hub: NewHub(),
+ logger: logger,
+ mux: http.NewServeMux(),
+ }
+ s.routes()
+ return s
+}
+
+func (s *Server) Handler() http.Handler {
+ return corsMiddleware(s.mux)
+}
+
+func (s *Server) StartHub() {
+ go s.hub.Run()
+ go s.forwardResults()
+}
+
+func (s *Server) routes() {
+ s.mux.HandleFunc("POST /api/tasks", s.handleCreateTask)
+ s.mux.HandleFunc("GET /api/tasks", s.handleListTasks)
+ s.mux.HandleFunc("GET /api/tasks/{id}", s.handleGetTask)
+ s.mux.HandleFunc("POST /api/tasks/{id}/run", s.handleRunTask)
+ s.mux.HandleFunc("GET /api/tasks/{id}/executions", s.handleListExecutions)
+ s.mux.HandleFunc("GET /api/executions/{id}", s.handleGetExecution)
+ s.mux.HandleFunc("GET /api/ws", s.handleWebSocket)
+ s.mux.HandleFunc("GET /api/health", s.handleHealth)
+}
+
+// forwardResults listens on the executor pool's result channel and broadcasts via WebSocket.
+func (s *Server) forwardResults() {
+ for result := range s.pool.Results() {
+ event := map[string]interface{}{
+ "type": "task_completed",
+ "task_id": result.TaskID,
+ "status": result.Execution.Status,
+ "exit_code": result.Execution.ExitCode,
+ "cost_usd": result.Execution.CostUSD,
+ "error": result.Execution.ErrorMsg,
+ "timestamp": time.Now().UTC(),
+ }
+ data, _ := json.Marshal(event)
+ s.hub.Broadcast(data)
+ }
+}
+
+func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
+ writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
+}
+
+func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) {
+ var input struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Claude task.ClaudeConfig `json:"claude"`
+ Timeout string `json:"timeout"`
+ Priority string `json:"priority"`
+ Tags []string `json:"tags"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()})
+ return
+ }
+
+ now := time.Now().UTC()
+ t := &task.Task{
+ ID: uuid.New().String(),
+ Name: input.Name,
+ Description: input.Description,
+ Claude: input.Claude,
+ Priority: task.Priority(input.Priority),
+ Tags: input.Tags,
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"},
+ State: task.StatePending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if t.Priority == "" {
+ t.Priority = task.PriorityNormal
+ }
+ if t.Tags == nil {
+ t.Tags = []string{}
+ }
+ if input.Timeout != "" {
+ dur, err := time.ParseDuration(input.Timeout)
+ if err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid timeout: " + err.Error()})
+ return
+ }
+ t.Timeout.Duration = dur
+ }
+
+ if err := task.Validate(t); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
+ return
+ }
+ if err := s.store.CreateTask(t); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+
+ writeJSON(w, http.StatusCreated, t)
+}
+
+func (s *Server) handleListTasks(w http.ResponseWriter, r *http.Request) {
+ filter := storage.TaskFilter{}
+ if state := r.URL.Query().Get("state"); state != "" {
+ filter.State = task.State(state)
+ }
+ tasks, err := s.store.ListTasks(filter)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ if tasks == nil {
+ tasks = []*task.Task{}
+ }
+ writeJSON(w, http.StatusOK, tasks)
+}
+
+func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ t, err := s.store.GetTask(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
+ return
+ }
+ writeJSON(w, http.StatusOK, t)
+}
+
+func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ t, err := s.store.GetTask(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
+ return
+ }
+
+ if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ t.State = task.StateQueued
+
+ if err := s.pool.Submit(context.Background(), t); err != nil {
+ writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)})
+ return
+ }
+
+ writeJSON(w, http.StatusAccepted, map[string]string{
+ "message": "task queued for execution",
+ "task_id": id,
+ })
+}
+
+func (s *Server) handleListExecutions(w http.ResponseWriter, r *http.Request) {
+ taskID := r.PathValue("id")
+ execs, err := s.store.ListExecutions(taskID)
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ return
+ }
+ if execs == nil {
+ execs = []*storage.Execution{}
+ }
+ writeJSON(w, http.StatusOK, execs)
+}
+
+func (s *Server) handleGetExecution(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ exec, err := s.store.GetExecution(id)
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "execution not found"})
+ return
+ }
+ writeJSON(w, http.StatusOK, exec)
+}
+
+func writeJSON(w http.ResponseWriter, status int, v interface{}) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(status)
+ json.NewEncoder(w).Encode(v)
+}
+
+func corsMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
+ if r.Method == "OPTIONS" {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+ next.ServeHTTP(w, r)
+ })
+}
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
new file mode 100644
index 0000000..c3b77ae
--- /dev/null
+++ b/internal/api/server_test.go
@@ -0,0 +1,186 @@
+package api
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "context"
+
+ "github.com/claudomator/claudomator/internal/executor"
+ "github.com/claudomator/claudomator/internal/storage"
+ "github.com/claudomator/claudomator/internal/task"
+)
+
+func testServer(t *testing.T) (*Server, *storage.DB) {
+ t.Helper()
+ dbPath := filepath.Join(t.TempDir(), "test.db")
+ store, err := storage.Open(dbPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() { store.Close() })
+
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ runner := &mockRunner{}
+ pool := executor.NewPool(2, runner, store, logger)
+ srv := NewServer(store, pool, logger)
+ return srv, store
+}
+
+type mockRunner struct{}
+
+func (m *mockRunner) Run(_ context.Context, _ *task.Task, _ *storage.Execution) error {
+ return nil
+}
+
+func TestHealthEndpoint(t *testing.T) {
+ srv, _ := testServer(t)
+ req := httptest.NewRequest("GET", "/api/health", nil)
+ w := httptest.NewRecorder()
+
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status: want 200, got %d", w.Code)
+ }
+ var body map[string]string
+ json.NewDecoder(w.Body).Decode(&body)
+ if body["status"] != "ok" {
+ t.Errorf("want status=ok, got %v", body)
+ }
+}
+
+func TestCreateTask_Success(t *testing.T) {
+ srv, _ := testServer(t)
+
+ payload := `{
+ "name": "API Task",
+ "description": "Created via API",
+ "claude": {
+ "instructions": "do the thing",
+ "model": "sonnet"
+ },
+ "timeout": "5m",
+ "tags": ["api"]
+ }`
+ req := httptest.NewRequest("POST", "/api/tasks", bytes.NewBufferString(payload))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusCreated {
+ t.Fatalf("status: want 201, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ var created task.Task
+ json.NewDecoder(w.Body).Decode(&created)
+ if created.Name != "API Task" {
+ t.Errorf("name: want 'API Task', got %q", created.Name)
+ }
+ if created.ID == "" {
+ t.Error("expected auto-generated ID")
+ }
+}
+
+func TestCreateTask_InvalidJSON(t *testing.T) {
+ srv, _ := testServer(t)
+
+ req := httptest.NewRequest("POST", "/api/tasks", bytes.NewBufferString("{bad json"))
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status: want 400, got %d", w.Code)
+ }
+}
+
+func TestCreateTask_ValidationFailure(t *testing.T) {
+ srv, _ := testServer(t)
+
+ payload := `{"name": "", "claude": {"instructions": ""}}`
+ req := httptest.NewRequest("POST", "/api/tasks", bytes.NewBufferString(payload))
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status: want 400, got %d", w.Code)
+ }
+}
+
+func TestListTasks_Empty(t *testing.T) {
+ srv, _ := testServer(t)
+
+ req := httptest.NewRequest("GET", "/api/tasks", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status: want 200, got %d", w.Code)
+ }
+
+ var tasks []task.Task
+ json.NewDecoder(w.Body).Decode(&tasks)
+ if len(tasks) != 0 {
+ t.Errorf("want 0 tasks, got %d", len(tasks))
+ }
+}
+
+func TestGetTask_NotFound(t *testing.T) {
+ srv, _ := testServer(t)
+
+ req := httptest.NewRequest("GET", "/api/tasks/nonexistent", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusNotFound {
+ t.Errorf("status: want 404, got %d", w.Code)
+ }
+}
+
+func TestListTasks_WithTasks(t *testing.T) {
+ srv, store := testServer(t)
+
+ // Create tasks directly in store.
+ for i := 0; i < 3; i++ {
+ tk := &task.Task{
+ ID: fmt.Sprintf("lt-%d", i), Name: fmt.Sprintf("T%d", i),
+ Claude: task.ClaudeConfig{Instructions: "x"}, Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{}, State: task.StatePending,
+ }
+ store.CreateTask(tk)
+ }
+
+ req := httptest.NewRequest("GET", "/api/tasks", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ var tasks []task.Task
+ json.NewDecoder(w.Body).Decode(&tasks)
+ if len(tasks) != 3 {
+ t.Errorf("want 3 tasks, got %d", len(tasks))
+ }
+}
+
+func TestCORS_Headers(t *testing.T) {
+ srv, _ := testServer(t)
+
+ req := httptest.NewRequest("OPTIONS", "/api/tasks", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Header().Get("Access-Control-Allow-Origin") != "*" {
+ t.Error("missing CORS origin header")
+ }
+ if w.Code != http.StatusOK {
+ t.Errorf("OPTIONS status: want 200, got %d", w.Code)
+ }
+}
diff --git a/internal/api/websocket.go b/internal/api/websocket.go
new file mode 100644
index 0000000..6bd8c88
--- /dev/null
+++ b/internal/api/websocket.go
@@ -0,0 +1,72 @@
+package api
+
+import (
+ "log/slog"
+ "net/http"
+ "sync"
+
+ "golang.org/x/net/websocket"
+)
+
+// Hub manages WebSocket connections and broadcasts messages.
+type Hub struct {
+ mu sync.RWMutex
+ clients map[*websocket.Conn]bool
+ logger *slog.Logger
+}
+
+func NewHub() *Hub {
+ return &Hub{
+ clients: make(map[*websocket.Conn]bool),
+ logger: slog.Default(),
+ }
+}
+
+// Run is a no-op loop kept for future cleanup/heartbeat logic.
+func (h *Hub) Run() {}
+
+func (h *Hub) Register(ws *websocket.Conn) {
+ h.mu.Lock()
+ h.clients[ws] = true
+ h.mu.Unlock()
+}
+
+func (h *Hub) Unregister(ws *websocket.Conn) {
+ h.mu.Lock()
+ delete(h.clients, ws)
+ h.mu.Unlock()
+}
+
+// Broadcast sends a message to all connected WebSocket clients.
+func (h *Hub) Broadcast(msg []byte) {
+ h.mu.RLock()
+ defer h.mu.RUnlock()
+ for conn := range h.clients {
+ if _, err := conn.Write(msg); err != nil {
+ h.logger.Error("websocket write error", "error", err)
+ }
+ }
+}
+
+// ClientCount returns the number of connected clients.
+func (h *Hub) ClientCount() int {
+ h.mu.RLock()
+ defer h.mu.RUnlock()
+ return len(h.clients)
+}
+
+func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
+ handler := websocket.Handler(func(ws *websocket.Conn) {
+ s.hub.Register(ws)
+ defer s.hub.Unregister(ws)
+
+ // Keep connection alive until client disconnects.
+ buf := make([]byte, 1024)
+ for {
+ if _, err := ws.Read(buf); err != nil {
+ break
+ }
+ }
+ })
+ handler.ServeHTTP(w, r)
+}