diff options
Diffstat (limited to 'internal/storage')
| -rw-r--r-- | internal/storage/db.go | 569 | ||||
| -rw-r--r-- | internal/storage/db_test.go | 372 | ||||
| -rw-r--r-- | internal/storage/seed.go | 62 | ||||
| -rw-r--r-- | internal/storage/sqlite_cgo.go | 5 | ||||
| -rw-r--r-- | internal/storage/sqlite_nocgo.go | 21 |
5 files changed, 1010 insertions, 19 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go index ce60e2f..4adc1ba 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -8,7 +8,6 @@ import ( "time" "github.com/thepeterstone/claudomator/internal/task" - _ "github.com/mattn/go-sqlite3" ) type DB struct { @@ -20,6 +19,10 @@ func Open(path string) (*DB, error) { if err != nil { return nil, fmt.Errorf("opening database: %w", err) } + // SQLite only allows one concurrent writer. Limiting to one open connection + // prevents "database is locked" errors when multiple goroutines write + // simultaneously via database/sql's connection pool. + db.SetMaxOpenConns(1) s := &DB{db: db} if err := s.migrate(); err != nil { db.Close() @@ -86,6 +89,54 @@ func (s *DB) migrate() error { `ALTER TABLE executions ADD COLUMN changestats_json TEXT`, `ALTER TABLE executions ADD COLUMN commits_json TEXT NOT NULL DEFAULT '[]'`, `ALTER TABLE tasks ADD COLUMN elaboration_input TEXT`, + `ALTER TABLE tasks ADD COLUMN project TEXT`, + `ALTER TABLE tasks ADD COLUMN repository_url TEXT`, + `CREATE TABLE IF NOT EXISTS push_subscriptions ( + id TEXT PRIMARY KEY, + endpoint TEXT NOT NULL UNIQUE, + p256dh_key TEXT NOT NULL, + auth_key TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS agent_events ( + id TEXT PRIMARY KEY, + agent TEXT NOT NULL, + event TEXT NOT NULL, + timestamp DATETIME NOT NULL, + until DATETIME, + reason TEXT + )`, + `CREATE INDEX IF NOT EXISTS idx_agent_events_agent ON agent_events(agent)`, + `CREATE INDEX IF NOT EXISTS idx_agent_events_timestamp ON agent_events(timestamp)`, + `CREATE TABLE IF NOT EXISTS projects ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + remote_url TEXT NOT NULL DEFAULT '', + local_path TEXT NOT NULL DEFAULT '', + type TEXT NOT NULL DEFAULT 'web', + deploy_script TEXT NOT NULL DEFAULT '', + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS stories ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + project_id TEXT NOT NULL DEFAULT '', + branch_name TEXT NOT NULL DEFAULT '', + deploy_config TEXT NOT NULL DEFAULT '', + validation_json TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'PENDING', + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + )`, + `ALTER TABLE tasks ADD COLUMN story_id TEXT`, + `ALTER TABLE tasks ADD COLUMN acceptance_criteria TEXT NOT NULL DEFAULT ''`, + `ALTER TABLE tasks ADD COLUMN checker_for_task_id TEXT NOT NULL DEFAULT ''`, + `ALTER TABLE tasks ADD COLUMN checker_report TEXT NOT NULL DEFAULT ''`, `ALTER TABLE executions ADD COLUMN tokens_in INTEGER`, `ALTER TABLE executions ADD COLUMN tokens_out INTEGER`, } @@ -125,24 +176,25 @@ func (s *DB) CreateTask(t *task.Task) error { } _, err = s.db.Exec(` - INSERT INTO tasks (id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - t.ID, t.Name, t.Description, t.ElaborationInput, string(configJSON), string(t.Priority), + INSERT INTO tasks (id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, story_id, acceptance_criteria, checker_for_task_id, checker_report) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + t.ID, t.Name, t.Description, t.ElaborationInput, t.Project, t.RepositoryURL, string(configJSON), string(t.Priority), t.Timeout.Duration.Nanoseconds(), string(retryJSON), string(tagsJSON), string(depsJSON), - t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), + t.ParentTaskID, string(t.State), t.CreatedAt.UTC(), t.UpdatedAt.UTC(), t.StoryID, + t.AcceptanceCriteria, t.CheckerForTaskID, t.CheckerReport, ) return err } // GetTask retrieves a task by ID. func (s *DB) GetTask(id string) (*task.Task, error) { - row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id) + row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE id = ?`, id) return scanTask(row) } // ListTasks returns tasks matching the given filter. func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { - query := `SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE 1=1` + query := `SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE 1=1` var args []interface{} if filter.State != "" { @@ -178,7 +230,7 @@ func (s *DB) ListTasks(filter TaskFilter) ([]*task.Task, error) { // ListSubtasks returns all tasks whose parent_task_id matches the given ID. func (s *DB) ListSubtasks(parentID string) ([]*task.Task, error) { - rows, err := s.db.Query(`SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) + rows, err := s.db.Query(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE parent_task_id = ? ORDER BY created_at ASC`, parentID) if err != nil { return nil, err } @@ -231,7 +283,7 @@ func (s *DB) ResetTaskForRetry(id string) (*task.Task, error) { } defer tx.Rollback() //nolint:errcheck - t, err := scanTask(tx.QueryRow(`SELECT id, name, description, elaboration_input, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json FROM tasks WHERE id = ?`, id)) + t, err := scanTask(tx.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE id = ?`, id)) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("task %q not found", id) @@ -292,9 +344,10 @@ func (s *DB) RejectTask(id, comment string) error { // TaskUpdate holds the fields that UpdateTask may change. type TaskUpdate struct { - Name string - Description string - Config task.AgentConfig + Name string + Description string + RepositoryURL string + Config task.AgentConfig Priority task.Priority TimeoutNS int64 Retry task.RetryConfig @@ -333,13 +386,11 @@ func (s *DB) UpdateTask(id string, u TaskUpdate) error { now := time.Now().UTC() result, err := s.db.Exec(` UPDATE tasks - SET name = ?, description = ?, config_json = ?, priority = ?, timeout_ns = ?, + SET name = ?, description = ?, repository_url = ?, config_json = ?, priority = ?, timeout_ns = ?, retry_json = ?, tags_json = ?, depends_on_json = ?, state = ?, updated_at = ? WHERE id = ?`, - u.Name, u.Description, string(configJSON), string(u.Priority), u.TimeoutNS, - string(retryJSON), string(tagsJSON), string(depsJSON), string(task.StatePending), now, - id, - ) + u.Name, u.Description, u.RepositoryURL, configJSON, string(u.Priority), u.TimeoutNS, + retryJSON, tagsJSON, depsJSON, string(task.StatePending), now, id) if err != nil { return err } @@ -376,6 +427,8 @@ func (s *DB) GetMaxUpdatedAt() (time.Time, error) { "2006-01-02T15:04:05Z07:00", time.RFC3339, "2006-01-02 15:04:05", + "2006-01-02 15:04:05 +0000 UTC", + "2006-01-02 15:04:05.999999999 +0000 UTC", } for _, f := range formats { parsed, err := time.Parse(f, t.String) @@ -417,6 +470,55 @@ type Execution struct { Summary string } +// CreateExecutionAndSetRunning inserts an execution record and transitions the +// task to RUNNING in a single transaction, preventing a crash-window where the +// task stays PENDING with an orphaned RUNNING execution record. +func (s *DB) CreateExecutionAndSetRunning(e *Execution) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() //nolint:errcheck + + // Validate state transition. + var currentState string + if err := tx.QueryRow(`SELECT state FROM tasks WHERE id = ?`, e.TaskID).Scan(¤tState); err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("task %q not found", e.TaskID) + } + return err + } + if !task.ValidTransition(task.State(currentState), task.StateRunning) { + return fmt.Errorf("invalid state transition %s → RUNNING for task %q", currentState, e.TaskID) + } + + // Insert execution record. + commitsJSON := "[]" + if len(e.Commits) > 0 { + b, err := json.Marshal(e.Commits) + if err != nil { + return fmt.Errorf("marshaling commits: %w", err) + } + commitsJSON = string(b) + } + if _, err := tx.Exec(` + INSERT INTO executions (id, task_id, start_time, end_time, exit_code, status, stdout_path, stderr_path, artifact_dir, cost_usd, error_msg, session_id, sandbox_dir, changestats_json, commits_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, ?)`, + e.ID, e.TaskID, e.StartTime.UTC(), e.EndTime.UTC(), e.ExitCode, e.Status, + e.StdoutPath, e.StderrPath, e.ArtifactDir, e.CostUSD, e.ErrorMsg, e.SessionID, e.SandboxDir, commitsJSON, + ); err != nil { + return err + } + + // Transition task to RUNNING. + now := time.Now().UTC() + if _, err := tx.Exec(`UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?`, string(task.StateRunning), now, e.TaskID); err != nil { + return err + } + + return tx.Commit() +} + // CreateExecution inserts an execution record. func (s *DB) CreateExecution(e *Execution) error { var changestatsJSON *string @@ -544,6 +646,141 @@ type RecentExecution struct { StdoutPath string `json:"stdout_path"` } +// ThroughputBucket is one time-bucket of execution counts by outcome. +type ThroughputBucket struct { + Hour string `json:"hour"` // RFC3339 truncated to hour + Completed int `json:"completed"` + Failed int `json:"failed"` + Other int `json:"other"` +} + +// BillingDay is the aggregated cost and run count for a calendar day. +type BillingDay struct { + Day string `json:"day"` // YYYY-MM-DD + CostUSD float64 `json:"cost_usd"` + Runs int `json:"runs"` +} + +// FailedExecution is a failed/timed-out/budget-exceeded execution with its error. +type FailedExecution struct { + ID string `json:"id"` + TaskID string `json:"task_id"` + TaskName string `json:"task_name"` + Status string `json:"status"` + ErrorMsg string `json:"error_msg"` + Category string `json:"category"` // quota | timeout | rate_limit | git | failed + StartedAt time.Time `json:"started_at"` +} + +// DashboardStats is returned by QueryDashboardStats. +type DashboardStats struct { + Throughput []ThroughputBucket `json:"throughput"` + Billing []BillingDay `json:"billing"` + Failures []FailedExecution `json:"failures"` +} + +// QueryDashboardStats returns pre-aggregated stats for the given window. +func (s *DB) QueryDashboardStats(since time.Time) (*DashboardStats, error) { + stats := &DashboardStats{ + Throughput: []ThroughputBucket{}, + Billing: []BillingDay{}, + Failures: []FailedExecution{}, + } + + // Throughput: completions per hour bucket + tpRows, err := s.db.Query(` + SELECT strftime('%Y-%m-%dT%H:00:00Z', start_time) as hour, + SUM(CASE WHEN status IN ('COMPLETED','READY') THEN 1 ELSE 0 END), + SUM(CASE WHEN status IN ('FAILED','TIMED_OUT','BUDGET_EXCEEDED') THEN 1 ELSE 0 END), + SUM(CASE WHEN status NOT IN ('COMPLETED','READY','FAILED','TIMED_OUT','BUDGET_EXCEEDED') THEN 1 ELSE 0 END) + FROM executions + WHERE start_time >= ? AND status NOT IN ('RUNNING','QUEUED','PENDING') + GROUP BY hour ORDER BY hour ASC`, since.UTC()) + if err != nil { + return nil, err + } + defer tpRows.Close() + for tpRows.Next() { + var b ThroughputBucket + if err := tpRows.Scan(&b.Hour, &b.Completed, &b.Failed, &b.Other); err != nil { + return nil, err + } + stats.Throughput = append(stats.Throughput, b) + } + if err := tpRows.Err(); err != nil { + return nil, err + } + + // Billing: cost per day + billRows, err := s.db.Query(` + SELECT date(start_time) as day, COALESCE(SUM(cost_usd),0), COUNT(*) + FROM executions + WHERE start_time >= ? + GROUP BY day ORDER BY day ASC`, since.UTC()) + if err != nil { + return nil, err + } + defer billRows.Close() + for billRows.Next() { + var b BillingDay + if err := billRows.Scan(&b.Day, &b.CostUSD, &b.Runs); err != nil { + return nil, err + } + stats.Billing = append(stats.Billing, b) + } + if err := billRows.Err(); err != nil { + return nil, err + } + + // Failures: recent failed executions with error messages + failRows, err := s.db.Query(` + SELECT e.id, e.task_id, t.name, e.status, COALESCE(e.error_msg,''), e.start_time + FROM executions e JOIN tasks t ON e.task_id = t.id + WHERE e.start_time >= ? AND e.status IN ('FAILED','TIMED_OUT','BUDGET_EXCEEDED') + ORDER BY e.start_time DESC LIMIT 50`, since.UTC()) + if err != nil { + return nil, err + } + defer failRows.Close() + for failRows.Next() { + var f FailedExecution + if err := failRows.Scan(&f.ID, &f.TaskID, &f.TaskName, &f.Status, &f.ErrorMsg, &f.StartedAt); err != nil { + return nil, err + } + f.Category = classifyError(f.Status, f.ErrorMsg) + stats.Failures = append(stats.Failures, f) + } + if err := failRows.Err(); err != nil { + return nil, err + } + + return stats, nil +} + +// classifyError maps a status + error message to a human category. +func classifyError(status, msg string) string { + if status == "TIMED_OUT" { + return "timeout" + } + if status == "BUDGET_EXCEEDED" { + return "quota" + } + low := strings.ToLower(msg) + if strings.Contains(low, "quota") || strings.Contains(low, "exhausted") || strings.Contains(low, "terminalquota") { + return "quota" + } + if strings.Contains(low, "rate limit") || strings.Contains(low, "429") || strings.Contains(low, "too many requests") { + return "rate_limit" + } + if strings.Contains(low, "git push") || strings.Contains(low, "git pull") { + return "git" + } + if strings.Contains(low, "timeout") || strings.Contains(low, "deadline") { + return "timeout" + } + return "failed" +} + // ListRecentExecutions returns executions since the given time, joined with task names. // If taskID is non-empty, only executions for that task are returned. func (s *DB) ListRecentExecutions(since time.Time, limit int, taskID string) ([]*RecentExecution, error) { @@ -600,6 +837,24 @@ func (s *DB) UpdateTaskSummary(taskID, summary string) error { return err } +// UpdateTaskCheckerReport sets the checker_report field on a task. +func (s *DB) UpdateTaskCheckerReport(id, report string) error { + now := time.Now().UTC() + _, err := s.db.Exec(`UPDATE tasks SET checker_report = ?, updated_at = ? WHERE id = ?`, report, now, id) + return err +} + +// GetCheckerTask returns the checker task for the given checked task ID, +// or nil if no checker task exists. +func (s *DB) GetCheckerTask(checkedTaskID string) (*task.Task, error) { + row := s.db.QueryRow(`SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE checker_for_task_id = ? LIMIT 1`, checkedTaskID) + t, err := scanTask(row) + if err == sql.ErrNoRows { + return nil, nil + } + return t, err +} + // AppendTaskInteraction appends a Q&A interaction to the task's interaction history. func (s *DB) AppendTaskInteraction(taskID string, interaction task.Interaction) error { tx, err := s.db.Begin() @@ -682,17 +937,35 @@ func scanTask(row scanner) (*task.Task, error) { timeoutNS int64 parentTaskID sql.NullString elaborationInput sql.NullString + project sql.NullString + repositoryURL sql.NullString rejectionComment sql.NullString questionJSON sql.NullString summary sql.NullString interactionsJSON sql.NullString + storyID sql.NullString + acceptanceCriteria sql.NullString + checkerForTaskID sql.NullString + checkerReport sql.NullString + ) + err := row.Scan( + &t.ID, &t.Name, &t.Description, &elaborationInput, &project, &repositoryURL, + &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, + &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt, + &rejectionComment, &questionJSON, &summary, &interactionsJSON, &storyID, + &acceptanceCriteria, &checkerForTaskID, &checkerReport, ) - err := row.Scan(&t.ID, &t.Name, &t.Description, &elaborationInput, &configJSON, &priority, &timeoutNS, &retryJSON, &tagsJSON, &depsJSON, &parentTaskID, &state, &t.CreatedAt, &t.UpdatedAt, &rejectionComment, &questionJSON, &summary, &interactionsJSON) t.ParentTaskID = parentTaskID.String t.ElaborationInput = elaborationInput.String + t.Project = project.String + t.RepositoryURL = repositoryURL.String t.RejectionComment = rejectionComment.String t.QuestionJSON = questionJSON.String t.Summary = summary.String + t.StoryID = storyID.String + t.AcceptanceCriteria = acceptanceCriteria.String + t.CheckerForTaskID = checkerForTaskID.String + t.CheckerReport = checkerReport.String if err != nil { return nil, err } @@ -772,3 +1045,263 @@ func (s *DB) UpdateExecutionChangestats(execID string, stats *task.Changestats) func scanExecutionRows(rows *sql.Rows) (*Execution, error) { return scanExecution(rows) } + +// PushSubscription represents a browser push subscription. +type PushSubscription struct { + ID string `json:"id"` + Endpoint string `json:"endpoint"` + P256DHKey string `json:"p256dh_key"` + AuthKey string `json:"auth_key"` + CreatedAt time.Time `json:"created_at"` +} + +// SavePushSubscription inserts or replaces a push subscription by endpoint. +func (s *DB) SavePushSubscription(sub PushSubscription) error { + _, err := s.db.Exec(` + INSERT INTO push_subscriptions (id, endpoint, p256dh_key, auth_key) + VALUES (?, ?, ?, ?) + ON CONFLICT(endpoint) DO UPDATE SET + id = excluded.id, + p256dh_key = excluded.p256dh_key, + auth_key = excluded.auth_key`, + sub.ID, sub.Endpoint, sub.P256DHKey, sub.AuthKey, + ) + return err +} + +// DeletePushSubscription removes the subscription with the given endpoint. +func (s *DB) DeletePushSubscription(endpoint string) error { + _, err := s.db.Exec(`DELETE FROM push_subscriptions WHERE endpoint = ?`, endpoint) + return err +} + +// ListPushSubscriptions returns all registered push subscriptions. +func (s *DB) ListPushSubscriptions() ([]PushSubscription, error) { + rows, err := s.db.Query(`SELECT id, endpoint, p256dh_key, auth_key, created_at FROM push_subscriptions ORDER BY created_at`) + if err != nil { + return nil, err + } + defer rows.Close() + + var subs []PushSubscription + for rows.Next() { + var sub PushSubscription + var createdAt string + if err := rows.Scan(&sub.ID, &sub.Endpoint, &sub.P256DHKey, &sub.AuthKey, &createdAt); err != nil { + return nil, err + } + // Parse created_at; ignore errors (use zero time on failure). + for _, layout := range []string{time.RFC3339, "2006-01-02 15:04:05", "2006-01-02T15:04:05Z"} { + if t, err := time.Parse(layout, createdAt); err == nil { + sub.CreatedAt = t + break + } + } + subs = append(subs, sub) + } + if subs == nil { + subs = []PushSubscription{} + } + return subs, rows.Err() +} + +// GetSetting returns the value for a key, or ("", nil) if not found. +func (s *DB) GetSetting(key string) (string, error) { + var value string + err := s.db.QueryRow(`SELECT value FROM settings WHERE key = ?`, key).Scan(&value) + if err == sql.ErrNoRows { + return "", nil + } + return value, err +} + +// SetSetting upserts a key/value pair in the settings table. +func (s *DB) SetSetting(key, value string) error { + _, err := s.db.Exec(`INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value`, key, value) + return err +} + +// AgentEvent records a rate-limit state change for an agent. +type AgentEvent struct { + ID string `json:"id"` + Agent string `json:"agent"` + Event string `json:"event"` // "rate_limited" | "available" + Timestamp time.Time `json:"timestamp"` + Until *time.Time `json:"until,omitempty"` // non-nil for "rate_limited" events + Reason string `json:"reason"` // "transient" | "quota" +} + +// RecordAgentEvent inserts an agent rate-limit event. +func (s *DB) RecordAgentEvent(e AgentEvent) error { + _, err := s.db.Exec( + `INSERT INTO agent_events (id, agent, event, timestamp, until, reason) VALUES (?, ?, ?, ?, ?, ?)`, + e.ID, e.Agent, e.Event, e.Timestamp.UTC(), timeOrNull(e.Until), e.Reason, + ) + return err +} + +// ListAgentEvents returns agent events since the given time, newest first. +func (s *DB) ListAgentEvents(since time.Time) ([]AgentEvent, error) { + rows, err := s.db.Query( + `SELECT id, agent, event, timestamp, until, reason FROM agent_events WHERE timestamp >= ? ORDER BY timestamp DESC LIMIT 500`, + since.UTC(), + ) + if err != nil { + return nil, err + } + defer rows.Close() + var events []AgentEvent + for rows.Next() { + var e AgentEvent + var until sql.NullTime + var reason sql.NullString + if err := rows.Scan(&e.ID, &e.Agent, &e.Event, &e.Timestamp, &until, &reason); err != nil { + return nil, err + } + if until.Valid { + e.Until = &until.Time + } + e.Reason = reason.String + events = append(events, e) + } + return events, rows.Err() +} + +func timeOrNull(t *time.Time) interface{} { + if t == nil { + return nil + } + return t.UTC() +} + +// CreateProject inserts a new project. +func (s *DB) CreateProject(p *task.Project) error { + now := time.Now().UTC() + _, err := s.db.Exec( + `INSERT INTO projects (id, name, remote_url, local_path, type, deploy_script, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + p.ID, p.Name, p.RemoteURL, p.LocalPath, p.Type, p.DeployScript, now, now, + ) + return err +} + +// GetProject retrieves a project by ID. +func (s *DB) GetProject(id string) (*task.Project, error) { + row := s.db.QueryRow(`SELECT id, name, remote_url, local_path, type, deploy_script FROM projects WHERE id = ?`, id) + p := &task.Project{} + if err := row.Scan(&p.ID, &p.Name, &p.RemoteURL, &p.LocalPath, &p.Type, &p.DeployScript); err != nil { + return nil, err + } + return p, nil +} + +// ListProjects returns all projects. +func (s *DB) ListProjects() ([]*task.Project, error) { + rows, err := s.db.Query(`SELECT id, name, remote_url, local_path, type, deploy_script FROM projects ORDER BY name`) + if err != nil { + return nil, err + } + defer rows.Close() + var projects []*task.Project + for rows.Next() { + p := &task.Project{} + if err := rows.Scan(&p.ID, &p.Name, &p.RemoteURL, &p.LocalPath, &p.Type, &p.DeployScript); err != nil { + return nil, err + } + projects = append(projects, p) + } + return projects, rows.Err() +} + +// UpdateProject updates an existing project. +func (s *DB) UpdateProject(p *task.Project) error { + now := time.Now().UTC() + _, err := s.db.Exec( + `UPDATE projects SET name = ?, remote_url = ?, local_path = ?, type = ?, deploy_script = ?, updated_at = ? WHERE id = ?`, + p.Name, p.RemoteURL, p.LocalPath, p.Type, p.DeployScript, now, p.ID, + ) + return err +} + +// UpsertProject inserts or updates a project by ID (used for seeding). +func (s *DB) UpsertProject(p *task.Project) error { + now := time.Now().UTC() + _, err := s.db.Exec( + `INSERT INTO projects (id, name, remote_url, local_path, type, deploy_script, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET name=excluded.name, remote_url=excluded.remote_url, + local_path=excluded.local_path, type=excluded.type, deploy_script=excluded.deploy_script, updated_at=excluded.updated_at`, + p.ID, p.Name, p.RemoteURL, p.LocalPath, p.Type, p.DeployScript, now, now, + ) + return err +} + +// CreateStory inserts a new story. +func (s *DB) CreateStory(st *task.Story) error { + now := time.Now().UTC() + _, err := s.db.Exec( + `INSERT INTO stories (id, name, project_id, branch_name, deploy_config, validation_json, status, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + st.ID, st.Name, st.ProjectID, st.BranchName, st.DeployConfig, st.ValidationJSON, string(st.Status), now, now, + ) + return err +} + +// GetStory retrieves a story by ID. +func (s *DB) GetStory(id string) (*task.Story, error) { + row := s.db.QueryRow(`SELECT id, name, project_id, branch_name, deploy_config, validation_json, status, created_at, updated_at FROM stories WHERE id = ?`, id) + st := &task.Story{} + var status string + if err := row.Scan(&st.ID, &st.Name, &st.ProjectID, &st.BranchName, &st.DeployConfig, &st.ValidationJSON, &status, &st.CreatedAt, &st.UpdatedAt); err != nil { + return nil, err + } + st.Status = task.StoryState(status) + return st, nil +} + +// ListStories returns all stories ordered by creation time descending. +func (s *DB) ListStories() ([]*task.Story, error) { + rows, err := s.db.Query(`SELECT id, name, project_id, branch_name, deploy_config, validation_json, status, created_at, updated_at FROM stories ORDER BY created_at DESC`) + if err != nil { + return nil, err + } + defer rows.Close() + var stories []*task.Story + for rows.Next() { + st := &task.Story{} + var status string + if err := rows.Scan(&st.ID, &st.Name, &st.ProjectID, &st.BranchName, &st.DeployConfig, &st.ValidationJSON, &status, &st.CreatedAt, &st.UpdatedAt); err != nil { + return nil, err + } + st.Status = task.StoryState(status) + stories = append(stories, st) + } + return stories, rows.Err() +} + +// UpdateStoryStatus updates the status of a story. +func (s *DB) UpdateStoryStatus(id string, status task.StoryState) error { + now := time.Now().UTC() + _, err := s.db.Exec(`UPDATE stories SET status = ?, updated_at = ? WHERE id = ?`, string(status), now, id) + return err +} + +// ListTasksByStory returns all tasks associated with a story, ordered by creation time ascending. +func (s *DB) ListTasksByStory(storyID string) ([]*task.Task, error) { + rows, err := s.db.Query( + `SELECT id, name, description, elaboration_input, project, repository_url, config_json, priority, timeout_ns, retry_json, tags_json, depends_on_json, parent_task_id, state, created_at, updated_at, rejection_comment, question_json, summary, interactions_json, story_id, acceptance_criteria, checker_for_task_id, checker_report FROM tasks WHERE story_id = ? ORDER BY created_at ASC`, + storyID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var tasks []*task.Task + for rows.Next() { + t, err := scanTaskRows(rows) + if err != nil { + return nil, err + } + tasks = append(tasks, t) + } + return tasks, rows.Err() +} diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index 752c5b1..0e67e02 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -41,7 +41,6 @@ func TestCreateTask_AndGetTask(t *testing.T) { Type: "claude", Model: "sonnet", Instructions: "do it", - ProjectDir: "/tmp", MaxBudgetUSD: 2.5, }, Priority: task.PriorityHigh, @@ -990,6 +989,128 @@ func TestAppendTaskInteraction_NotFound(t *testing.T) { } } +func TestCreateTask_Project_RoundTrip(t *testing.T) { + db := testDB(t) + now := time.Now().UTC().Truncate(time.Second) + + tk := &task.Task{ + ID: "proj-1", + Name: "Project Task", + Project: "my-project", + Agent: task.AgentConfig{Type: "claude", Instructions: "do it"}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + if err := db.CreateTask(tk); err != nil { + t.Fatalf("creating task: %v", err) + } + + got, err := db.GetTask("proj-1") + if err != nil { + t.Fatalf("getting task: %v", err) + } + if got.Project != "my-project" { + t.Errorf("project: want %q, got %q", "my-project", got.Project) + } +} + +// ── Push subscription tests ─────────────────────────────────────────────────── + +func TestPushSubscription_SaveAndList(t *testing.T) { + db := testDB(t) + + sub := PushSubscription{ + ID: "sub-1", + Endpoint: "https://push.example.com/endpoint1", + P256DHKey: "p256dhkey1", + AuthKey: "authkey1", + } + if err := db.SavePushSubscription(sub); err != nil { + t.Fatalf("SavePushSubscription: %v", err) + } + + subs, err := db.ListPushSubscriptions() + if err != nil { + t.Fatalf("ListPushSubscriptions: %v", err) + } + if len(subs) != 1 { + t.Fatalf("want 1 subscription, got %d", len(subs)) + } + if subs[0].Endpoint != sub.Endpoint { + t.Errorf("endpoint: want %q, got %q", sub.Endpoint, subs[0].Endpoint) + } + if subs[0].P256DHKey != sub.P256DHKey { + t.Errorf("p256dh_key: want %q, got %q", sub.P256DHKey, subs[0].P256DHKey) + } + if subs[0].AuthKey != sub.AuthKey { + t.Errorf("auth_key: want %q, got %q", sub.AuthKey, subs[0].AuthKey) + } +} + +func TestPushSubscription_Delete(t *testing.T) { + db := testDB(t) + + sub := PushSubscription{ + ID: "sub-del", + Endpoint: "https://push.example.com/todelete", + P256DHKey: "key", + AuthKey: "auth", + } + if err := db.SavePushSubscription(sub); err != nil { + t.Fatalf("SavePushSubscription: %v", err) + } + + if err := db.DeletePushSubscription(sub.Endpoint); err != nil { + t.Fatalf("DeletePushSubscription: %v", err) + } + + subs, err := db.ListPushSubscriptions() + if err != nil { + t.Fatalf("ListPushSubscriptions: %v", err) + } + if len(subs) != 0 { + t.Errorf("want 0 subscriptions after delete, got %d", len(subs)) + } +} + +func TestPushSubscription_UniqueEndpoint(t *testing.T) { + db := testDB(t) + + sub := PushSubscription{ + ID: "sub-uq", + Endpoint: "https://push.example.com/unique", + P256DHKey: "key1", + AuthKey: "auth1", + } + if err := db.SavePushSubscription(sub); err != nil { + t.Fatalf("SavePushSubscription first: %v", err) + } + + // Save again with same endpoint — should update or replace, not error. + sub2 := PushSubscription{ + ID: "sub-uq2", + Endpoint: "https://push.example.com/unique", + P256DHKey: "key2", + AuthKey: "auth2", + } + if err := db.SavePushSubscription(sub2); err != nil { + t.Fatalf("SavePushSubscription second (upsert): %v", err) + } + + subs, err := db.ListPushSubscriptions() + if err != nil { + t.Fatalf("ListPushSubscriptions: %v", err) + } + if len(subs) != 1 { + t.Errorf("want 1 subscription after upsert, got %d", len(subs)) + } +} + func TestExecution_StoreAndRetrieveChangestats(t *testing.T) { db := testDB(t) now := time.Now().UTC().Truncate(time.Second) @@ -1032,3 +1153,252 @@ func TestExecution_StoreAndRetrieveChangestats(t *testing.T) { } } +func TestCreateProject(t *testing.T) { + db := testDB(t) + defer db.Close() + + p := &task.Project{ + ID: "proj-1", + Name: "claudomator", + RemoteURL: "/bare/claudomator.git", + LocalPath: "/workspace/claudomator", + Type: "web", + } + if err := db.CreateProject(p); err != nil { + t.Fatalf("CreateProject: %v", err) + } + got, err := db.GetProject("proj-1") + if err != nil { + t.Fatalf("GetProject: %v", err) + } + if got.Name != "claudomator" { + t.Errorf("Name: want claudomator, got %q", got.Name) + } + if got.LocalPath != "/workspace/claudomator" { + t.Errorf("LocalPath: want /workspace/claudomator, got %q", got.LocalPath) + } +} + +func TestListProjects(t *testing.T) { + db := testDB(t) + defer db.Close() + + for _, p := range []*task.Project{ + {ID: "p1", Name: "alpha", Type: "web"}, + {ID: "p2", Name: "beta", Type: "android"}, + } { + if err := db.CreateProject(p); err != nil { + t.Fatalf("CreateProject: %v", err) + } + } + list, err := db.ListProjects() + if err != nil { + t.Fatalf("ListProjects: %v", err) + } + if len(list) != 2 { + t.Errorf("want 2 projects, got %d", len(list)) + } +} + +func TestUpdateProject(t *testing.T) { + db := testDB(t) + defer db.Close() + + p := &task.Project{ID: "p1", Name: "original", Type: "web"} + if err := db.CreateProject(p); err != nil { + t.Fatalf("CreateProject: %v", err) + } + p.Name = "updated" + if err := db.UpdateProject(p); err != nil { + t.Fatalf("UpdateProject: %v", err) + } + got, _ := db.GetProject("p1") + if got.Name != "updated" { + t.Errorf("Name after update: want updated, got %q", got.Name) + } +} + +func TestCreateStory(t *testing.T) { + db := testDB(t) + st := &task.Story{ + ID: "story-1", + Name: "My Story", + Status: task.StoryPending, + } + if err := db.CreateStory(st); err != nil { + t.Fatalf("CreateStory: %v", err) + } +} + +func TestGetStory(t *testing.T) { + db := testDB(t) + st := &task.Story{ + ID: "story-2", + Name: "Get Story", + ProjectID: "proj-1", + Status: task.StoryPending, + } + if err := db.CreateStory(st); err != nil { + t.Fatalf("CreateStory: %v", err) + } + got, err := db.GetStory("story-2") + if err != nil { + t.Fatalf("GetStory: %v", err) + } + if got.Name != "Get Story" { + t.Errorf("Name: want 'Get Story', got %q", got.Name) + } + if got.ProjectID != "proj-1" { + t.Errorf("ProjectID: want 'proj-1', got %q", got.ProjectID) + } + if got.Status != task.StoryPending { + t.Errorf("Status: want PENDING, got %q", got.Status) + } +} + +func TestListStories(t *testing.T) { + db := testDB(t) + for _, name := range []string{"A", "B", "C"} { + if err := db.CreateStory(&task.Story{ID: name, Name: name, Status: task.StoryPending}); err != nil { + t.Fatalf("CreateStory %s: %v", name, err) + } + } + stories, err := db.ListStories() + if err != nil { + t.Fatalf("ListStories: %v", err) + } + if len(stories) != 3 { + t.Errorf("want 3 stories, got %d", len(stories)) + } +} + +func TestUpdateStoryStatus(t *testing.T) { + db := testDB(t) + st := &task.Story{ID: "story-upd", Name: "Upd", Status: task.StoryPending} + if err := db.CreateStory(st); err != nil { + t.Fatalf("CreateStory: %v", err) + } + if err := db.UpdateStoryStatus("story-upd", task.StoryInProgress); err != nil { + t.Fatalf("UpdateStoryStatus: %v", err) + } + got, _ := db.GetStory("story-upd") + if got.Status != task.StoryInProgress { + t.Errorf("Status: want IN_PROGRESS, got %q", got.Status) + } +} + +func TestListTasksByStory(t *testing.T) { + db := testDB(t) + now := time.Now().UTC() + + if err := db.CreateStory(&task.Story{ID: "story-tasks", Name: "S", Status: task.StoryPending}); err != nil { + t.Fatalf("CreateStory: %v", err) + } + + makeTask := func(id string) *task.Task { + return &task.Task{ + ID: id, + Name: id, + StoryID: "story-tasks", + Agent: task.AgentConfig{Type: "claude"}, + Priority: task.PriorityNormal, + Tags: []string{}, + DependsOn: []string{}, + Retry: task.RetryConfig{MaxAttempts: 1}, + State: task.StatePending, + CreatedAt: now, + UpdatedAt: now, + } + } + + if err := db.CreateTask(makeTask("t1")); err != nil { + t.Fatal(err) + } + if err := db.CreateTask(makeTask("t2")); err != nil { + t.Fatal(err) + } + + tasks, err := db.ListTasksByStory("story-tasks") + if err != nil { + t.Fatalf("ListTasksByStory: %v", err) + } + if len(tasks) != 2 { + t.Errorf("want 2 tasks, got %d", len(tasks)) + } + for _, tk := range tasks { + if tk.StoryID != "story-tasks" { + t.Errorf("task %s: StoryID want 'story-tasks', got %q", tk.ID, tk.StoryID) + } + } +} + +func TestUpdateTaskCheckerReport(t *testing.T) { + db := testDB(t) + tk := &task.Task{ + ID: "cr-1", Name: "orig", RepositoryURL: "https://github.com/x/y", + Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + } + if err := db.CreateTask(tk); err != nil { + t.Fatalf("CreateTask: %v", err) + } + if err := db.UpdateTaskCheckerReport("cr-1", "Tests failed: missing endpoint"); err != nil { + t.Fatalf("UpdateTaskCheckerReport: %v", err) + } + got, err := db.GetTask("cr-1") + if err != nil { + t.Fatalf("GetTask: %v", err) + } + if got.CheckerReport != "Tests failed: missing endpoint" { + t.Errorf("expected checker report, got %q", got.CheckerReport) + } +} + +func TestGetCheckerTask(t *testing.T) { + db := testDB(t) + checked := &task.Task{ + ID: "chk-orig", Name: "orig", RepositoryURL: "https://github.com/x/y", + Agent: task.AgentConfig{Type: "claude", Instructions: "x"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + } + if err := db.CreateTask(checked); err != nil { + t.Fatalf("CreateTask checked: %v", err) + } + checker := &task.Task{ + ID: "chk-checker", Name: "Check: orig", CheckerForTaskID: "chk-orig", + RepositoryURL: "https://github.com/x/y", + Agent: task.AgentConfig{Type: "claude", Instructions: "validate"}, + Priority: task.PriorityNormal, + Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"}, + Tags: []string{}, DependsOn: []string{}, + State: task.StatePending, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + } + if err := db.CreateTask(checker); err != nil { + t.Fatalf("CreateTask checker: %v", err) + } + + // Should find the checker task. + got, err := db.GetCheckerTask("chk-orig") + if err != nil { + t.Fatalf("GetCheckerTask: %v", err) + } + if got == nil || got.ID != "chk-checker" { + t.Errorf("expected checker task ID chk-checker, got %v", got) + } + + // Should return nil when no checker exists. + none, err := db.GetCheckerTask("nonexistent") + if err != nil { + t.Fatalf("GetCheckerTask nonexistent: %v", err) + } + if none != nil { + t.Errorf("expected nil for task with no checker, got %v", none) + } +} + diff --git a/internal/storage/seed.go b/internal/storage/seed.go new file mode 100644 index 0000000..c2df84f --- /dev/null +++ b/internal/storage/seed.go @@ -0,0 +1,62 @@ +package storage + +import ( + "os/exec" + "strings" + + "github.com/thepeterstone/claudomator/internal/task" +) + +// SeedProjects upserts the default project registry on startup. +func (s *DB) SeedProjects() error { + projects := []*task.Project{ + { + ID: "claudomator", + Name: "claudomator", + LocalPath: "/workspace/claudomator", + RemoteURL: localBareRemote("/workspace/claudomator"), + Type: "web", + DeployScript: "/workspace/claudomator/scripts/deploy", + }, + { + ID: "nav", + Name: "nav", + LocalPath: "/workspace/nav", + RemoteURL: localBareRemote("/workspace/nav"), + Type: "android", + }, + { + ID: "doot", + Name: "doot", + LocalPath: "/workspace/doot", + RemoteURL: localBareRemote("/workspace/doot"), + Type: "web", + DeployScript: "/workspace/doot/scripts/deploy", + }, + { + ID: "modal-shell", + Name: "modal-shell", + LocalPath: "/workspace/modal-shell", + RemoteURL: localBareRemote("/workspace/modal-shell"), + Type: "web", + }, + } + for _, p := range projects { + if err := s.UpsertProject(p); err != nil { + return err + } + } + return nil +} + +// localBareRemote returns the URL of the "local" git remote for dir, +// falling back to dir itself if the remote is not configured. +func localBareRemote(dir string) string { + out, err := exec.Command("git", "-C", dir, "remote", "get-url", "local").Output() + if err == nil { + if url := strings.TrimSpace(string(out)); url != "" { + return url + } + } + return dir +} diff --git a/internal/storage/sqlite_cgo.go b/internal/storage/sqlite_cgo.go new file mode 100644 index 0000000..0956852 --- /dev/null +++ b/internal/storage/sqlite_cgo.go @@ -0,0 +1,5 @@ +//go:build cgo + +package storage + +import _ "github.com/mattn/go-sqlite3" diff --git a/internal/storage/sqlite_nocgo.go b/internal/storage/sqlite_nocgo.go new file mode 100644 index 0000000..9862440 --- /dev/null +++ b/internal/storage/sqlite_nocgo.go @@ -0,0 +1,21 @@ +//go:build !cgo + +package storage + +import ( + "database/sql" + "database/sql/driver" + + modernc "modernc.org/sqlite" +) + +// Register the modernc pure-Go SQLite driver under the "sqlite3" name so that +// the rest of the codebase can use sql.Open("sqlite3", ...) regardless of +// whether CGO is enabled. +func init() { + sql.Register("sqlite3", &modernc.Driver{}) +} + +// modernc.Driver satisfies driver.Driver; this blank-import ensures the +// compiler sees the interface is satisfied. +var _ driver.Driver = (*modernc.Driver)(nil) |
