diff options
Diffstat (limited to 'internal/api')
| -rw-r--r-- | internal/api/server.go | 23 | ||||
| -rw-r--r-- | internal/api/server_test.go | 5 |
2 files changed, 11 insertions, 17 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index a6e708a..944e450 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -459,24 +459,19 @@ func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) { func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - t, err := s.store.GetTask(id) + t, err := s.store.ResetTaskForRetry(id) if err != nil { - writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) - return - } - - if !task.ValidTransition(t.State, task.StateQueued) { - writeJSON(w, http.StatusConflict, map[string]string{ - "error": fmt.Sprintf("task cannot be queued from state %s", t.State), - }) - return - } - - if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil { + if strings.Contains(err.Error(), "not found") { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"}) + return + } + if strings.Contains(err.Error(), "invalid state transition") { + writeJSON(w, http.StatusConflict, map[string]string{"error": err.Error()}) + return + } writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } - t.State = task.StateQueued if err := s.pool.Submit(context.Background(), t); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)}) diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 9b5c7ae..afdc9d2 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -397,9 +397,8 @@ func TestRunTask_CompletedTask_Returns409(t *testing.T) { } var body map[string]string json.NewDecoder(w.Body).Decode(&body) - wantMsg := "task cannot be queued from state COMPLETED" - if body["error"] != wantMsg { - t.Errorf("error body: want %q, got %q", wantMsg, body["error"]) + if !strings.Contains(body["error"], "invalid state transition") { + t.Errorf("error body: want it to contain 'invalid state transition', got %q", body["error"]) } } |
