diff options
34 files changed, 3262 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cdfa427 --- /dev/null +++ b/.gitignore @@ -0,0 +1,29 @@ +# Binaries +/bin/ +/bin/claudomator +*.exe + +# Go +vendor/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# Data +*.db +*.db-journal +*.db-wal + +# Test +coverage.out +coverage.html + +# Session state +SESSION_STATE.md diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a8b99ff --- /dev/null +++ b/Makefile @@ -0,0 +1,27 @@ +.PHONY: build test lint clean run + +BINARY := claudomator +BUILD_DIR := bin +GO := go + +build: + $(GO) build -o $(BUILD_DIR)/$(BINARY) ./cmd/claudomator + +test: + $(GO) test ./... -v -race -count=1 + +test-cover: + $(GO) test ./... -coverprofile=coverage.out -race + $(GO) tool cover -html=coverage.out -o coverage.html + +lint: + golangci-lint run ./... + +clean: + rm -rf $(BUILD_DIR) coverage.out coverage.html + +run: build + ./$(BUILD_DIR)/$(BINARY) + +tidy: + $(GO) mod tidy diff --git a/cmd/claudomator/main.go b/cmd/claudomator/main.go new file mode 100644 index 0000000..e6cba37 --- /dev/null +++ b/cmd/claudomator/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + "os" + + "github.com/claudomator/claudomator/internal/cli" +) + +func main() { + if err := cli.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} diff --git a/docs/adr/001-language-and-architecture.md b/docs/adr/001-language-and-architecture.md new file mode 100644 index 0000000..99022f1 --- /dev/null +++ b/docs/adr/001-language-and-architecture.md @@ -0,0 +1,32 @@ +# ADR-001: Go with Server + Mobile PWA Architecture + +## Status +Accepted + +## Context +Claudomator needs to capture tasks, dispatch them to Claude Code, and report results. +The primary human interface is a mobile device requiring push notifications and frictionless input. + +## Decision +- **Language**: Go for the backend (CLI + API server) +- **Architecture**: Pipeline with bounded executor pool (goroutines) +- **API**: REST + WebSocket for real-time updates +- **Storage**: SQLite + filesystem (logs, artifacts) +- **Task format**: YAML definitions +- **Mobile**: PWA with Web Push notifications (future phase) + +## Rationale +- Go: single binary, excellent process management (`os/exec`), natural concurrency +- SQLite: zero-dependency, embeddable, queryable metadata +- WebSocket: real-time progress streaming to mobile clients +- REST: simple task creation from any HTTP client (mobile, curl, CI) + +## Alternatives Considered +- **TypeScript/Node.js**: Claude Code SDK exists, but runtime dependency hurts distribution +- **Python**: Good async, but packaging/distribution is painful for CLI tools +- **Rust**: Overkill for this problem domain; slower iteration speed + +## Consequences +- CGo dependency via `go-sqlite3` (requires C compiler for builds) +- Mobile PWA to be built as a separate frontend phase +- Claude Code invoked via CLI flags (`-p`, `--output-format stream-json`), not SDK @@ -0,0 +1,16 @@ +module github.com/claudomator/claudomator + +go 1.25.3 + +require ( + github.com/google/uuid v1.6.0 + github.com/mattn/go-sqlite3 v1.14.33 + github.com/spf13/cobra v1.10.2 + golang.org/x/net v0.49.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.9 // indirect +) @@ -0,0 +1,19 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= +github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/api/server.go b/internal/api/server.go new file mode 100644 index 0000000..cc5e6e5 --- /dev/null +++ b/internal/api/server.go @@ -0,0 +1,225 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/claudomator/claudomator/internal/executor" + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" + "github.com/google/uuid" +) + +// Server provides the REST API and WebSocket endpoint for Claudomator. +type Server struct { + store *storage.DB + pool *executor.Pool + hub *Hub + logger *slog.Logger + mux *http.ServeMux +} + +func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger) *Server { + s := &Server{ + store: store, + pool: pool, + hub: NewHub(), + logger: logger, + mux: http.NewServeMux(), + } + s.routes() + return s +} + +func (s *Server) Handler() http.Handler { + return corsMiddleware(s.mux) +} + +func (s *Server) StartHub() { + go s.hub.Run() + go s.forwardResults() +} + +func (s *Server) routes() { + s.mux.HandleFunc("POST /api/tasks", s.handleCreateTask) + s.mux.HandleFunc("GET /api/tasks", s.handleListTasks) + s.mux.HandleFunc("GET /api/tasks/{id}", s.handleGetTask) + s.mux.HandleFunc("POST /api/tasks/{id}/run", s.handleRunTask) + s.mux.HandleFunc("GET /api/tasks/{id}/executions", s.handleListExecutions) + s.mux.HandleFunc("GET /api/executions/{id}", s.handleGetExecution) + s.mux.HandleFunc("GET /api/ws", s.handleWebSocket) + s.mux.HandleFunc("GET /api/health", s.handleHealth) +} + +// forwardResults listens on the executor pool's result channel and broadcasts via WebSocket. +func (s *Server) forwardResults() { + for result := range s.pool.Results() { + event := map[string]interface{}{ + "type": "task_completed", + "task_id": result.TaskID, + "status": result.Execution.Status, + "exit_code": result.Execution.ExitCode, + "cost_usd": result.Execution.CostUSD, + "error": result.Execution.ErrorMsg, + "timestamp": time.Now().UTC(), + } + data, _ := json.Marshal(event) + s.hub.Broadcast(data) + } +} + +func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} + +func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { + var input struct { + Name string `json:"name"` + Description string `json:"description"` + Claude task.ClaudeConfig `json:"claude"` + Timeout string `json:"timeout"` + Priority string `json:"priority"` + Tags []string `json:"tags"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON: " + err.Error()}) + return + } + + now := time.Now().UTC() + t := &task.Task{ + ID: uuid.New().String(), + Name: input.Name, + Description: input.Description, + Claude: input.Claude, + Priority: task.Priority(input.Priority), + Tags: input.Tags, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + if t.Priority == "" { + t.Priority = task.PriorityNormal + } + if t.Tags == nil { + t.Tags = []string{} + } + if input.Timeout != "" { + dur, err := time.ParseDuration(input.Timeout) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid timeout: " + err.Error()}) + return + } + t.Timeout.Duration = dur + } + + if err := task.Validate(t); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) + return + } + if err := s.store.CreateTask(t); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + writeJSON(w, http.StatusCreated, t) +} + +func (s *Server) handleListTasks(w http.ResponseWriter, r *http.Request) { + filter := storage.TaskFilter{} + if state := r.URL.Query().Get("state"); state != "" { + filter.State = task.State(state) + } + tasks, err := s.store.ListTasks(filter) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if tasks == nil { + tasks = []*task.Task{} + } + writeJSON(w, http.StatusOK, tasks) +} + +func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + t, err := s.store.GetTask(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + writeJSON(w, http.StatusOK, t) +} + +func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + t, err := s.store.GetTask(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + + if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + t.State = task.StateQueued + + if err := s.pool.Submit(context.Background(), t); err != nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)}) + return + } + + writeJSON(w, http.StatusAccepted, map[string]string{ + "message": "task queued for execution", + "task_id": id, + }) +} + +func (s *Server) handleListExecutions(w http.ResponseWriter, r *http.Request) { + taskID := r.PathValue("id") + execs, err := s.store.ListExecutions(taskID) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if execs == nil { + execs = []*storage.Execution{} + } + writeJSON(w, http.StatusOK, execs) +} + +func (s *Server) handleGetExecution(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + exec, err := s.store.GetExecution(id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "execution not found"}) + return + } + writeJSON(w, http.StatusOK, exec) +} + +func writeJSON(w http.ResponseWriter, status int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(v) +} + +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/api/server_test.go b/internal/api/server_test.go new file mode 100644 index 0000000..c3b77ae --- /dev/null +++ b/internal/api/server_test.go @@ -0,0 +1,186 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "context" + + "github.com/claudomator/claudomator/internal/executor" + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" +) + +func testServer(t *testing.T) (*Server, *storage.DB) { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "test.db") + store, err := storage.Open(dbPath) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { store.Close() }) + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + runner := &mockRunner{} + pool := executor.NewPool(2, runner, store, logger) + srv := NewServer(store, pool, logger) + return srv, store +} + +type mockRunner struct{} + +func (m *mockRunner) Run(_ context.Context, _ *task.Task, _ *storage.Execution) error { + return nil +} + +func TestHealthEndpoint(t *testing.T) { + srv, _ := testServer(t) + req := httptest.NewRequest("GET", "/api/health", nil) + w := httptest.NewRecorder() + + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status: want 200, got %d", w.Code) + } + var body map[string]string + json.NewDecoder(w.Body).Decode(&body) + if body["status"] != "ok" { + t.Errorf("want status=ok, got %v", body) + } +} + +func TestCreateTask_Success(t *testing.T) { + srv, _ := testServer(t) + + payload := `{ + "name": "API Task", + "description": "Created via API", + "claude": { + "instructions": "do the thing", + "model": "sonnet" + }, + "timeout": "5m", + "tags": ["api"] + }` + req := httptest.NewRequest("POST", "/api/tasks", bytes.NewBufferString(payload)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status: want 201, got %d; body: %s", w.Code, w.Body.String()) + } + + var created task.Task + json.NewDecoder(w.Body).Decode(&created) + if created.Name != "API Task" { + t.Errorf("name: want 'API Task', got %q", created.Name) + } + if created.ID == "" { + t.Error("expected auto-generated ID") + } +} + +func TestCreateTask_InvalidJSON(t *testing.T) { + srv, _ := testServer(t) + + req := httptest.NewRequest("POST", "/api/tasks", bytes.NewBufferString("{bad json")) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status: want 400, got %d", w.Code) + } +} + +func TestCreateTask_ValidationFailure(t *testing.T) { + srv, _ := testServer(t) + + payload := `{"name": "", "claude": {"instructions": ""}}` + req := httptest.NewRequest("POST", "/api/tasks", bytes.NewBufferString(payload)) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status: want 400, got %d", w.Code) + } +} + +func TestListTasks_Empty(t *testing.T) { + srv, _ := testServer(t) + + req := httptest.NewRequest("GET", "/api/tasks", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status: want 200, got %d", w.Code) + } + + var tasks []task.Task + json.NewDecoder(w.Body).Decode(&tasks) + if len(tasks) != 0 { + t.Errorf("want 0 tasks, got %d", len(tasks)) + } +} + +func TestGetTask_NotFound(t *testing.T) { + srv, _ := testServer(t) + + req := httptest.NewRequest("GET", "/api/tasks/nonexistent", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status: want 404, got %d", w.Code) + } +} + +func TestListTasks_WithTasks(t *testing.T) { + srv, store := testServer(t) + + // Create tasks directly in store. + for i := 0; i < 3; i++ { + tk := &task.Task{ + ID: fmt.Sprintf("lt-%d", i), Name: fmt.Sprintf("T%d", i), + Claude: task.ClaudeConfig{Instructions: "x"}, Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, State: task.StatePending, + } + store.CreateTask(tk) + } + + req := httptest.NewRequest("GET", "/api/tasks", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + var tasks []task.Task + json.NewDecoder(w.Body).Decode(&tasks) + if len(tasks) != 3 { + t.Errorf("want 3 tasks, got %d", len(tasks)) + } +} + +func TestCORS_Headers(t *testing.T) { + srv, _ := testServer(t) + + req := httptest.NewRequest("OPTIONS", "/api/tasks", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Error("missing CORS origin header") + } + if w.Code != http.StatusOK { + t.Errorf("OPTIONS status: want 200, got %d", w.Code) + } +} diff --git a/internal/api/websocket.go b/internal/api/websocket.go new file mode 100644 index 0000000..6bd8c88 --- /dev/null +++ b/internal/api/websocket.go @@ -0,0 +1,72 @@ +package api + +import ( + "log/slog" + "net/http" + "sync" + + "golang.org/x/net/websocket" +) + +// Hub manages WebSocket connections and broadcasts messages. +type Hub struct { + mu sync.RWMutex + clients map[*websocket.Conn]bool + logger *slog.Logger +} + +func NewHub() *Hub { + return &Hub{ + clients: make(map[*websocket.Conn]bool), + logger: slog.Default(), + } +} + +// Run is a no-op loop kept for future cleanup/heartbeat logic. +func (h *Hub) Run() {} + +func (h *Hub) Register(ws *websocket.Conn) { + h.mu.Lock() + h.clients[ws] = true + h.mu.Unlock() +} + +func (h *Hub) Unregister(ws *websocket.Conn) { + h.mu.Lock() + delete(h.clients, ws) + h.mu.Unlock() +} + +// Broadcast sends a message to all connected WebSocket clients. +func (h *Hub) Broadcast(msg []byte) { + h.mu.RLock() + defer h.mu.RUnlock() + for conn := range h.clients { + if _, err := conn.Write(msg); err != nil { + h.logger.Error("websocket write error", "error", err) + } + } +} + +// ClientCount returns the number of connected clients. +func (h *Hub) ClientCount() int { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.clients) +} + +func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + handler := websocket.Handler(func(ws *websocket.Conn) { + s.hub.Register(ws) + defer s.hub.Unregister(ws) + + // Keep connection alive until client disconnects. + buf := make([]byte, 1024) + for { + if _, err := ws.Read(buf); err != nil { + break + } + } + }) + handler.ServeHTTP(w, r) +} diff --git a/internal/cli/init.go b/internal/cli/init.go new file mode 100644 index 0000000..6660f9d --- /dev/null +++ b/internal/cli/init.go @@ -0,0 +1,58 @@ +package cli + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/claudomator/claudomator/internal/storage" + "github.com/spf13/cobra" +) + +func newInitCmd() *cobra.Command { + return &cobra.Command{ + Use: "init", + Short: "Initialize Claudomator data directory", + RunE: func(cmd *cobra.Command, args []string) error { + return initClaudomator() + }, + } +} + +func initClaudomator() error { + if err := cfg.EnsureDirs(); err != nil { + return fmt.Errorf("creating directories: %w", err) + } + + // Initialize database. + store, err := storage.Open(cfg.DBPath) + if err != nil { + return fmt.Errorf("initializing database: %w", err) + } + store.Close() + + // Create example task file if it doesn't exist. + examplePath := filepath.Join(cfg.DataDir, "example-task.yaml") + if _, err := os.Stat(examplePath); os.IsNotExist(err) { + example := `name: "Example Task" +description: "A sample task to get started" +claude: + model: "sonnet" + instructions: | + Say hello and list the files in the current directory. + working_dir: "." +timeout: "5m" +tags: + - "example" +` + if err := os.WriteFile(examplePath, []byte(example), 0644); err != nil { + return fmt.Errorf("writing example: %w", err) + } + } + + fmt.Printf("Claudomator initialized at %s\n", cfg.DataDir) + fmt.Printf(" Database: %s\n", cfg.DBPath) + fmt.Printf(" Logs: %s\n", cfg.LogDir) + fmt.Printf(" Example: %s\n", examplePath) + return nil +} diff --git a/internal/cli/list.go b/internal/cli/list.go new file mode 100644 index 0000000..a7515a1 --- /dev/null +++ b/internal/cli/list.go @@ -0,0 +1,59 @@ +package cli + +import ( + "fmt" + "os" + "text/tabwriter" + + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" + "github.com/spf13/cobra" +) + +func newListCmd() *cobra.Command { + var state string + + cmd := &cobra.Command{ + Use: "list", + Short: "List tasks", + RunE: func(cmd *cobra.Command, args []string) error { + return listTasks(state) + }, + } + + cmd.Flags().StringVar(&state, "state", "", "filter by state (PENDING, RUNNING, COMPLETED, FAILED)") + + return cmd +} + +func listTasks(state string) error { + store, err := storage.Open(cfg.DBPath) + if err != nil { + return fmt.Errorf("opening db: %w", err) + } + defer store.Close() + + filter := storage.TaskFilter{} + if state != "" { + filter.State = task.State(state) + } + + tasks, err := store.ListTasks(filter) + if err != nil { + return err + } + + if len(tasks) == 0 { + fmt.Println("No tasks found.") + return nil + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tNAME\tSTATE\tPRIORITY\tCREATED") + for _, t := range tasks { + fmt.Fprintf(w, "%.8s\t%s\t%s\t%s\t%s\n", + t.ID, t.Name, t.State, t.Priority, t.CreatedAt.Format("2006-01-02 15:04")) + } + w.Flush() + return nil +} diff --git a/internal/cli/root.go b/internal/cli/root.go new file mode 100644 index 0000000..2800a76 --- /dev/null +++ b/internal/cli/root.go @@ -0,0 +1,40 @@ +package cli + +import ( + "github.com/claudomator/claudomator/internal/config" + "github.com/spf13/cobra" +) + +var ( + cfgFile string + verbose bool + cfg *config.Config +) + +func NewRootCmd() *cobra.Command { + cfg = config.Default() + + cmd := &cobra.Command{ + Use: "claudomator", + Short: "Automation toolkit for Claude Code", + Long: "Claudomator captures tasks, dispatches them to Claude Code, and reports results.", + } + + cmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default $HOME/.claudomator/config.toml)") + cmd.PersistentFlags().StringVar(&cfg.DataDir, "data-dir", cfg.DataDir, "data directory") + cmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "verbose output") + + cmd.AddCommand( + newRunCmd(), + newServeCmd(), + newListCmd(), + newStatusCmd(), + newInitCmd(), + ) + + return cmd +} + +func Execute() error { + return NewRootCmd().Execute() +} diff --git a/internal/cli/run.go b/internal/cli/run.go new file mode 100644 index 0000000..e74b247 --- /dev/null +++ b/internal/cli/run.go @@ -0,0 +1,128 @@ +package cli + +import ( + "context" + "fmt" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/claudomator/claudomator/internal/executor" + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" + "github.com/spf13/cobra" +) + +func newRunCmd() *cobra.Command { + var ( + parallel int + dryRun bool + ) + + cmd := &cobra.Command{ + Use: "run <task-file>", + Short: "Run task(s) from a YAML file", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runTasks(args[0], parallel, dryRun) + }, + } + + cmd.Flags().IntVarP(¶llel, "parallel", "p", 3, "max concurrent executions") + cmd.Flags().BoolVar(&dryRun, "dry-run", false, "validate without executing") + + return cmd +} + +func runTasks(file string, parallel int, dryRun bool) error { + tasks, err := task.ParseFile(file) + if err != nil { + return fmt.Errorf("parsing: %w", err) + } + + // Validate all tasks. + for i := range tasks { + if err := task.Validate(&tasks[i]); err != nil { + return fmt.Errorf("task %q: %w", tasks[i].Name, err) + } + } + + if dryRun { + fmt.Printf("Validated %d task(s) successfully.\n", len(tasks)) + for _, t := range tasks { + fmt.Printf(" - %s (model: %s, timeout: %v)\n", t.Name, t.Claude.Model, t.Timeout.Duration) + } + return nil + } + + // Setup infrastructure. + if err := cfg.EnsureDirs(); err != nil { + return fmt.Errorf("creating dirs: %w", err) + } + + store, err := storage.Open(cfg.DBPath) + if err != nil { + return fmt.Errorf("opening db: %w", err) + } + defer store.Close() + + level := slog.LevelInfo + if verbose { + level = slog.LevelDebug + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) + + runner := &executor.ClaudeRunner{ + BinaryPath: cfg.ClaudeBinaryPath, + Logger: logger, + LogDir: cfg.LogDir, + } + pool := executor.NewPool(parallel, runner, store, logger) + + // Handle graceful shutdown. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + go func() { + <-sigCh + fmt.Fprintln(os.Stderr, "\nShutting down...") + cancel() + }() + + // Submit all tasks. + fmt.Printf("Dispatching %d task(s) (max concurrency: %d)...\n", len(tasks), parallel) + for i := range tasks { + if err := store.CreateTask(&tasks[i]); err != nil { + return fmt.Errorf("storing task: %w", err) + } + if err := store.UpdateTaskState(tasks[i].ID, task.StateQueued); err != nil { + return fmt.Errorf("queuing task: %w", err) + } + tasks[i].State = task.StateQueued + if err := pool.Submit(ctx, &tasks[i]); err != nil { + logger.Warn("could not submit task", "name", tasks[i].Name, "error", err) + } + } + + // Wait for all results. + completed, failed := 0, 0 + for i := 0; i < len(tasks); i++ { + result := <-pool.Results() + if result.Err != nil { + failed++ + fmt.Printf(" FAIL %s: %v\n", result.TaskID, result.Err) + } else { + completed++ + fmt.Printf(" OK %s (cost: $%.4f)\n", result.TaskID, result.Execution.CostUSD) + } + } + + fmt.Printf("\nDone: %d completed, %d failed\n", completed, failed) + if failed > 0 { + return fmt.Errorf("%d task(s) failed", failed) + } + return nil +} diff --git a/internal/cli/serve.go b/internal/cli/serve.go new file mode 100644 index 0000000..5d41395 --- /dev/null +++ b/internal/cli/serve.go @@ -0,0 +1,86 @@ +package cli + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/claudomator/claudomator/internal/api" + "github.com/claudomator/claudomator/internal/executor" + "github.com/claudomator/claudomator/internal/storage" + "github.com/spf13/cobra" +) + +func newServeCmd() *cobra.Command { + var addr string + + cmd := &cobra.Command{ + Use: "serve", + Short: "Start the Claudomator API server", + RunE: func(cmd *cobra.Command, args []string) error { + return serve(addr) + }, + } + + cmd.Flags().StringVar(&addr, "addr", ":8484", "listen address") + + return cmd +} + +func serve(addr string) error { + if err := cfg.EnsureDirs(); err != nil { + return fmt.Errorf("creating dirs: %w", err) + } + + store, err := storage.Open(cfg.DBPath) + if err != nil { + return fmt.Errorf("opening db: %w", err) + } + defer store.Close() + + level := slog.LevelInfo + if verbose { + level = slog.LevelDebug + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) + + runner := &executor.ClaudeRunner{ + BinaryPath: cfg.ClaudeBinaryPath, + Logger: logger, + LogDir: cfg.LogDir, + } + pool := executor.NewPool(cfg.MaxConcurrent, runner, store, logger) + + srv := api.NewServer(store, pool, logger) + srv.StartHub() + + httpSrv := &http.Server{ + Addr: addr, + Handler: srv.Handler(), + } + + // Graceful shutdown. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + go func() { + <-sigCh + logger.Info("shutting down server...") + shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second) + defer shutdownCancel() + httpSrv.Shutdown(shutdownCtx) + }() + + fmt.Printf("Claudomator server listening on %s\n", addr) + if err := httpSrv.ListenAndServe(); err != http.ErrServerClosed { + return err + } + return nil +} diff --git a/internal/cli/status.go b/internal/cli/status.go new file mode 100644 index 0000000..4613fee --- /dev/null +++ b/internal/cli/status.go @@ -0,0 +1,63 @@ +package cli + +import ( + "fmt" + "os" + "text/tabwriter" + + "github.com/claudomator/claudomator/internal/storage" + "github.com/spf13/cobra" +) + +func newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status <task-id>", + Short: "Show task status and execution history", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return showStatus(args[0]) + }, + } + return cmd +} + +func showStatus(id string) error { + store, err := storage.Open(cfg.DBPath) + if err != nil { + return fmt.Errorf("opening db: %w", err) + } + defer store.Close() + + // Try full ID first, then prefix match. + t, err := store.GetTask(id) + if err != nil { + return fmt.Errorf("task %q not found", id) + } + + fmt.Printf("Task: %s\n", t.Name) + fmt.Printf("ID: %s\n", t.ID) + fmt.Printf("State: %s\n", t.State) + fmt.Printf("Priority: %s\n", t.Priority) + fmt.Printf("Model: %s\n", t.Claude.Model) + if t.Description != "" { + fmt.Printf("Description: %s\n", t.Description) + } + + execs, err := store.ListExecutions(t.ID) + if err != nil { + return err + } + + if len(execs) > 0 { + fmt.Printf("\nExecutions (%d):\n", len(execs)) + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, " ID\tSTATUS\tEXIT\tCOST\tDURATION\tSTARTED") + for _, e := range execs { + dur := e.EndTime.Sub(e.StartTime) + fmt.Fprintf(w, " %.8s\t%s\t%d\t$%.4f\t%v\t%s\n", + e.ID, e.Status, e.ExitCode, e.CostUSD, dur.Round(1e9), e.StartTime.Format("15:04:05")) + } + w.Flush() + } + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..da7f264 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,41 @@ +package config + +import ( + "os" + "path/filepath" +) + +type Config struct { + DataDir string `toml:"data_dir"` + DBPath string `toml:"-"` + LogDir string `toml:"-"` + ClaudeBinaryPath string `toml:"claude_binary_path"` + MaxConcurrent int `toml:"max_concurrent"` + DefaultTimeout string `toml:"default_timeout"` + ServerAddr string `toml:"server_addr"` + WebhookURL string `toml:"webhook_url"` +} + +func Default() *Config { + home, _ := os.UserHomeDir() + dataDir := filepath.Join(home, ".claudomator") + return &Config{ + DataDir: dataDir, + DBPath: filepath.Join(dataDir, "claudomator.db"), + LogDir: filepath.Join(dataDir, "executions"), + ClaudeBinaryPath: "claude", + MaxConcurrent: 3, + DefaultTimeout: "15m", + ServerAddr: ":8484", + } +} + +// EnsureDirs creates the data directory structure. +func (c *Config) EnsureDirs() error { + for _, dir := range []string{c.DataDir, c.LogDir} { + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + } + return nil +} diff --git a/internal/executor/claude.go b/internal/executor/claude.go new file mode 100644 index 0000000..c845d58 --- /dev/null +++ b/internal/executor/claude.go @@ -0,0 +1,152 @@ +package executor + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" +) + +// ClaudeRunner spawns the `claude` CLI in non-interactive mode. +type ClaudeRunner struct { + BinaryPath string // defaults to "claude" + Logger *slog.Logger + LogDir string // base directory for execution logs +} + +func (r *ClaudeRunner) binaryPath() string { + if r.BinaryPath != "" { + return r.BinaryPath + } + return "claude" +} + +// Run executes a claude -p invocation, streaming output to log files. +func (r *ClaudeRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + args := r.buildArgs(t) + + cmd := exec.CommandContext(ctx, r.binaryPath(), args...) + if t.Claude.WorkingDir != "" { + cmd.Dir = t.Claude.WorkingDir + } + + // Setup log directory for this execution. + logDir := filepath.Join(r.LogDir, e.ID) + if err := os.MkdirAll(logDir, 0700); err != nil { + return fmt.Errorf("creating log dir: %w", err) + } + + stdoutPath := filepath.Join(logDir, "stdout.log") + stderrPath := filepath.Join(logDir, "stderr.log") + e.StdoutPath = stdoutPath + e.StderrPath = stderrPath + e.ArtifactDir = logDir + + stdoutFile, err := os.Create(stdoutPath) + if err != nil { + return fmt.Errorf("creating stdout log: %w", err) + } + defer stdoutFile.Close() + + stderrFile, err := os.Create(stderrPath) + if err != nil { + return fmt.Errorf("creating stderr log: %w", err) + } + defer stderrFile.Close() + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("creating stdout pipe: %w", err) + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("creating stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("starting claude: %w", err) + } + + // Stream output to log files and parse cost info. + var costUSD float64 + go func() { + costUSD = streamAndParseCost(stdoutPipe, stdoutFile, r.Logger) + }() + go io.Copy(stderrFile, stderrPipe) + + if err := cmd.Wait(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + e.ExitCode = exitErr.ExitCode() + } + e.CostUSD = costUSD + return fmt.Errorf("claude exited with error: %w", err) + } + + e.ExitCode = 0 + e.CostUSD = costUSD + return nil +} + +func (r *ClaudeRunner) buildArgs(t *task.Task) []string { + args := []string{ + "-p", t.Claude.Instructions, + "--output-format", "stream-json", + } + + if t.Claude.Model != "" { + args = append(args, "--model", t.Claude.Model) + } + if t.Claude.MaxBudgetUSD > 0 { + args = append(args, "--max-budget-usd", fmt.Sprintf("%.2f", t.Claude.MaxBudgetUSD)) + } + if t.Claude.PermissionMode != "" { + args = append(args, "--permission-mode", t.Claude.PermissionMode) + } + if t.Claude.SystemPromptAppend != "" { + args = append(args, "--append-system-prompt", t.Claude.SystemPromptAppend) + } + for _, tool := range t.Claude.AllowedTools { + args = append(args, "--allowedTools", tool) + } + for _, tool := range t.Claude.DisallowedTools { + args = append(args, "--disallowedTools", tool) + } + for _, f := range t.Claude.ContextFiles { + args = append(args, "--add-dir", f) + } + args = append(args, t.Claude.AdditionalArgs...) + + return args +} + +// streamAndParseCost reads streaming JSON from claude and writes to the log file, +// extracting cost data from the stream. +func streamAndParseCost(r io.Reader, w io.Writer, logger *slog.Logger) float64 { + tee := io.TeeReader(r, w) + scanner := bufio.NewScanner(tee) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large lines + + var totalCost float64 + for scanner.Scan() { + line := scanner.Bytes() + var msg map[string]interface{} + if err := json.Unmarshal(line, &msg); err != nil { + continue + } + // Extract cost from result messages. + if costData, ok := msg["cost_usd"]; ok { + if cost, ok := costData.(float64); ok { + totalCost = cost + } + } + } + return totalCost +} diff --git a/internal/executor/claude_test.go b/internal/executor/claude_test.go new file mode 100644 index 0000000..448ab40 --- /dev/null +++ b/internal/executor/claude_test.go @@ -0,0 +1,84 @@ +package executor + +import ( + "testing" + + "github.com/claudomator/claudomator/internal/task" +) + +func TestClaudeRunner_BuildArgs_BasicTask(t *testing.T) { + r := &ClaudeRunner{} + tk := &task.Task{ + Claude: task.ClaudeConfig{ + Instructions: "fix the bug", + Model: "sonnet", + }, + } + + args := r.buildArgs(tk) + + expected := []string{"-p", "fix the bug", "--output-format", "stream-json", "--model", "sonnet"} + if len(args) != len(expected) { + t.Fatalf("args length: want %d, got %d: %v", len(expected), len(args), args) + } + for i, want := range expected { + if args[i] != want { + t.Errorf("arg[%d]: want %q, got %q", i, want, args[i]) + } + } +} + +func TestClaudeRunner_BuildArgs_FullConfig(t *testing.T) { + r := &ClaudeRunner{} + tk := &task.Task{ + Claude: task.ClaudeConfig{ + Instructions: "implement feature", + Model: "opus", + MaxBudgetUSD: 5.0, + PermissionMode: "bypassPermissions", + SystemPromptAppend: "Follow TDD", + AllowedTools: []string{"Bash", "Edit"}, + DisallowedTools: []string{"Write"}, + ContextFiles: []string{"/src"}, + AdditionalArgs: []string{"--verbose"}, + }, + } + + args := r.buildArgs(tk) + + // Check key args are present. + argMap := make(map[string]bool) + for _, a := range args { + argMap[a] = true + } + + requiredArgs := []string{ + "-p", "implement feature", "--output-format", "stream-json", + "--model", "opus", "--max-budget-usd", "5.00", + "--permission-mode", "bypassPermissions", + "--append-system-prompt", "Follow TDD", + "--allowedTools", "Bash", "Edit", + "--disallowedTools", "Write", + "--add-dir", "/src", + "--verbose", + } + for _, req := range requiredArgs { + if !argMap[req] { + t.Errorf("missing arg %q in %v", req, args) + } + } +} + +func TestClaudeRunner_BinaryPath_Default(t *testing.T) { + r := &ClaudeRunner{} + if r.binaryPath() != "claude" { + t.Errorf("want 'claude', got %q", r.binaryPath()) + } +} + +func TestClaudeRunner_BinaryPath_Custom(t *testing.T) { + r := &ClaudeRunner{BinaryPath: "/usr/local/bin/claude"} + if r.binaryPath() != "/usr/local/bin/claude" { + t.Errorf("want custom path, got %q", r.binaryPath()) + } +} diff --git a/internal/executor/executor.go b/internal/executor/executor.go new file mode 100644 index 0000000..c6c5124 --- /dev/null +++ b/internal/executor/executor.go @@ -0,0 +1,138 @@ +package executor + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" + "github.com/google/uuid" +) + +// Runner executes a single task and returns the result. +type Runner interface { + Run(ctx context.Context, t *task.Task, exec *storage.Execution) error +} + +// Pool manages a bounded set of concurrent task workers. +type Pool struct { + maxConcurrent int + runner Runner + store *storage.DB + logger *slog.Logger + + mu sync.Mutex + active int + resultCh chan *Result +} + +// Result is emitted when a task execution completes. +type Result struct { + TaskID string + Execution *storage.Execution + Err error +} + +func NewPool(maxConcurrent int, runner Runner, store *storage.DB, logger *slog.Logger) *Pool { + if maxConcurrent < 1 { + maxConcurrent = 1 + } + return &Pool{ + maxConcurrent: maxConcurrent, + runner: runner, + store: store, + logger: logger, + resultCh: make(chan *Result, maxConcurrent*2), + } +} + +// Submit dispatches a task for execution. Blocks if pool is at capacity. +func (p *Pool) Submit(ctx context.Context, t *task.Task) error { + p.mu.Lock() + if p.active >= p.maxConcurrent { + active := p.active + max := p.maxConcurrent + p.mu.Unlock() + return fmt.Errorf("executor pool at capacity (%d/%d)", active, max) + } + p.active++ + p.mu.Unlock() + + go p.execute(ctx, t) + return nil +} + +// Results returns the channel for reading execution results. +func (p *Pool) Results() <-chan *Result { + return p.resultCh +} + +// ActiveCount returns the number of currently running tasks. +func (p *Pool) ActiveCount() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.active +} + +func (p *Pool) execute(ctx context.Context, t *task.Task) { + execID := uuid.New().String() + exec := &storage.Execution{ + ID: execID, + TaskID: t.ID, + StartTime: time.Now().UTC(), + Status: "RUNNING", + } + + // Record execution start. + if err := p.store.CreateExecution(exec); err != nil { + p.logger.Error("failed to create execution record", "error", err) + } + if err := p.store.UpdateTaskState(t.ID, task.StateRunning); err != nil { + p.logger.Error("failed to update task state", "error", err) + } + + // Apply task timeout. + var cancel context.CancelFunc + if t.Timeout.Duration > 0 { + ctx, cancel = context.WithTimeout(ctx, t.Timeout.Duration) + } else { + ctx, cancel = context.WithCancel(ctx) + } + defer cancel() + + // Run the task. + err := p.runner.Run(ctx, t, exec) + exec.EndTime = time.Now().UTC() + + if err != nil { + if ctx.Err() == context.DeadlineExceeded { + exec.Status = "TIMED_OUT" + exec.ErrorMsg = "execution timed out" + p.store.UpdateTaskState(t.ID, task.StateTimedOut) + } else if ctx.Err() == context.Canceled { + exec.Status = "CANCELLED" + exec.ErrorMsg = "execution cancelled" + p.store.UpdateTaskState(t.ID, task.StateCancelled) + } else { + exec.Status = "FAILED" + exec.ErrorMsg = err.Error() + p.store.UpdateTaskState(t.ID, task.StateFailed) + } + } else { + exec.Status = "COMPLETED" + p.store.UpdateTaskState(t.ID, task.StateCompleted) + } + + if updateErr := p.store.UpdateExecution(exec); updateErr != nil { + p.logger.Error("failed to update execution", "error", updateErr) + } + + p.mu.Lock() + p.active-- + p.mu.Unlock() + + p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err} +} diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go new file mode 100644 index 0000000..acce95b --- /dev/null +++ b/internal/executor/executor_test.go @@ -0,0 +1,206 @@ +package executor + +import ( + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/claudomator/claudomator/internal/storage" + "github.com/claudomator/claudomator/internal/task" +) + +// mockRunner implements Runner for testing. +type mockRunner struct { + mu sync.Mutex + calls int + delay time.Duration + err error + exitCode int +} + +func (m *mockRunner) Run(ctx context.Context, t *task.Task, e *storage.Execution) error { + m.mu.Lock() + m.calls++ + m.mu.Unlock() + + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return ctx.Err() + } + } + if m.err != nil { + e.ExitCode = m.exitCode + return m.err + } + return nil +} + +func (m *mockRunner) callCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.calls +} + +func testStore(t *testing.T) *storage.DB { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "test.db") + db, err := storage.Open(dbPath) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { db.Close() }) + return db +} + +func makeTask(id string) *task.Task { + now := time.Now().UTC() + return &task.Task{ + ID: id, Name: "Test " + id, + Claude: task.ClaudeConfig{Instructions: "test"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StateQueued, + CreatedAt: now, UpdatedAt: now, + } +} + +func TestPool_Submit_Success(t *testing.T) { + store := testStore(t) + runner := &mockRunner{} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + tk := makeTask("ps-1") + store.CreateTask(tk) + + if err := pool.Submit(context.Background(), tk); err != nil { + t.Fatalf("submit: %v", err) + } + + result := <-pool.Results() + if result.Err != nil { + t.Errorf("expected no error, got: %v", result.Err) + } + if result.Execution.Status != "COMPLETED" { + t.Errorf("status: want COMPLETED, got %q", result.Execution.Status) + } + + // Verify task state in DB. + got, _ := store.GetTask("ps-1") + if got.State != task.StateCompleted { + t.Errorf("task state: want COMPLETED, got %v", got.State) + } +} + +func TestPool_Submit_Failure(t *testing.T) { + store := testStore(t) + runner := &mockRunner{err: fmt.Errorf("boom"), exitCode: 1} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + tk := makeTask("pf-1") + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + result := <-pool.Results() + if result.Err == nil { + t.Fatal("expected error") + } + if result.Execution.Status != "FAILED" { + t.Errorf("status: want FAILED, got %q", result.Execution.Status) + } +} + +func TestPool_Submit_Timeout(t *testing.T) { + store := testStore(t) + runner := &mockRunner{delay: 5 * time.Second} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + tk := makeTask("pt-1") + tk.Timeout.Duration = 50 * time.Millisecond + store.CreateTask(tk) + pool.Submit(context.Background(), tk) + + result := <-pool.Results() + if result.Execution.Status != "TIMED_OUT" { + t.Errorf("status: want TIMED_OUT, got %q", result.Execution.Status) + } +} + +func TestPool_Submit_Cancellation(t *testing.T) { + store := testStore(t) + runner := &mockRunner{delay: 5 * time.Second} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(2, runner, store, logger) + + ctx, cancel := context.WithCancel(context.Background()) + tk := makeTask("pc-1") + store.CreateTask(tk) + pool.Submit(ctx, tk) + + time.Sleep(20 * time.Millisecond) + cancel() + + result := <-pool.Results() + if result.Execution.Status != "CANCELLED" { + t.Errorf("status: want CANCELLED, got %q", result.Execution.Status) + } +} + +func TestPool_AtCapacity(t *testing.T) { + store := testStore(t) + runner := &mockRunner{delay: time.Second} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(1, runner, store, logger) + + tk1 := makeTask("cap-1") + store.CreateTask(tk1) + pool.Submit(context.Background(), tk1) + + // Pool is at capacity, second submit should fail. + time.Sleep(10 * time.Millisecond) // let goroutine start + tk2 := makeTask("cap-2") + store.CreateTask(tk2) + err := pool.Submit(context.Background(), tk2) + if err == nil { + t.Fatal("expected capacity error") + } + + <-pool.Results() // drain +} + +func TestPool_ConcurrentExecution(t *testing.T) { + store := testStore(t) + runner := &mockRunner{delay: 50 * time.Millisecond} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + pool := NewPool(3, runner, store, logger) + + for i := 0; i < 3; i++ { + tk := makeTask(fmt.Sprintf("cc-%d", i)) + store.CreateTask(tk) + if err := pool.Submit(context.Background(), tk); err != nil { + t.Fatalf("submit %d: %v", i, err) + } + } + + for i := 0; i < 3; i++ { + result := <-pool.Results() + if result.Execution.Status != "COMPLETED" { + t.Errorf("task %s: want COMPLETED, got %q", result.TaskID, result.Execution.Status) + } + } + + if runner.callCount() != 3 { + t.Errorf("calls: want 3, got %d", runner.callCount()) + } +} diff --git a/internal/notify/notify.go b/internal/notify/notify.go new file mode 100644 index 0000000..86e641f --- /dev/null +++ b/internal/notify/notify.go @@ -0,0 +1,95 @@ +package notify + +import ( + "bytes" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" +) + +// Notifier sends notifications when tasks complete. +type Notifier interface { + Notify(event Event) error +} + +// Event represents a task completion event. +type Event struct { + TaskID string `json:"task_id"` + TaskName string `json:"task_name"` + Status string `json:"status"` + CostUSD float64 `json:"cost_usd"` + Duration string `json:"duration"` + Error string `json:"error,omitempty"` +} + +// WebhookNotifier sends POST requests to a configured URL. +type WebhookNotifier struct { + URL string + client *http.Client + logger *slog.Logger +} + +func NewWebhookNotifier(url string, logger *slog.Logger) *WebhookNotifier { + return &WebhookNotifier{ + URL: url, + client: &http.Client{Timeout: 10 * time.Second}, + logger: logger, + } +} + +func (w *WebhookNotifier) Notify(event Event) error { + body, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("marshaling event: %w", err) + } + + resp, err := w.client.Post(w.URL, "application/json", bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("sending webhook: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + return fmt.Errorf("webhook returned %d", resp.StatusCode) + } + return nil +} + +// MultiNotifier fans out to multiple notifiers. +type MultiNotifier struct { + notifiers []Notifier + logger *slog.Logger +} + +func NewMultiNotifier(logger *slog.Logger, notifiers ...Notifier) *MultiNotifier { + return &MultiNotifier{notifiers: notifiers, logger: logger} +} + +func (m *MultiNotifier) Notify(event Event) error { + var lastErr error + for _, n := range m.notifiers { + if err := n.Notify(event); err != nil { + m.logger.Error("notification failed", "error", err) + lastErr = err + } + } + return lastErr +} + +// LogNotifier logs events (useful as a default/fallback). +type LogNotifier struct { + Logger *slog.Logger +} + +func (l *LogNotifier) Notify(event Event) error { + l.Logger.Info("task completed", + "task_id", event.TaskID, + "task_name", event.TaskName, + "status", event.Status, + "cost_usd", event.CostUSD, + "duration", event.Duration, + ) + return nil +} diff --git a/internal/notify/notify_test.go b/internal/notify/notify_test.go new file mode 100644 index 0000000..fcb5345 --- /dev/null +++ b/internal/notify/notify_test.go @@ -0,0 +1,86 @@ +package notify + +import ( + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestWebhookNotifier_Success(t *testing.T) { + var received Event + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &received) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + notifier := NewWebhookNotifier(server.URL, logger) + + event := Event{ + TaskID: "t-1", + TaskName: "Test", + Status: "COMPLETED", + CostUSD: 0.50, + Duration: "2m30s", + } + + if err := notifier.Notify(event); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if received.TaskID != "t-1" { + t.Errorf("task_id: want 't-1', got %q", received.TaskID) + } + if received.CostUSD != 0.50 { + t.Errorf("cost: want 0.50, got %f", received.CostUSD) + } +} + +func TestWebhookNotifier_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + notifier := NewWebhookNotifier(server.URL, logger) + + err := notifier.Notify(Event{TaskID: "t-1", Status: "COMPLETED"}) + if err == nil { + t.Fatal("expected error for 500 response") + } +} + +func TestMultiNotifier_FansOut(t *testing.T) { + var count int + counter := &countingNotifier{count: &count} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + multi := NewMultiNotifier(logger, counter, counter, counter) + + multi.Notify(Event{TaskID: "t-1"}) + if count != 3 { + t.Errorf("want 3 notifications, got %d", count) + } +} + +func TestLogNotifier_NoError(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + notifier := &LogNotifier{Logger: logger} + if err := notifier.Notify(Event{TaskID: "t-1", Status: "COMPLETED"}); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +type countingNotifier struct { + count *int +} + +func (c *countingNotifier) Notify(_ Event) error { + *c.count++ + return nil +} diff --git a/internal/reporter/reporter.go b/internal/reporter/reporter.go new file mode 100644 index 0000000..4ba66e0 --- /dev/null +++ b/internal/reporter/reporter.go @@ -0,0 +1,117 @@ +package reporter + +import ( + "encoding/json" + "fmt" + "io" + "text/tabwriter" + "time" + + "github.com/claudomator/claudomator/internal/storage" +) + +// Reporter generates reports from execution data. +type Reporter interface { + Generate(w io.Writer, executions []*storage.Execution) error +} + +// ConsoleReporter outputs a formatted table. +type ConsoleReporter struct{} + +func (r *ConsoleReporter) Generate(w io.Writer, executions []*storage.Execution) error { + if len(executions) == 0 { + fmt.Fprintln(w, "No executions found.") + return nil + } + + tw := tabwriter.NewWriter(w, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, "ID\tTASK\tSTATUS\tEXIT\tCOST\tDURATION\tSTARTED") + + var totalCost float64 + var completed, failed int + + for _, e := range executions { + dur := e.EndTime.Sub(e.StartTime) + totalCost += e.CostUSD + if e.Status == "COMPLETED" { + completed++ + } else { + failed++ + } + + fmt.Fprintf(tw, "%.8s\t%.8s\t%s\t%d\t$%.4f\t%v\t%s\n", + e.ID, e.TaskID, e.Status, e.ExitCode, e.CostUSD, + dur.Round(time.Second), e.StartTime.Format("2006-01-02 15:04")) + } + tw.Flush() + + fmt.Fprintf(w, "\nSummary: %d completed, %d failed, total cost $%.4f\n", completed, failed, totalCost) + return nil +} + +// JSONReporter outputs JSON. +type JSONReporter struct { + Pretty bool +} + +func (r *JSONReporter) Generate(w io.Writer, executions []*storage.Execution) error { + enc := json.NewEncoder(w) + if r.Pretty { + enc.SetIndent("", " ") + } + return enc.Encode(executions) +} + +// HTMLReporter generates a standalone HTML report. +type HTMLReporter struct{} + +func (r *HTMLReporter) Generate(w io.Writer, executions []*storage.Execution) error { + var totalCost float64 + var completed, failed int + for _, e := range executions { + totalCost += e.CostUSD + if e.Status == "COMPLETED" { + completed++ + } else { + failed++ + } + } + + fmt.Fprint(w, `<!DOCTYPE html> +<html lang="en"><head><meta charset="utf-8"><meta name="viewport" content="width=device-width, initial-scale=1"> +<title>Claudomator Report</title> +<style> + * { box-sizing: border-box; margin: 0; padding: 0; } + body { font-family: -apple-system, BlinkMacSystemFont, sans-serif; background: #0f172a; color: #e2e8f0; padding: 1rem; } + .container { max-width: 960px; margin: 0 auto; } + h1 { font-size: 1.5rem; margin-bottom: 1rem; color: #7dd3fc; } + .summary { display: flex; gap: 1rem; margin-bottom: 1.5rem; flex-wrap: wrap; } + .stat { background: #1e293b; padding: 1rem; border-radius: 0.5rem; flex: 1; min-width: 120px; } + .stat .label { font-size: 0.75rem; color: #94a3b8; text-transform: uppercase; } + .stat .value { font-size: 1.5rem; font-weight: bold; margin-top: 0.25rem; } + .ok { color: #4ade80; } .fail { color: #f87171; } .cost { color: #fbbf24; } + table { width: 100%; border-collapse: collapse; background: #1e293b; border-radius: 0.5rem; overflow: hidden; } + th { background: #334155; padding: 0.75rem; text-align: left; font-size: 0.75rem; text-transform: uppercase; color: #94a3b8; } + td { padding: 0.75rem; border-top: 1px solid #334155; font-size: 0.875rem; } + tr:hover { background: #334155; } + .status-COMPLETED { color: #4ade80; } .status-FAILED { color: #f87171; } .status-TIMED_OUT { color: #fbbf24; } +</style></head><body><div class="container"> +<h1>Claudomator Report</h1> +<div class="summary">`) + + fmt.Fprintf(w, `<div class="stat"><div class="label">Completed</div><div class="value ok">%d</div></div>`, completed) + fmt.Fprintf(w, `<div class="stat"><div class="label">Failed</div><div class="value fail">%d</div></div>`, failed) + fmt.Fprintf(w, `<div class="stat"><div class="label">Total Cost</div><div class="value cost">$%.4f</div></div>`, totalCost) + fmt.Fprintf(w, `<div class="stat"><div class="label">Total</div><div class="value">%d</div></div>`, len(executions)) + + fmt.Fprint(w, `</div><table><thead><tr><th>ID</th><th>Task</th><th>Status</th><th>Exit</th><th>Cost</th><th>Duration</th><th>Started</th></tr></thead><tbody>`) + + for _, e := range executions { + dur := e.EndTime.Sub(e.StartTime).Round(time.Second) + fmt.Fprintf(w, `<tr><td>%.8s</td><td>%.8s</td><td class="status-%s">%s</td><td>%d</td><td>$%.4f</td><td>%v</td><td>%s</td></tr>`, + e.ID, e.TaskID, e.Status, e.Status, e.ExitCode, e.CostUSD, dur, e.StartTime.Format("2006-01-02 15:04")) + } + + fmt.Fprint(w, `</tbody></table></div></body></html>`) + return nil +} diff --git a/internal/reporter/reporter_test.go b/internal/reporter/reporter_test.go new file mode 100644 index 0000000..1ddce23 --- /dev/null +++ b/internal/reporter/reporter_test.go @@ -0,0 +1,114 @@ +package reporter + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + "time" + + "github.com/claudomator/claudomator/internal/storage" +) + +func sampleExecutions() []*storage.Execution { + now := time.Date(2026, 2, 8, 10, 0, 0, 0, time.UTC) + return []*storage.Execution{ + { + ID: "exec-1", TaskID: "task-1", Status: "COMPLETED", + StartTime: now, EndTime: now.Add(2 * time.Minute), + ExitCode: 0, CostUSD: 0.25, + }, + { + ID: "exec-2", TaskID: "task-2", Status: "FAILED", + StartTime: now, EndTime: now.Add(30 * time.Second), + ExitCode: 1, CostUSD: 0.10, ErrorMsg: "something broke", + }, + } +} + +func TestConsoleReporter_WithExecutions(t *testing.T) { + r := &ConsoleReporter{} + var buf bytes.Buffer + err := r.Generate(&buf, sampleExecutions()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + output := buf.String() + if !strings.Contains(output, "COMPLETED") { + t.Error("missing COMPLETED status") + } + if !strings.Contains(output, "FAILED") { + t.Error("missing FAILED status") + } + if !strings.Contains(output, "1 completed, 1 failed") { + t.Errorf("missing summary in output: %s", output) + } + if !strings.Contains(output, "$0.3500") { + t.Errorf("missing total cost in output: %s", output) + } +} + +func TestConsoleReporter_Empty(t *testing.T) { + r := &ConsoleReporter{} + var buf bytes.Buffer + r.Generate(&buf, []*storage.Execution{}) + if !strings.Contains(buf.String(), "No executions") { + t.Error("expected 'No executions' message") + } +} + +func TestJSONReporter(t *testing.T) { + r := &JSONReporter{Pretty: false} + var buf bytes.Buffer + err := r.Generate(&buf, sampleExecutions()) + if err != nil { + t.Fatal(err) + } + + var result []storage.Execution + if err := json.Unmarshal(buf.Bytes(), &result); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if len(result) != 2 { + t.Errorf("want 2 results, got %d", len(result)) + } + if result[0].Status != "COMPLETED" { + t.Errorf("want COMPLETED, got %q", result[0].Status) + } +} + +func TestJSONReporter_Pretty(t *testing.T) { + r := &JSONReporter{Pretty: true} + var buf bytes.Buffer + r.Generate(&buf, sampleExecutions()) + if !strings.Contains(buf.String(), " ") { + t.Error("expected indented JSON") + } +} + +func TestHTMLReporter(t *testing.T) { + r := &HTMLReporter{} + var buf bytes.Buffer + err := r.Generate(&buf, sampleExecutions()) + if err != nil { + t.Fatal(err) + } + + html := buf.String() + if !strings.Contains(html, "<!DOCTYPE html>") { + t.Error("missing DOCTYPE") + } + if !strings.Contains(html, "Claudomator Report") { + t.Error("missing title") + } + if !strings.Contains(html, "COMPLETED") { + t.Error("missing COMPLETED status") + } + if !strings.Contains(html, "FAILED") { + t.Error("missing FAILED status") + } + if !strings.Contains(html, "$0.3500") { + t.Error("missing total cost") + } +} diff --git a/internal/storage/db.go b/internal/storage/db.go new file mode 100644 index 0000000..67fbe08 --- /dev/null +++ b/internal/storage/db.go @@ -0,0 +1,278 @@ +package storage + +import ( + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/claudomator/claudomator/internal/task" + _ "github.com/mattn/go-sqlite3" +) + +type DB struct { + db *sql.DB +} + +func Open(path string) (*DB, error) { + db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_busy_timeout=5000") + if err != nil { + return nil, fmt.Errorf("opening database: %w", err) + } + s := &DB{db: db} + if err := s.migrate(); err != nil { + db.Close() + return nil, fmt.Errorf("running migrations: %w", err) + } + return s, nil +} + +func (s *DB) Close() error { + return s.db.Close() +} + +func (s *DB) migrate() error { + schema := ` + CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + config_json TEXT NOT NULL, + priority TEXT NOT NULL DEFAULT 'normal', + timeout_ns INTEGER NOT NULL DEFAULT 0, + retry_json TEXT NOT NULL DEFAULT '{}', + tags_json TEXT NOT NULL DEFAULT '[]', + depends_on_json TEXT NOT NULL DEFAULT '[]', + state TEXT NOT NULL DEFAULT 'PENDING', + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + ); + + CREATE TABLE IF NOT EXISTS executions ( + id TEXT PRIMARY KEY, + task_id TEXT NOT NULL, + start_time DATETIME NOT NULL, + end_time DATETIME, + exit_code INTEGER, + status TEXT NOT NULL, + stdout_path TEXT, + stderr_path TEXT, + artifact_dir TEXT, + cost_usd REAL, + error_msg TEXT, + FOREIGN KEY (task_id) REFERENCES tasks(id) + ); + + CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state); + CREATE INDEX IF NOT EXISTS idx_executions_status ON executions(status); + CREATE INDEX IF NOT EXISTS idx_executions_task_id ON executions(task_id); + ` + _, err := s.db.Exec(schema) + return err +} + +// CreateTask inserts a task into the database. +func (s *DB) CreateTask(t *task.Task) error { + configJSON, err := json.Marshal(t.Claude) + if err != nil { + return fmt.Errorf("marshaling config: %w", err) + } + retryJSON, err := json.Marshal(t.Retry) + if err != nil { + return fmt.Errorf("marshaling retry: %w", err) + } + tagsJSON, err := json.Marshal(t.Tags) + if err != nil { + return fmt.Errorf("marshaling tags: %w", err) + } + depsJSON, err := json.Marshal(t.DependsOn) + if err != nil { + return fmt.Errorf("marshaling depends_on: %w", err) + } + + _, err = s.db.Exec(` + INSERT INTO tasks (id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + t.ID, t.Name, t.Description, string(configJSON), string(t.Priority), + t.Timeout.Duration.Nanoseconds(), string(retryJSON), string(tagsJSON), string(depsJSON), + string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), + ) + return err +} + +// GetTask retrieves a task by ID. +func (s *DB) GetTask(id string) (*task.Task, error) { + row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE id = ?`, id) + return scanTask(row) +} + +// ListTasks returns tasks matching the given filter. +func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { + query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE 1=1` + var args []interface{} + + if filter.State != "" { + query += " AND state = ?" + args = append(args, string(filter.State)) + } + query += " ORDER BY created_at DESC" + if filter.Limit > 0 { + query += " LIMIT ?" + args = append(args, filter.Limit) + } + + rows, err := s.db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var tasks []*task.Task + for rows.Next() { + t, err := scanTaskRows(rows) + if err != nil { + return nil, err + } + tasks = append(tasks, t) + } + return tasks, rows.Err() +} + +// UpdateTaskState atomically updates a task's state. +func (s *DB) UpdateTaskState(id string, newState task.State) error { + now := time.Now().UTC() + result, err := s.db.Exec(`UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?`, string(newState), now, id) + if err != nil { + return err + } + n, err := result.RowsAffected() + if err != nil { + return err + } + if n == 0 { + return fmt.Errorf("task %q not found", id) + } + return nil +} + +// TaskFilter specifies criteria for listing tasks. +type TaskFilter struct { + State task.State + Limit int +} + +// Execution represents a single run of a task. +type Execution struct { + ID string + TaskID string + StartTime time.Time + EndTime time.Time + ExitCode int + Status string + StdoutPath string + StderrPath string + ArtifactDir string + CostUSD float64 + ErrorMsg string +} + +// CreateExecution inserts an execution record. +func (s *DB) CreateExecution(e *Execution) error { + _, err := s.db.Exec(` + INSERT INTO executions (id, task_id, start_time, end_time, exit_code, status, stdout_path, stderr_path, artifact_dir, cost_usd, error_msg) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + e.ID, e.TaskID, e.StartTime.UTC(), e.EndTime.UTC(), e.ExitCode, e.Status, + e.StdoutPath, e.StderrPath, e.ArtifactDir, e.CostUSD, e.ErrorMsg, + ) + return err +} + +// GetExecution retrieves an execution by ID. +func (s *DB) GetExecution(id string) (*Execution, error) { + row := s.db.QueryRow(`SELECT id, task_id, start_time, end_time, exit_code, status, stdout_path, stderr_path, artifact_dir, cost_usd, error_msg FROM executions WHERE id = ?`, id) + return scanExecution(row) +} + +// ListExecutions returns executions for a task. +func (s *DB) ListExecutions(taskID string) ([]*Execution, error) { + rows, err := s.db.Query(`SELECT id, task_id, start_time, end_time, exit_code, status, stdout_path, stderr_path, artifact_dir, cost_usd, error_msg FROM executions WHERE task_id = ? ORDER BY start_time DESC`, taskID) + if err != nil { + return nil, err + } + defer rows.Close() + + var execs []*Execution + for rows.Next() { + e, err := scanExecutionRows(rows) + if err != nil { + return nil, err + } + execs = append(execs, e) + } + return execs, rows.Err() +} + +// UpdateExecution updates a completed execution. +func (s *DB) UpdateExecution(e *Execution) error { + _, err := s.db.Exec(` + UPDATE executions SET end_time = ?, exit_code = ?, status = ?, cost_usd = ?, error_msg = ? + WHERE id = ?`, + e.EndTime.UTC(), e.ExitCode, e.Status, e.CostUSD, e.ErrorMsg, e.ID, + ) + return err +} + +type scanner interface { + Scan(dest ...interface{}) error +} + +func scanTask(row scanner) (*task.Task, error) { + var ( + t task.Task + configJSON string + retryJSON string + tagsJSON string + depsJSON string + state string + priority string + timeoutNS int64 + ) + err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &state, &t.CreatedAt, &t.UpdatedAt) + if err != nil { + return nil, err + } + t.State = task.State(state) + t.Priority = task.Priority(priority) + t.Timeout.Duration = time.Duration(timeoutNS) + if err := json.Unmarshal([]byte(configJSON), &t.Claude); err != nil { + return nil, fmt.Errorf("unmarshaling config: %w", err) + } + if err := json.Unmarshal([]byte(retryJSON), &t.Retry); err != nil { + return nil, fmt.Errorf("unmarshaling retry: %w", err) + } + if err := json.Unmarshal([]byte(tagsJSON), &t.Tags); err != nil { + return nil, fmt.Errorf("unmarshaling tags: %w", err) + } + if err := json.Unmarshal([]byte(depsJSON), &t.DependsOn); err != nil { + return nil, fmt.Errorf("unmarshaling depends_on: %w", err) + } + return &t, nil +} + +func scanTaskRows(rows *sql.Rows) (*task.Task, error) { + return scanTask(rows) +} + +func scanExecution(row scanner) (*Execution, error) { + var e Execution + err := row.Scan(&e.ID, &e.TaskID, &e.StartTime, &e.EndTime, &e.ExitCode, &e.Status, + &e.StdoutPath, &e.StderrPath, &e.ArtifactDir, &e.CostUSD, &e.ErrorMsg) + if err != nil { + return nil, err + } + return &e, nil +} + +func scanExecutionRows(rows *sql.Rows) (*Execution, error) { + return scanExecution(rows) +} diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go new file mode 100644 index 0000000..78cb1e1 --- /dev/null +++ b/internal/storage/db_test.go @@ -0,0 +1,285 @@ +package storage + +import ( + "fmt" + "path/filepath" + "testing" + "time" + + "github.com/claudomator/claudomator/internal/task" +) + +func testDB(t *testing.T) *DB { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "test.db") + db, err := Open(dbPath) + if err != nil { + t.Fatalf("opening db: %v", err) + } + t.Cleanup(func() { db.Close() }) + return db +} + +func TestOpen_CreatesSchema(t *testing.T) { + db := testDB(t) + // Should be able to query tasks table. + _, err := db.ListTasks(TaskFilter{}) + if err != nil { + t.Fatalf("querying tasks: %v", err) + } +} + +func TestCreateTask_AndGetTask(t *testing.T) { + db := testDB(t) + now := time.Now().UTC().Truncate(time.Second) + + tk := &task.Task{ + ID: "task-1", + Name: "Test Task", + Description: "A test", + Claude: task.ClaudeConfig{ + Model: "sonnet", + Instructions: "do it", + WorkingDir: "/tmp", + MaxBudgetUSD: 2.5, + }, + Priority: task.PriorityHigh, + Tags: []string{"test", "alpha"}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "exponential"}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + tk.Timeout.Duration = 10 * time.Minute + + if err := db.CreateTask(tk); err != nil { + t.Fatalf("creating task: %v", err) + } + + got, err := db.GetTask("task-1") + if err != nil { + t.Fatalf("getting task: %v", err) + } + if got.Name != "Test Task" { + t.Errorf("name: want 'Test Task', got %q", got.Name) + } + if got.Claude.Model != "sonnet" { + t.Errorf("model: want 'sonnet', got %q", got.Claude.Model) + } + if got.Claude.MaxBudgetUSD != 2.5 { + t.Errorf("budget: want 2.5, got %f", got.Claude.MaxBudgetUSD) + } + if got.Priority != task.PriorityHigh { + t.Errorf("priority: want 'high', got %q", got.Priority) + } + if got.Timeout.Duration != 10*time.Minute { + t.Errorf("timeout: want 10m, got %v", got.Timeout.Duration) + } + if got.Retry.MaxAttempts != 3 { + t.Errorf("retry: want 3, got %d", got.Retry.MaxAttempts) + } + if len(got.Tags) != 2 || got.Tags[0] != "test" { + t.Errorf("tags: want [test alpha], got %v", got.Tags) + } + if got.State != task.StatePending { + t.Errorf("state: want PENDING, got %v", got.State) + } +} + +func TestUpdateTaskState(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + tk := &task.Task{ + ID: "task-2", + Name: "Stateful", + Claude: task.ClaudeConfig{Instructions: "test"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, + DependsOn: []string{}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + if err := db.CreateTask(tk); err != nil { + t.Fatal(err) + } + + if err := db.UpdateTaskState("task-2", task.StateQueued); err != nil { + t.Fatalf("updating state: %v", err) + } + got, _ := db.GetTask("task-2") + if got.State != task.StateQueued { + t.Errorf("state: want QUEUED, got %v", got.State) + } +} + +func TestUpdateTaskState_NotFound(t *testing.T) { + db := testDB(t) + err := db.UpdateTaskState("nonexistent", task.StateQueued) + if err == nil { + t.Fatal("expected error for nonexistent task") + } +} + +func TestListTasks_FilterByState(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + + for i, state := range []task.State{task.StatePending, task.StatePending, task.StateRunning} { + tk := &task.Task{ + ID: fmt.Sprintf("t-%d", i), Name: fmt.Sprintf("Task %d", i), + Claude: task.ClaudeConfig{Instructions: "x"}, Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: state, CreatedAt: now, UpdatedAt: now, + } + if err := db.CreateTask(tk); err != nil { + t.Fatal(err) + } + } + + pending, err := db.ListTasks(TaskFilter{State: task.StatePending}) + if err != nil { + t.Fatal(err) + } + if len(pending) != 2 { + t.Errorf("want 2 pending, got %d", len(pending)) + } + + running, err := db.ListTasks(TaskFilter{State: task.StateRunning}) + if err != nil { + t.Fatal(err) + } + if len(running) != 1 { + t.Errorf("want 1 running, got %d", len(running)) + } +} + +func TestListTasks_WithLimit(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + for i := 0; i < 5; i++ { + tk := &task.Task{ + ID: fmt.Sprintf("lt-%d", i), Name: fmt.Sprintf("T%d", i), + Claude: task.ClaudeConfig{Instructions: "x"}, Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: now.Add(time.Duration(i) * time.Second), UpdatedAt: now, + } + db.CreateTask(tk) + } + + tasks, err := db.ListTasks(TaskFilter{Limit: 3}) + if err != nil { + t.Fatal(err) + } + if len(tasks) != 3 { + t.Errorf("want 3, got %d", len(tasks)) + } +} + +func TestCreateExecution_AndGet(t *testing.T) { + db := testDB(t) + now := time.Now().UTC().Truncate(time.Second) + + // Need a task first. + tk := &task.Task{ + ID: "etask", Name: "E", Claude: task.ClaudeConfig{Instructions: "x"}, + Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: now, UpdatedAt: now, + } + db.CreateTask(tk) + + exec := &Execution{ + ID: "exec-1", + TaskID: "etask", + StartTime: now, + EndTime: now.Add(5 * time.Minute), + ExitCode: 0, + Status: "COMPLETED", + StdoutPath: "/tmp/stdout.log", + StderrPath: "/tmp/stderr.log", + CostUSD: 0.42, + } + if err := db.CreateExecution(exec); err != nil { + t.Fatalf("creating execution: %v", err) + } + + got, err := db.GetExecution("exec-1") + if err != nil { + t.Fatalf("getting execution: %v", err) + } + if got.Status != "COMPLETED" { + t.Errorf("status: want COMPLETED, got %q", got.Status) + } + if got.CostUSD != 0.42 { + t.Errorf("cost: want 0.42, got %f", got.CostUSD) + } + if got.ExitCode != 0 { + t.Errorf("exit code: want 0, got %d", got.ExitCode) + } +} + +func TestListExecutions(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + tk := &task.Task{ + ID: "ltask", Name: "L", Claude: task.ClaudeConfig{Instructions: "x"}, + Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: now, UpdatedAt: now, + } + db.CreateTask(tk) + + for i := 0; i < 3; i++ { + db.CreateExecution(&Execution{ + ID: fmt.Sprintf("le-%d", i), TaskID: "ltask", + StartTime: now.Add(time.Duration(i) * time.Minute), EndTime: now.Add(time.Duration(i+1) * time.Minute), + Status: "COMPLETED", + }) + } + + execs, err := db.ListExecutions("ltask") + if err != nil { + t.Fatal(err) + } + if len(execs) != 3 { + t.Errorf("want 3, got %d", len(execs)) + } +} + +func TestUpdateExecution(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + tk := &task.Task{ + ID: "utask", Name: "U", Claude: task.ClaudeConfig{Instructions: "x"}, + Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: now, UpdatedAt: now, + } + db.CreateTask(tk) + + exec := &Execution{ + ID: "ue-1", TaskID: "utask", StartTime: now, EndTime: now, Status: "RUNNING", + } + db.CreateExecution(exec) + + exec.Status = "FAILED" + exec.ExitCode = 1 + exec.ErrorMsg = "something broke" + exec.EndTime = now.Add(2 * time.Minute) + if err := db.UpdateExecution(exec); err != nil { + t.Fatal(err) + } + + got, _ := db.GetExecution("ue-1") + if got.Status != "FAILED" { + t.Errorf("status: want FAILED, got %q", got.Status) + } + if got.ErrorMsg != "something broke" { + t.Errorf("error: want 'something broke', got %q", got.ErrorMsg) + } +} diff --git a/internal/task/parser.go b/internal/task/parser.go new file mode 100644 index 0000000..7a450b8 --- /dev/null +++ b/internal/task/parser.go @@ -0,0 +1,61 @@ +package task + +import ( + "fmt" + "os" + "time" + + "github.com/google/uuid" + "gopkg.in/yaml.v3" +) + +// ParseFile reads a YAML file and returns tasks. Supports both single-task +// and batch (tasks: [...]) formats. +func ParseFile(path string) ([]Task, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading task file: %w", err) + } + return Parse(data) +} + +// Parse parses YAML bytes into tasks. +func Parse(data []byte) ([]Task, error) { + // Try batch format first. + var batch BatchFile + if err := yaml.Unmarshal(data, &batch); err == nil && len(batch.Tasks) > 0 { + return initTasks(batch.Tasks), nil + } + + // Try single task. + var t Task + if err := yaml.Unmarshal(data, &t); err != nil { + return nil, fmt.Errorf("parsing task YAML: %w", err) + } + if t.Name == "" { + return nil, fmt.Errorf("task must have a name") + } + return initTasks([]Task{t}), nil +} + +func initTasks(tasks []Task) []Task { + now := time.Now() + for i := range tasks { + if tasks[i].ID == "" { + tasks[i].ID = uuid.New().String() + } + if tasks[i].Priority == "" { + tasks[i].Priority = PriorityNormal + } + if tasks[i].Retry.MaxAttempts == 0 { + tasks[i].Retry.MaxAttempts = 1 + } + if tasks[i].Retry.Backoff == "" { + tasks[i].Retry.Backoff = "exponential" + } + tasks[i].State = StatePending + tasks[i].CreatedAt = now + tasks[i].UpdatedAt = now + } + return tasks +} diff --git a/internal/task/parser_test.go b/internal/task/parser_test.go new file mode 100644 index 0000000..cb68e86 --- /dev/null +++ b/internal/task/parser_test.go @@ -0,0 +1,152 @@ +package task + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestParse_SingleTask(t *testing.T) { + yaml := ` +name: "Test Task" +description: "A simple test" +claude: + model: "sonnet" + instructions: "Do something" + working_dir: "/tmp" +timeout: "10m" +tags: + - "test" +` + tasks, err := Parse([]byte(yaml)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tasks) != 1 { + t.Fatalf("expected 1 task, got %d", len(tasks)) + } + task := tasks[0] + if task.Name != "Test Task" { + t.Errorf("expected name 'Test Task', got %q", task.Name) + } + if task.Claude.Model != "sonnet" { + t.Errorf("expected model 'sonnet', got %q", task.Claude.Model) + } + if task.Timeout.Duration != 10*time.Minute { + t.Errorf("expected timeout 10m, got %v", task.Timeout.Duration) + } + if task.State != StatePending { + t.Errorf("expected state PENDING, got %v", task.State) + } + if task.ID == "" { + t.Error("expected auto-generated ID") + } + if task.Priority != PriorityNormal { + t.Errorf("expected default priority 'normal', got %q", task.Priority) + } +} + +func TestParse_BatchTasks(t *testing.T) { + yaml := ` +tasks: + - name: "Task A" + claude: + instructions: "Do A" + working_dir: "/tmp" + tags: ["alpha"] + - name: "Task B" + claude: + instructions: "Do B" + working_dir: "/tmp" + tags: ["beta"] +` + tasks, err := Parse([]byte(yaml)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tasks) != 2 { + t.Fatalf("expected 2 tasks, got %d", len(tasks)) + } + if tasks[0].Name != "Task A" { + t.Errorf("expected 'Task A', got %q", tasks[0].Name) + } + if tasks[1].Name != "Task B" { + t.Errorf("expected 'Task B', got %q", tasks[1].Name) + } +} + +func TestParse_MissingName_ReturnsError(t *testing.T) { + yaml := ` +description: "no name" +claude: + instructions: "something" +` + _, err := Parse([]byte(yaml)) + if err == nil { + t.Fatal("expected error for missing name") + } +} + +func TestParse_DefaultRetryConfig(t *testing.T) { + yaml := ` +name: "Defaults" +claude: + instructions: "test" +` + tasks, err := Parse([]byte(yaml)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tasks[0].Retry.MaxAttempts != 1 { + t.Errorf("expected default max_attempts=1, got %d", tasks[0].Retry.MaxAttempts) + } + if tasks[0].Retry.Backoff != "exponential" { + t.Errorf("expected default backoff 'exponential', got %q", tasks[0].Retry.Backoff) + } +} + +func TestParse_WithPriority(t *testing.T) { + yaml := ` +name: "High Priority" +priority: "high" +claude: + instructions: "urgent" +` + tasks, err := Parse([]byte(yaml)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tasks[0].Priority != PriorityHigh { + t.Errorf("expected priority 'high', got %q", tasks[0].Priority) + } +} + +func TestParseFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "task.yaml") + content := ` +name: "File Task" +claude: + instructions: "from file" + working_dir: "/tmp" +` + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + tasks, err := ParseFile(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tasks) != 1 || tasks[0].Name != "File Task" { + t.Errorf("unexpected tasks: %+v", tasks) + } +} + +func TestParseFile_NotFound(t *testing.T) { + _, err := ParseFile("/nonexistent/task.yaml") + if err == nil { + t.Fatal("expected error for nonexistent file") + } +} diff --git a/internal/task/task.go b/internal/task/task.go new file mode 100644 index 0000000..3796cf3 --- /dev/null +++ b/internal/task/task.go @@ -0,0 +1,100 @@ +package task + +import "time" + +type State string + +const ( + StatePending State = "PENDING" + StateQueued State = "QUEUED" + StateRunning State = "RUNNING" + StateCompleted State = "COMPLETED" + StateFailed State = "FAILED" + StateTimedOut State = "TIMED_OUT" + StateCancelled State = "CANCELLED" + StateBudgetExceeded State = "BUDGET_EXCEEDED" +) + +type Priority string + +const ( + PriorityHigh Priority = "high" + PriorityNormal Priority = "normal" + PriorityLow Priority = "low" +) + +type ClaudeConfig struct { + Model string `yaml:"model"` + ContextFiles []string `yaml:"context_files"` + Instructions string `yaml:"instructions"` + WorkingDir string `yaml:"working_dir"` + MaxBudgetUSD float64 `yaml:"max_budget_usd"` + PermissionMode string `yaml:"permission_mode"` + AllowedTools []string `yaml:"allowed_tools"` + DisallowedTools []string `yaml:"disallowed_tools"` + SystemPromptAppend string `yaml:"system_prompt_append"` + AdditionalArgs []string `yaml:"additional_args"` +} + +type RetryConfig struct { + MaxAttempts int `yaml:"max_attempts"` + Backoff string `yaml:"backoff"` // "linear", "exponential" +} + +type Task struct { + ID string `yaml:"id"` + Name string `yaml:"name"` + Description string `yaml:"description"` + Claude ClaudeConfig `yaml:"claude"` + Timeout Duration `yaml:"timeout"` + Retry RetryConfig `yaml:"retry"` + Priority Priority `yaml:"priority"` + Tags []string `yaml:"tags"` + DependsOn []string `yaml:"depends_on"` + State State `yaml:"-"` + CreatedAt time.Time `yaml:"-"` + UpdatedAt time.Time `yaml:"-"` +} + +// Duration wraps time.Duration for YAML unmarshaling from strings like "30m". +type Duration struct { + time.Duration +} + +func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + dur, err := time.ParseDuration(s) + if err != nil { + return err + } + d.Duration = dur + return nil +} + +func (d Duration) MarshalYAML() (interface{}, error) { + return d.Duration.String(), nil +} + +// BatchFile represents a YAML file containing multiple tasks. +type BatchFile struct { + Tasks []Task `yaml:"tasks"` +} + +// ValidTransition returns true if moving from the current state to next is allowed. +func ValidTransition(from, to State) bool { + transitions := map[State][]State{ + StatePending: {StateQueued, StateCancelled}, + StateQueued: {StateRunning, StateCancelled}, + StateRunning: {StateCompleted, StateFailed, StateTimedOut, StateCancelled, StateBudgetExceeded}, + StateFailed: {StateQueued}, // retry + } + for _, allowed := range transitions[from] { + if allowed == to { + return true + } + } + return false +} diff --git a/internal/task/task_test.go b/internal/task/task_test.go new file mode 100644 index 0000000..96f5f6f --- /dev/null +++ b/internal/task/task_test.go @@ -0,0 +1,80 @@ +package task + +import ( + "testing" + "time" +) + +func TestValidTransition_AllowedTransitions(t *testing.T) { + tests := []struct { + name string + from State + to State + }{ + {"pending to queued", StatePending, StateQueued}, + {"pending to cancelled", StatePending, StateCancelled}, + {"queued to running", StateQueued, StateRunning}, + {"queued to cancelled", StateQueued, StateCancelled}, + {"running to completed", StateRunning, StateCompleted}, + {"running to failed", StateRunning, StateFailed}, + {"running to timed out", StateRunning, StateTimedOut}, + {"running to cancelled", StateRunning, StateCancelled}, + {"running to budget exceeded", StateRunning, StateBudgetExceeded}, + {"failed to queued (retry)", StateFailed, StateQueued}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !ValidTransition(tt.from, tt.to) { + t.Errorf("expected transition %s -> %s to be valid", tt.from, tt.to) + } + }) + } +} + +func TestValidTransition_DisallowedTransitions(t *testing.T) { + tests := []struct { + name string + from State + to State + }{ + {"pending to running", StatePending, StateRunning}, + {"pending to completed", StatePending, StateCompleted}, + {"queued to completed", StateQueued, StateCompleted}, + {"completed to running", StateCompleted, StateRunning}, + {"completed to queued", StateCompleted, StateQueued}, + {"failed to completed", StateFailed, StateCompleted}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if ValidTransition(tt.from, tt.to) { + t.Errorf("expected transition %s -> %s to be invalid", tt.from, tt.to) + } + }) + } +} + +func TestDuration_UnmarshalYAML(t *testing.T) { + var d Duration + unmarshal := func(v interface{}) error { + ptr := v.(*string) + *ptr = "30m" + return nil + } + if err := d.UnmarshalYAML(unmarshal); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if d.Duration != 30*time.Minute { + t.Errorf("expected 30m, got %v", d.Duration) + } +} + +func TestDuration_MarshalYAML(t *testing.T) { + d := Duration{Duration: 15 * time.Minute} + v, err := d.MarshalYAML() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v != "15m0s" { + t.Errorf("expected '15m0s', got %v", v) + } +} diff --git a/internal/task/validator.go b/internal/task/validator.go new file mode 100644 index 0000000..ea0b1c2 --- /dev/null +++ b/internal/task/validator.go @@ -0,0 +1,65 @@ +package task + +import ( + "fmt" + "strings" +) + +// ValidationError collects multiple validation failures. +type ValidationError struct { + Errors []string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("validation failed: %s", strings.Join(e.Errors, "; ")) +} + +func (e *ValidationError) Add(msg string) { + e.Errors = append(e.Errors, msg) +} + +func (e *ValidationError) HasErrors() bool { + return len(e.Errors) > 0 +} + +// Validate checks a task for required fields and valid values. +func Validate(t *Task) error { + ve := &ValidationError{} + + if t.Name == "" { + ve.Add("name is required") + } + if t.Claude.Instructions == "" { + ve.Add("claude.instructions is required") + } + if t.Claude.MaxBudgetUSD < 0 { + ve.Add("claude.max_budget_usd must be non-negative") + } + if t.Timeout.Duration < 0 { + ve.Add("timeout must be non-negative") + } + if t.Retry.MaxAttempts < 1 { + ve.Add("retry.max_attempts must be at least 1") + } + if t.Retry.Backoff != "" && t.Retry.Backoff != "linear" && t.Retry.Backoff != "exponential" { + ve.Add("retry.backoff must be 'linear' or 'exponential'") + } + validPriorities := map[Priority]bool{PriorityHigh: true, PriorityNormal: true, PriorityLow: true} + if t.Priority != "" && !validPriorities[t.Priority] { + ve.Add(fmt.Sprintf("invalid priority %q; must be high, normal, or low", t.Priority)) + } + if t.Claude.PermissionMode != "" { + validModes := map[string]bool{ + "default": true, "acceptEdits": true, "bypassPermissions": true, + "plan": true, "dontAsk": true, "delegate": true, + } + if !validModes[t.Claude.PermissionMode] { + ve.Add(fmt.Sprintf("invalid permission_mode %q", t.Claude.PermissionMode)) + } + } + + if ve.HasErrors() { + return ve + } + return nil +} diff --git a/internal/task/validator_test.go b/internal/task/validator_test.go new file mode 100644 index 0000000..967eed3 --- /dev/null +++ b/internal/task/validator_test.go @@ -0,0 +1,115 @@ +package task + +import ( + "strings" + "testing" +) + +func validTask() *Task { + return &Task{ + ID: "test-id", + Name: "Valid Task", + Claude: ClaudeConfig{ + Instructions: "do something", + WorkingDir: "/tmp", + }, + Priority: PriorityNormal, + Retry: RetryConfig{MaxAttempts: 1, Backoff: "exponential"}, + } +} + +func TestValidate_ValidTask_NoError(t *testing.T) { + task := validTask() + if err := Validate(task); err != nil { + t.Errorf("expected no error, got: %v", err) + } +} + +func TestValidate_MissingName_ReturnsError(t *testing.T) { + task := validTask() + task.Name = "" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "name is required") { + t.Errorf("expected 'name is required' in error, got: %v", err) + } +} + +func TestValidate_MissingInstructions_ReturnsError(t *testing.T) { + task := validTask() + task.Claude.Instructions = "" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "instructions is required") { + t.Errorf("expected 'instructions is required' in error, got: %v", err) + } +} + +func TestValidate_NegativeBudget_ReturnsError(t *testing.T) { + task := validTask() + task.Claude.MaxBudgetUSD = -1.0 + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "max_budget_usd") { + t.Errorf("expected budget error, got: %v", err) + } +} + +func TestValidate_InvalidBackoff_ReturnsError(t *testing.T) { + task := validTask() + task.Retry.Backoff = "random" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "backoff") { + t.Errorf("expected backoff error, got: %v", err) + } +} + +func TestValidate_InvalidPriority_ReturnsError(t *testing.T) { + task := validTask() + task.Priority = "urgent" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "invalid priority") { + t.Errorf("expected priority error, got: %v", err) + } +} + +func TestValidate_InvalidPermissionMode_ReturnsError(t *testing.T) { + task := validTask() + task.Claude.PermissionMode = "yolo" + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "permission_mode") { + t.Errorf("expected permission_mode error, got: %v", err) + } +} + +func TestValidate_MultipleErrors(t *testing.T) { + task := &Task{ + Retry: RetryConfig{MaxAttempts: 0, Backoff: "bad"}, + } + err := Validate(task) + if err == nil { + t.Fatal("expected error") + } + ve, ok := err.(*ValidationError) + if !ok { + t.Fatalf("expected *ValidationError, got %T", err) + } + if len(ve.Errors) < 3 { + t.Errorf("expected at least 3 errors, got %d: %v", len(ve.Errors), ve.Errors) + } +} diff --git a/test/fixtures/tasks/batch-tasks.yaml b/test/fixtures/tasks/batch-tasks.yaml new file mode 100644 index 0000000..e1c97b7 --- /dev/null +++ b/test/fixtures/tasks/batch-tasks.yaml @@ -0,0 +1,27 @@ +tasks: + - name: "Lint Check" + claude: + model: "sonnet" + instructions: "Run the linter and report any issues." + working_dir: "/tmp" + timeout: "10m" + priority: "high" + tags: ["ci", "lint"] + + - name: "Unit Tests" + claude: + model: "sonnet" + instructions: "Run the test suite and summarize results." + working_dir: "/tmp" + timeout: "15m" + tags: ["ci", "test"] + + - name: "Security Audit" + claude: + model: "opus" + instructions: "Review the codebase for security vulnerabilities." + working_dir: "/tmp" + max_budget_usd: 5.00 + timeout: "30m" + priority: "low" + tags: ["security"] diff --git a/test/fixtures/tasks/simple-task.yaml b/test/fixtures/tasks/simple-task.yaml new file mode 100644 index 0000000..d49b066 --- /dev/null +++ b/test/fixtures/tasks/simple-task.yaml @@ -0,0 +1,11 @@ +name: "Hello World" +description: "A simple test task" +claude: + model: "sonnet" + instructions: | + Say hello and list the files in the current directory. + working_dir: "/tmp" +timeout: "5m" +tags: + - "test" + - "simple" |
