summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/api/server_test.go124
1 files changed, 120 insertions, 4 deletions
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index 765b813..5047af6 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -3,6 +3,7 @@ package api
import (
"bytes"
"encoding/json"
+ "errors"
"fmt"
"log/slog"
"net/http"
@@ -75,6 +76,11 @@ func TestServer_ProcessResult_CallsNotifier(t *testing.T) {
func testServer(t *testing.T) (*Server, *storage.DB) {
t.Helper()
+ return testServerWithRunner(t, &mockRunner{})
+}
+
+func testServerWithRunner(t *testing.T, runner executor.Runner) (*Server, *storage.DB) {
+ t.Helper()
dbPath := filepath.Join(t.TempDir(), "test.db")
store, err := storage.Open(dbPath)
if err != nil {
@@ -83,7 +89,6 @@ func testServer(t *testing.T) (*Server, *storage.DB) {
t.Cleanup(func() { store.Close() })
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
- runner := &mockRunner{}
runners := map[string]executor.Runner{
"claude": runner,
"gemini": runner,
@@ -93,10 +98,38 @@ func testServer(t *testing.T) (*Server, *storage.DB) {
return srv, store
}
-type mockRunner struct{}
+type mockRunner struct {
+ err error
+ sleep time.Duration
+}
-func (m *mockRunner) Run(_ context.Context, _ *task.Task, _ *storage.Execution) error {
- return nil
+func (m *mockRunner) Run(ctx context.Context, _ *task.Task, _ *storage.Execution) error {
+ if m.sleep > 0 {
+ select {
+ case <-time.After(m.sleep):
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+ return m.err
+}
+
+// pollState polls store.GetTask until the task reaches wantState or the timeout elapses.
+func pollState(t *testing.T, store *storage.DB, taskID string, wantState task.State, timeout time.Duration) task.State {
+ t.Helper()
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ tk, err := store.GetTask(taskID)
+ if err == nil && tk.State == wantState {
+ return tk.State
+ }
+ time.Sleep(5 * time.Millisecond)
+ }
+ tk, _ := store.GetTask(taskID)
+ if tk != nil {
+ return tk.State
+ }
+ return ""
}
func TestListWorkspaces_UsesConfiguredRoot(t *testing.T) {
@@ -1026,3 +1059,86 @@ func TestRateLimit_ValidateRejectsExcess(t *testing.T) {
t.Errorf("second validate request from same IP should be 429, got %d", code)
}
}
+
+func TestRunTask_AgentFails_TaskSetToFailed(t *testing.T) {
+ runner := &mockRunner{err: errors.New("agent error")}
+ srv, store := testServerWithRunner(t, runner)
+ createTaskWithState(t, store, "async-fail-1", task.StatePending)
+
+ req := httptest.NewRequest("POST", "/api/tasks/async-fail-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ got := pollState(t, store, "async-fail-1", task.StateFailed, 2*time.Second)
+ if got != task.StateFailed {
+ t.Errorf("task state: want FAILED, got %v", got)
+ }
+}
+
+func TestRunTask_AgentTimesOut_TaskSetToTimedOut(t *testing.T) {
+ runner := &mockRunner{sleep: 5 * time.Second}
+ srv, store := testServerWithRunner(t, runner)
+
+ tk := &task.Task{
+ ID: "async-timeout-1",
+ Name: "timeout-test",
+ Agent: task.AgentConfig{Type: "claude", Instructions: "do something"},
+ Priority: task.PriorityNormal,
+ Retry: task.RetryConfig{MaxAttempts: 1, Backoff: "linear"},
+ Tags: []string{},
+ DependsOn: []string{},
+ State: task.StatePending,
+ Timeout: task.Duration{Duration: 50 * time.Millisecond},
+ }
+ if err := store.CreateTask(tk); err != nil {
+ t.Fatal(err)
+ }
+
+ req := httptest.NewRequest("POST", "/api/tasks/async-timeout-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ got := pollState(t, store, "async-timeout-1", task.StateTimedOut, 2*time.Second)
+ if got != task.StateTimedOut {
+ t.Errorf("task state: want TIMED_OUT, got %v", got)
+ }
+}
+
+func TestRunTask_AgentCancelled_TaskSetToCancelled(t *testing.T) {
+ runner := &mockRunner{sleep: 5 * time.Second}
+ srv, store := testServerWithRunner(t, runner)
+ createTaskWithState(t, store, "async-cancel-1", task.StatePending)
+
+ req := httptest.NewRequest("POST", "/api/tasks/async-cancel-1/run", nil)
+ w := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(w, req)
+
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("POST /run status: want 202, got %d; body: %s", w.Code, w.Body.String())
+ }
+
+ // Wait for the pool to start the task (cancel func must be registered).
+ pollState(t, store, "async-cancel-1", task.StateRunning, 2*time.Second)
+
+ // Cancel via the API — pool.Cancel() cancels the context; runner returns ctx.Err().
+ cancelReq := httptest.NewRequest("POST", "/api/tasks/async-cancel-1/cancel", nil)
+ cancelW := httptest.NewRecorder()
+ srv.Handler().ServeHTTP(cancelW, cancelReq)
+
+ if cancelW.Code != http.StatusOK {
+ t.Fatalf("POST /cancel status: want 200, got %d; body: %s", cancelW.Code, cancelW.Body.String())
+ }
+
+ got := pollState(t, store, "async-cancel-1", task.StateCancelled, 2*time.Second)
+ if got != task.StateCancelled {
+ t.Errorf("task state: want CANCELLED, got %v", got)
+ }
+}