summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/server.go23
-rw-r--r--internal/api/server_test.go5
2 files changed, 11 insertions, 17 deletions
diff --git a/internal/api/server.go b/internal/api/server.go
index a6e708a..944e450 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -459,24 +459,19 @@ func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
- t, err := s.store.GetTask(id)
+ t, err := s.store.ResetTaskForRetry(id)
if err != nil {
- writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
- return
- }
-
- if !task.ValidTransition(t.State, task.StateQueued) {
- writeJSON(w, http.StatusConflict, map[string]string{
- "error": fmt.Sprintf("task cannot be queued from state %s", t.State),
- })
- return
- }
-
- if err := s.store.UpdateTaskState(id, task.StateQueued); err != nil {
+ if strings.Contains(err.Error(), "not found") {
+ writeJSON(w, http.StatusNotFound, map[string]string{"error": "task not found"})
+ return
+ }
+ if strings.Contains(err.Error(), "invalid state transition") {
+ writeJSON(w, http.StatusConflict, map[string]string{"error": err.Error()})
+ return
+ }
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
- t.State = task.StateQueued
if err := s.pool.Submit(context.Background(), t); err != nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)})
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index 9b5c7ae..afdc9d2 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -397,9 +397,8 @@ func TestRunTask_CompletedTask_Returns409(t *testing.T) {
}
var body map[string]string
json.NewDecoder(w.Body).Decode(&body)
- wantMsg := "task cannot be queued from state COMPLETED"
- if body["error"] != wantMsg {
- t.Errorf("error body: want %q, got %q", wantMsg, body["error"])
+ if !strings.Contains(body["error"], "invalid state transition") {
+ t.Errorf("error body: want it to contain 'invalid state transition', got %q", body["error"])
}
}