summaryrefslogtreecommitdiff
path: root/internal/storage
diff options
context:
space:
mode:
Diffstat (limited to 'internal/storage')
-rw-r--r--internal/storage/db.go278
-rw-r--r--internal/storage/db_test.go285
2 files changed, 563 insertions, 0 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go
new file mode 100644
index 0000000..67fbe08
--- /dev/null
+++ b/internal/storage/db.go
@@ -0,0 +1,278 @@
+package storage
+
+import (
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/claudomator/claudomator/internal/task"
+ _ "github.com/mattn/go-sqlite3"
+)
+
+type DB struct {
+ db *sql.DB
+}
+
+func Open(path string) (*DB, error) {
+ db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_busy_timeout=5000")
+ if err != nil {
+ return nil, fmt.Errorf("opening database: %w", err)
+ }
+ s := &DB{db: db}
+ if err := s.migrate(); err != nil {
+ db.Close()
+ return nil, fmt.Errorf("running migrations: %w", err)
+ }
+ return s, nil
+}
+
+func (s *DB) Close() error {
+ return s.db.Close()
+}
+
+func (s *DB) migrate() error {
+ schema := `
+ CREATE TABLE IF NOT EXISTS tasks (
+ id TEXT PRIMARY KEY,
+ name TEXT NOT NULL,
+ description TEXT,
+ config_json TEXT NOT NULL,
+ priority TEXT NOT NULL DEFAULT 'normal',
+ timeout_ns INTEGER NOT NULL DEFAULT 0,
+ retry_json TEXT NOT NULL DEFAULT '{}',
+ tags_json TEXT NOT NULL DEFAULT '[]',
+ depends_on_json TEXT NOT NULL DEFAULT '[]',
+ state TEXT NOT NULL DEFAULT 'PENDING',
+ created_at DATETIME NOT NULL,
+ updated_at DATETIME NOT NULL
+ );
+
+ CREATE TABLE IF NOT EXISTS executions (
+ id TEXT PRIMARY KEY,
+ task_id TEXT NOT NULL,
+ start_time DATETIME NOT NULL,
+ end_time DATETIME,
+ exit_code INTEGER,
+ status TEXT NOT NULL,
+ stdout_path TEXT,
+ stderr_path TEXT,
+ artifact_dir TEXT,
+ cost_usd REAL,
+ error_msg TEXT,
+ FOREIGN KEY (task_id) REFERENCES tasks(id)
+ );
+
+ CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state);
+ CREATE INDEX IF NOT EXISTS idx_executions_status ON executions(status);
+ CREATE INDEX IF NOT EXISTS idx_executions_task_id ON executions(task_id);
+ `
+ _, err := s.db.Exec(schema)
+ return err
+}
+
+// CreateTask inserts a task into the database.
+func (s *DB) CreateTask(t *task.Task) error {
+ configJSON, err := json.Marshal(t.Claude)
+ if err != nil {
+ return fmt.Errorf("marshaling config: %w", err)
+ }
+ retryJSON, err := json.Marshal(t.Retry)
+ if err != nil {
+ return fmt.Errorf("marshaling retry: %w", err)
+ }
+ tagsJSON, err := json.Marshal(t.Tags)
+ if err != nil {
+ return fmt.Errorf("marshaling tags: %w", err)
+ }
+ depsJSON, err := json.Marshal(t.DependsOn)
+ if err != nil {
+ return fmt.Errorf("marshaling depends_on: %w", err)
+ }
+
+ _, err = s.db.Exec(`
+ INSERT INTO tasks (id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+ t.ID, t.Name, t.Description, string(configJSON), string(t.Priority),
+ t.Timeout.Duration.Nanoseconds(), string(retryJSON), string(tagsJSON), string(depsJSON),
+ string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(),
+ )
+ return err
+}
+
+// GetTask retrieves a task by ID.
+func (s *DB) GetTask(id string) (*task.Task, error) {
+ row := s.db.QueryRow(`SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE id = ?`, id)
+ return scanTask(row)
+}
+
+// ListTasks returns tasks matching the given filter.
+func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) {
+ query := `SELECT id, name, description, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, state, created_at, updated_at FROM tasks WHERE 1=1`
+ var args []interface{}
+
+ if filter.State != "" {
+ query += " AND state = ?"
+ args = append(args, string(filter.State))
+ }
+ query += " ORDER BY created_at DESC"
+ if filter.Limit > 0 {
+ query += " LIMIT ?"
+ args = append(args, filter.Limit)
+ }
+
+ rows, err := s.db.Query(query, args...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var tasks []*task.Task
+ for rows.Next() {
+ t, err := scanTaskRows(rows)
+ if err != nil {
+ return nil, err
+ }
+ tasks = append(tasks, t)
+ }
+ return tasks, rows.Err()
+}
+
+// UpdateTaskState atomically updates a task's state.
+func (s *DB) UpdateTaskState(id string, newState task.State) error {
+ now := time.Now().UTC()
+ result, err := s.db.Exec(`UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?`, string(newState), now, id)
+ if err != nil {
+ return err
+ }
+ n, err := result.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if n == 0 {
+ return fmt.Errorf("task %q not found", id)
+ }
+ return nil
+}
+
+// TaskFilter specifies criteria for listing tasks.
+type TaskFilter struct {
+ State task.State
+ Limit int
+}
+
+// Execution represents a single run of a task.
+type Execution struct {
+ ID string
+ TaskID string
+ StartTime time.Time
+ EndTime time.Time
+ ExitCode int
+ Status string
+ StdoutPath string
+ StderrPath string
+ ArtifactDir string
+ CostUSD float64
+ ErrorMsg string
+}
+
+// CreateExecution inserts an execution record.
+func (s *DB) CreateExecution(e *Execution) error {
+ _, err := s.db.Exec(`
+ INSERT INTO executions (id, task_id, start_time, end_time, exit_code, status, stdout_path, stderr_path, artifact_dir, cost_usd, error_msg)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+ e.ID, e.TaskID, e.StartTime.UTC(), e.EndTime.UTC(), e.ExitCode, e.Status,
+ e.StdoutPath, e.StderrPath, e.ArtifactDir, e.CostUSD, e.ErrorMsg,
+ )
+ return err
+}
+
+// GetExecution retrieves an execution by ID.
+func (s *DB) GetExecution(id string) (*Execution, error) {
+ row := s.db.QueryRow(`SELECT id, task_id, start_time, end_time, exit_code, status, stdout_path, stderr_path, artifact_dir, cost_usd, error_msg FROM executions WHERE id = ?`, id)
+ return scanExecution(row)
+}
+
+// ListExecutions returns executions for a task.
+func (s *DB) ListExecutions(taskID string) ([]*Execution, error) {
+ rows, err := s.db.Query(`SELECT id, task_id, start_time, end_time, exit_code, status, stdout_path, stderr_path, artifact_dir, cost_usd, error_msg FROM executions WHERE task_id = ? ORDER BY start_time DESC`, taskID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var execs []*Execution
+ for rows.Next() {
+ e, err := scanExecutionRows(rows)
+ if err != nil {
+ return nil, err
+ }
+ execs = append(execs, e)
+ }
+ return execs, rows.Err()
+}
+
+// UpdateExecution updates a completed execution.
+func (s *DB) UpdateExecution(e *Execution) error {
+ _, err := s.db.Exec(`
+ UPDATE executions SET end_time = ?, exit_code = ?, status = ?, cost_usd = ?, error_msg = ?
+ WHERE id = ?`,
+ e.EndTime.UTC(), e.ExitCode, e.Status, e.CostUSD, e.ErrorMsg, e.ID,
+ )
+ return err
+}
+
+type scanner interface {
+ Scan(dest ...interface{}) error
+}
+
+func scanTask(row scanner) (*task.Task, error) {
+ var (
+ t task.Task
+ configJSON string
+ retryJSON string
+ tagsJSON string
+ depsJSON string
+ state string
+ priority string
+ timeoutNS int64
+ )
+ err := row.Scan(&t.ID, &t.Name, &t.Description, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &state, &t.CreatedAt, &t.UpdatedAt)
+ if err != nil {
+ return nil, err
+ }
+ t.State = task.State(state)
+ t.Priority = task.Priority(priority)
+ t.Timeout.Duration = time.Duration(timeoutNS)
+ if err := json.Unmarshal([]byte(configJSON), &t.Claude); err != nil {
+ return nil, fmt.Errorf("unmarshaling config: %w", err)
+ }
+ if err := json.Unmarshal([]byte(retryJSON), &t.Retry); err != nil {
+ return nil, fmt.Errorf("unmarshaling retry: %w", err)
+ }
+ if err := json.Unmarshal([]byte(tagsJSON), &t.Tags); err != nil {
+ return nil, fmt.Errorf("unmarshaling tags: %w", err)
+ }
+ if err := json.Unmarshal([]byte(depsJSON), &t.DependsOn); err != nil {
+ return nil, fmt.Errorf("unmarshaling depends_on: %w", err)
+ }
+ return &t, nil
+}
+
+func scanTaskRows(rows *sql.Rows) (*task.Task, error) {
+ return scanTask(rows)
+}
+
+func scanExecution(row scanner) (*Execution, error) {
+ var e Execution
+ err := row.Scan(&e.ID, &e.TaskID, &e.StartTime, &e.EndTime, &e.ExitCode, &e.Status,
+ &e.StdoutPath, &e.StderrPath, &e.ArtifactDir, &e.CostUSD, &e.ErrorMsg)
+ if err != nil {
+ return nil, err
+ }
+ return &e, nil
+}
+
+func scanExecutionRows(rows *sql.Rows) (*Execution, error) {
+ return scanExecution(rows)
+}
diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go
new file mode 100644
index 0000000..78cb1e1
--- /dev/null
+++ b/internal/storage/db_test.go
@@ -0,0 +1,285 @@
+package storage
+
+import (
+ "fmt"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/claudomator/claudomator/internal/task"
+)
+
+func testDB(t *testing.T) *DB {
+ t.Helper()
+ dbPath := filepath.Join(t.TempDir(), "test.db")
+ db, err := Open(dbPath)
+ if err != nil {
+ t.Fatalf("opening db: %v", err)
+ }
+ t.Cleanup(func() { db.Close() })
+ return db
+}
+
+func TestOpen_CreatesSchema(t *testing.T) {
+ db := testDB(t)
+ // Should be able to query tasks table.
+ _, err := db.ListTasks(TaskFilter{})
+ if err != nil {
+ t.Fatalf("querying tasks: %v", err)
+ }
+}
+
+func TestCreateTask_AndGetTask(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC().Truncate(time.Second)
+
+ tk := &task.Task{
+ ID: "task-1",
+ Name: "Test Task",
+ Description: "A test",
+ Claude: task.ClaudeConfig{
+ Model: "sonnet",
+ Instructions: "do it",
+ WorkingDir: "/tmp",
+ MaxBudgetUSD: 2.5,
+ },
+ Priority: task.PriorityHigh,
+ Tags: []string{"test", "alpha"},
+ DependsOn: []string{},
+ Retry: task.RetryConfig{MaxAttempts: 3, Backoff: "exponential"},
+ State: task.StatePending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ tk.Timeout.Duration = 10 * time.Minute
+
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatalf("creating task: %v", err)
+ }
+
+ got, err := db.GetTask("task-1")
+ if err != nil {
+ t.Fatalf("getting task: %v", err)
+ }
+ if got.Name != "Test Task" {
+ t.Errorf("name: want 'Test Task', got %q", got.Name)
+ }
+ if got.Claude.Model != "sonnet" {
+ t.Errorf("model: want 'sonnet', got %q", got.Claude.Model)
+ }
+ if got.Claude.MaxBudgetUSD != 2.5 {
+ t.Errorf("budget: want 2.5, got %f", got.Claude.MaxBudgetUSD)
+ }
+ if got.Priority != task.PriorityHigh {
+ t.Errorf("priority: want 'high', got %q", got.Priority)
+ }
+ if got.Timeout.Duration != 10*time.Minute {
+ t.Errorf("timeout: want 10m, got %v", got.Timeout.Duration)
+ }
+ if got.Retry.MaxAttempts != 3 {
+ t.Errorf("retry: want 3, got %d", got.Retry.MaxAttempts)
+ }
+ if len(got.Tags) != 2 || got.Tags[0] != "test" {
+ t.Errorf("tags: want [test alpha], got %v", got.Tags)
+ }
+ if got.State != task.StatePending {
+ t.Errorf("state: want PENDING, got %v", got.State)
+ }
+}
+
+func TestUpdateTaskState(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC()
+ tk := &task.Task{
+ ID: "task-2",
+ Name: "Stateful",
+ Claude: task.ClaudeConfig{Instructions: "test"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StatePending,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := db.UpdateTaskState("task-2", task.StateQueued); err != nil {
+ t.Fatalf("updating state: %v", err)
+ }
+ got, _ := db.GetTask("task-2")
+ if got.State != task.StateQueued {
+ t.Errorf("state: want QUEUED, got %v", got.State)
+ }
+}
+
+func TestUpdateTaskState_NotFound(t *testing.T) {
+ db := testDB(t)
+ err := db.UpdateTaskState("nonexistent", task.StateQueued)
+ if err == nil {
+ t.Fatal("expected error for nonexistent task")
+ }
+}
+
+func TestListTasks_FilterByState(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC()
+
+ for i, state := range []task.State{task.StatePending, task.StatePending, task.StateRunning} {
+ tk := &task.Task{
+ ID: fmt.Sprintf("t-%d", i), Name: fmt.Sprintf("Task %d", i),
+ Claude: task.ClaudeConfig{Instructions: "x"}, Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{},
+ State: state, CreatedAt: now, UpdatedAt: now,
+ }
+ if err := db.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ pending, err := db.ListTasks(TaskFilter{State: task.StatePending})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(pending) != 2 {
+ t.Errorf("want 2 pending, got %d", len(pending))
+ }
+
+ running, err := db.ListTasks(TaskFilter{State: task.StateRunning})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(running) != 1 {
+ t.Errorf("want 1 running, got %d", len(running))
+ }
+}
+
+func TestListTasks_WithLimit(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC()
+ for i := 0; i < 5; i++ {
+ tk := &task.Task{
+ ID: fmt.Sprintf("lt-%d", i), Name: fmt.Sprintf("T%d", i),
+ Claude: task.ClaudeConfig{Instructions: "x"}, Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{},
+ State: task.StatePending, CreatedAt: now.Add(time.Duration(i) * time.Second), UpdatedAt: now,
+ }
+ db.CreateTask(tk)
+ }
+
+ tasks, err := db.ListTasks(TaskFilter{Limit: 3})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(tasks) != 3 {
+ t.Errorf("want 3, got %d", len(tasks))
+ }
+}
+
+func TestCreateExecution_AndGet(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC().Truncate(time.Second)
+
+ // Need a task first.
+ tk := &task.Task{
+ ID: "etask", Name: "E", Claude: task.ClaudeConfig{Instructions: "x"},
+ Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{},
+ State: task.StatePending, CreatedAt: now, UpdatedAt: now,
+ }
+ db.CreateTask(tk)
+
+ exec := &Execution{
+ ID: "exec-1",
+ TaskID: "etask",
+ StartTime: now,
+ EndTime: now.Add(5 * time.Minute),
+ ExitCode: 0,
+ Status: "COMPLETED",
+ StdoutPath: "/tmp/stdout.log",
+ StderrPath: "/tmp/stderr.log",
+ CostUSD: 0.42,
+ }
+ if err := db.CreateExecution(exec); err != nil {
+ t.Fatalf("creating execution: %v", err)
+ }
+
+ got, err := db.GetExecution("exec-1")
+ if err != nil {
+ t.Fatalf("getting execution: %v", err)
+ }
+ if got.Status != "COMPLETED" {
+ t.Errorf("status: want COMPLETED, got %q", got.Status)
+ }
+ if got.CostUSD != 0.42 {
+ t.Errorf("cost: want 0.42, got %f", got.CostUSD)
+ }
+ if got.ExitCode != 0 {
+ t.Errorf("exit code: want 0, got %d", got.ExitCode)
+ }
+}
+
+func TestListExecutions(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC()
+ tk := &task.Task{
+ ID: "ltask", Name: "L", Claude: task.ClaudeConfig{Instructions: "x"},
+ Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{},
+ State: task.StatePending, CreatedAt: now, UpdatedAt: now,
+ }
+ db.CreateTask(tk)
+
+ for i := 0; i < 3; i++ {
+ db.CreateExecution(&Execution{
+ ID: fmt.Sprintf("le-%d", i), TaskID: "ltask",
+ StartTime: now.Add(time.Duration(i) * time.Minute), EndTime: now.Add(time.Duration(i+1) * time.Minute),
+ Status: "COMPLETED",
+ })
+ }
+
+ execs, err := db.ListExecutions("ltask")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(execs) != 3 {
+ t.Errorf("want 3, got %d", len(execs))
+ }
+}
+
+func TestUpdateExecution(t *testing.T) {
+ db := testDB(t)
+ now := time.Now().UTC()
+ tk := &task.Task{
+ ID: "utask", Name: "U", Claude: task.ClaudeConfig{Instructions: "x"},
+ Priority: task.PriorityNormal, Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{}, DependsOn: []string{},
+ State: task.StatePending, CreatedAt: now, UpdatedAt: now,
+ }
+ db.CreateTask(tk)
+
+ exec := &Execution{
+ ID: "ue-1", TaskID: "utask", StartTime: now, EndTime: now, Status: "RUNNING",
+ }
+ db.CreateExecution(exec)
+
+ exec.Status = "FAILED"
+ exec.ExitCode = 1
+ exec.ErrorMsg = "something broke"
+ exec.EndTime = now.Add(2 * time.Minute)
+ if err := db.UpdateExecution(exec); err != nil {
+ t.Fatal(err)
+ }
+
+ got, _ := db.GetExecution("ue-1")
+ if got.Status != "FAILED" {
+ t.Errorf("status: want FAILED, got %q", got.Status)
+ }
+ if got.ErrorMsg != "something broke" {
+ t.Errorf("error: want 'something broke', got %q", got.ErrorMsg)
+ }
+}