diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/server.go | 14 | ||||
| -rw-r--r-- | internal/cli/serve.go | 9 |
2 files changed, 16 insertions, 7 deletions
diff --git a/internal/api/server.go b/internal/api/server.go index 3bc4147..604f354 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -31,6 +31,7 @@ type questionStore interface { // Server provides the REST API and WebSocket endpoint for Claudomator. type Server struct { + ctx context.Context // server lifecycle context; used for pool submissions store *storage.DB logStore logStore // injectable for tests; defaults to store taskLogStore taskLogStore // injectable for tests; defaults to store @@ -63,6 +64,12 @@ func (s *Server) SetAPIToken(token string) { s.apiToken = token } +// SetContext replaces the server's lifecycle context used for pool submissions. +// Call this before StartHub to tie task submissions to the server's shutdown signal. +func (s *Server) SetContext(ctx context.Context) { + s.ctx = ctx +} + // SetNotifier configures a notifier that is called on every task completion. func (s *Server) SetNotifier(n notify.Notifier) { s.notifier = n @@ -85,6 +92,7 @@ func (s *Server) Pool() *executor.Pool { return s.pool } func NewServer(store *storage.DB, pool *executor.Pool, logger *slog.Logger, claudeBinPath, geminiBinPath string) *Server { wd, _ := os.Getwd() s := &Server{ + ctx: context.Background(), store: store, logStore: store, taskLogStore: store, @@ -344,7 +352,7 @@ func (s *Server) handleAnswerQuestion(w http.ResponseWriter, r *http.Request) { ResumeAnswer: input.Answer, SandboxDir: latest.SandboxDir, } - if err := s.pool.SubmitResume(context.Background(), tk, resumeExec); err != nil { + if err := s.pool.SubmitResume(s.ctx, tk, resumeExec); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": err.Error()}) return } @@ -389,7 +397,7 @@ func (s *Server) handleResumeTimedOutTask(w http.ResponseWriter, r *http.Request ResumeSessionID: latest.SessionID, ResumeAnswer: resumeMsg, } - if err := s.pool.SubmitResume(context.Background(), tk, resumeExec); err != nil { + if err := s.pool.SubmitResume(s.ctx, tk, resumeExec); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": err.Error()}) return } @@ -661,7 +669,7 @@ func (s *Server) handleRunTask(w http.ResponseWriter, r *http.Request) { // task isn't immediately re-cancelled by checkDepsReady. s.cascadeRetryDeps(r.Context(), originalTask) - if err := s.pool.Submit(context.Background(), t); err != nil { + if err := s.pool.Submit(s.ctx, t); err != nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": fmt.Sprintf("executor pool: %v", err)}) return } diff --git a/internal/cli/serve.go b/internal/cli/serve.go index f7493ed..581a064 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -163,6 +163,11 @@ func serve(addr string) error { "deploy": filepath.Join(wd, "scripts", "deploy"), }) + // Graceful shutdown. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + srv.SetContext(ctx) srv.StartHub() httpSrv := &http.Server{ @@ -170,10 +175,6 @@ func serve(addr string) error { Handler: srv.Handler(), } - // Graceful shutdown. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - workerTimeout := 3 * time.Minute if cfg.ShutdownTimeout > 0 { workerTimeout = cfg.ShutdownTimeout |
