summaryrefslogtreecommitdiff
path: root/internal/executor/executor.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/executor/executor.go')
-rw-r--r--internal/executor/executor.go88
1 files changed, 68 insertions, 20 deletions
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 4bb1f2c..baeb399 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -15,6 +15,18 @@ import (
"github.com/google/uuid"
)
+// Store is the subset of storage.DB methods used by the Pool.
+// Defining it as an interface allows test doubles to be injected.
+type Store interface {
+ GetTask(id string) (*task.Task, error)
+ ListTasks(filter storage.TaskFilter) ([]*task.Task, error)
+ ListExecutions(taskID string) ([]*storage.Execution, error)
+ CreateExecution(e *storage.Execution) error
+ UpdateExecution(e *storage.Execution) error
+ UpdateTaskState(id string, newState task.State) error
+ UpdateTaskQuestion(taskID, questionJSON string) error
+}
+
// LogPather is an optional interface runners can implement to provide the log
// directory for an execution before it starts. The pool uses this to persist
// log paths at CreateExecution time rather than waiting until execution ends.
@@ -38,7 +50,7 @@ type workItem struct {
type Pool struct {
maxConcurrent int
runners map[string]Runner
- store *storage.DB
+ store Store
logger *slog.Logger
depPollInterval time.Duration // how often waitForDependencies polls; defaults to 5s
@@ -61,7 +73,7 @@ type Result struct {
Err error
}
-func NewPool(maxConcurrent int, runners map[string]Runner, store *storage.DB, logger *slog.Logger) *Pool {
+func NewPool(maxConcurrent int, runners map[string]Runner, store Store, logger *slog.Logger) *Pool {
if maxConcurrent < 1 {
maxConcurrent = 1
}
@@ -252,32 +264,48 @@ func (p *Pool) executeResume(ctx context.Context, t *task.Task, exec *storage.Ex
var blockedErr *BlockedError
if errors.As(err, &blockedErr) {
exec.Status = "BLOCKED"
- p.store.UpdateTaskState(t.ID, task.StateBlocked)
- p.store.UpdateTaskQuestion(t.ID, blockedErr.QuestionJSON)
+ if err := p.store.UpdateTaskState(t.ID, task.StateBlocked); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateBlocked, "error", err)
+ }
+ if err := p.store.UpdateTaskQuestion(t.ID, blockedErr.QuestionJSON); err != nil {
+ p.logger.Error("failed to update task question", "taskID", t.ID, "error", err)
+ }
} else if ctx.Err() == context.DeadlineExceeded {
exec.Status = "TIMED_OUT"
exec.ErrorMsg = "execution timed out"
- p.store.UpdateTaskState(t.ID, task.StateTimedOut)
+ if err := p.store.UpdateTaskState(t.ID, task.StateTimedOut); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateTimedOut, "error", err)
+ }
} else if ctx.Err() == context.Canceled {
exec.Status = "CANCELLED"
exec.ErrorMsg = "execution cancelled"
- p.store.UpdateTaskState(t.ID, task.StateCancelled)
+ if err := p.store.UpdateTaskState(t.ID, task.StateCancelled); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateCancelled, "error", err)
+ }
} else if isQuotaExhausted(err) {
exec.Status = "BUDGET_EXCEEDED"
exec.ErrorMsg = err.Error()
- p.store.UpdateTaskState(t.ID, task.StateBudgetExceeded)
+ if err := p.store.UpdateTaskState(t.ID, task.StateBudgetExceeded); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateBudgetExceeded, "error", err)
+ }
} else {
exec.Status = "FAILED"
exec.ErrorMsg = err.Error()
- p.store.UpdateTaskState(t.ID, task.StateFailed)
+ if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err)
+ }
}
} else {
if t.ParentTaskID == "" {
exec.Status = "READY"
- p.store.UpdateTaskState(t.ID, task.StateReady)
+ if err := p.store.UpdateTaskState(t.ID, task.StateReady); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateReady, "error", err)
+ }
} else {
exec.Status = "COMPLETED"
- p.store.UpdateTaskState(t.ID, task.StateCompleted)
+ if err := p.store.UpdateTaskState(t.ID, task.StateCompleted); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateCompleted, "error", err)
+ }
}
}
@@ -371,7 +399,9 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
if createErr := p.store.CreateExecution(exec); createErr != nil {
p.logger.Error("failed to create execution record", "error", createErr)
}
- p.store.UpdateTaskState(t.ID, task.StateFailed)
+ if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err)
+ }
p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err}
return
}
@@ -391,7 +421,9 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
if createErr := p.store.CreateExecution(exec); createErr != nil {
p.logger.Error("failed to create execution record", "error", createErr)
}
- p.store.UpdateTaskState(t.ID, task.StateFailed)
+ if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err)
+ }
p.resultCh <- &Result{TaskID: t.ID, Execution: exec, Err: err}
return
}
@@ -467,32 +499,48 @@ func (p *Pool) execute(ctx context.Context, t *task.Task) {
var blockedErr *BlockedError
if errors.As(err, &blockedErr) {
exec.Status = "BLOCKED"
- p.store.UpdateTaskState(t.ID, task.StateBlocked)
- p.store.UpdateTaskQuestion(t.ID, blockedErr.QuestionJSON)
+ if err := p.store.UpdateTaskState(t.ID, task.StateBlocked); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateBlocked, "error", err)
+ }
+ if err := p.store.UpdateTaskQuestion(t.ID, blockedErr.QuestionJSON); err != nil {
+ p.logger.Error("failed to update task question", "taskID", t.ID, "error", err)
+ }
} else if ctx.Err() == context.DeadlineExceeded {
exec.Status = "TIMED_OUT"
exec.ErrorMsg = "execution timed out"
- p.store.UpdateTaskState(t.ID, task.StateTimedOut)
+ if err := p.store.UpdateTaskState(t.ID, task.StateTimedOut); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateTimedOut, "error", err)
+ }
} else if ctx.Err() == context.Canceled {
exec.Status = "CANCELLED"
exec.ErrorMsg = "execution cancelled"
- p.store.UpdateTaskState(t.ID, task.StateCancelled)
+ if err := p.store.UpdateTaskState(t.ID, task.StateCancelled); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateCancelled, "error", err)
+ }
} else if isQuotaExhausted(err) {
exec.Status = "BUDGET_EXCEEDED"
exec.ErrorMsg = err.Error()
- p.store.UpdateTaskState(t.ID, task.StateBudgetExceeded)
+ if err := p.store.UpdateTaskState(t.ID, task.StateBudgetExceeded); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateBudgetExceeded, "error", err)
+ }
} else {
exec.Status = "FAILED"
exec.ErrorMsg = err.Error()
- p.store.UpdateTaskState(t.ID, task.StateFailed)
+ if err := p.store.UpdateTaskState(t.ID, task.StateFailed); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateFailed, "error", err)
+ }
}
} else {
if t.ParentTaskID == "" {
exec.Status = "READY"
- p.store.UpdateTaskState(t.ID, task.StateReady)
+ if err := p.store.UpdateTaskState(t.ID, task.StateReady); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateReady, "error", err)
+ }
} else {
exec.Status = "COMPLETED"
- p.store.UpdateTaskState(t.ID, task.StateCompleted)
+ if err := p.store.UpdateTaskState(t.ID, task.StateCompleted); err != nil {
+ p.logger.Error("failed to update task state", "taskID", t.ID, "state", task.StateCompleted, "error", err)
+ }
}
}