From 1ce83b6b6a300f4389dd84c4477f3ca73a431524 Mon Sep 17 00:00:00 2001 From: Peter Stone Date: Sun, 8 Mar 2026 20:40:55 +0000 Subject: cli: newLogger helper, defaultServerURL, shared http client, report command - Extract newLogger() to remove duplication across run/serve/start - Add defaultServerURL const ("http://localhost:8484") used by all client commands - Move http.Client into internal/cli/http.go with 30s timeout - Add 'report' command for printing execution summaries - Add test coverage for create and serve commands Co-Authored-By: Claude Sonnet 4.6 --- internal/cli/create_test.go | 125 ++++++++++++++++++++++++++++++++++++++++++++ internal/cli/http.go | 10 ++++ internal/cli/report.go | 74 ++++++++++++++++++++++++++ internal/cli/report_test.go | 32 ++++++++++++ internal/cli/root.go | 21 +++++++- internal/cli/run.go | 7 +-- internal/cli/serve.go | 15 +++--- internal/cli/serve_test.go | 91 ++++++++++++++++++++++++++++++++ internal/cli/start.go | 12 +++-- 9 files changed, 369 insertions(+), 18 deletions(-) create mode 100644 internal/cli/create_test.go create mode 100644 internal/cli/http.go create mode 100644 internal/cli/report.go create mode 100644 internal/cli/report_test.go create mode 100644 internal/cli/serve_test.go (limited to 'internal/cli') diff --git a/internal/cli/create_test.go b/internal/cli/create_test.go new file mode 100644 index 0000000..22ce6bd --- /dev/null +++ b/internal/cli/create_test.go @@ -0,0 +1,125 @@ +package cli + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestCreateTask_TimesOut(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(5 * time.Second): // fallback so srv.Close() never deadlocks + } + })) + defer srv.Close() + + orig := httpClient + httpClient = &http.Client{Timeout: 50 * time.Millisecond} + defer func() { httpClient = orig }() + + err := createTask(srv.URL, "test", "do something", "", "", "", 1.0, "15m", "normal", false) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "POST /api/tasks") { + t.Errorf("expected error mentioning POST /api/tasks, got: %v", err) + } +} + +func TestStartTask_EscapesTaskID(t *testing.T) { + var capturedPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.RawPath + if capturedPath == "" { + capturedPath = r.URL.Path + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer srv.Close() + + err := startTask(srv.URL, "task/with/slashes") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if strings.Contains(capturedPath, "task/with/slashes") { + t.Errorf("task ID was not escaped; raw path contains unescaped slashes: %s", capturedPath) + } + if !strings.Contains(capturedPath, "task%2Fwith%2Fslashes") { + t.Errorf("expected escaped path segment, got: %s", capturedPath) + } +} + +func TestCreateTask_MissingIDField_ReturnsError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"name":"test"}`)) // no "id" field + })) + defer srv.Close() + + err := createTask(srv.URL, "test", "do something", "", "", "", 1.0, "15m", "normal", false) + if err == nil { + t.Fatal("expected error for missing id field, got nil") + } + if !strings.Contains(err.Error(), "without id") { + t.Errorf("expected error mentioning missing id, got: %v", err) + } +} + +func TestCreateTask_NonJSONResponse_ReturnsError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte(`502 Bad Gateway`)) + })) + defer srv.Close() + + err := createTask(srv.URL, "test", "do something", "", "", "", 1.0, "15m", "normal", false) + if err == nil { + t.Fatal("expected error for non-JSON response, got nil") + } + if !strings.Contains(err.Error(), "invalid JSON") { + t.Errorf("expected error mentioning invalid JSON, got: %v", err) + } +} + +func TestStartTask_NonJSONResponse_ReturnsError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte(`502 Bad Gateway`)) + })) + defer srv.Close() + + err := startTask(srv.URL, "task-abc") + if err == nil { + t.Fatal("expected error for non-JSON response, got nil") + } + if !strings.Contains(err.Error(), "invalid JSON") { + t.Errorf("expected error mentioning invalid JSON, got: %v", err) + } +} + +func TestStartTask_TimesOut(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(5 * time.Second): // fallback so srv.Close() never deadlocks + } + })) + defer srv.Close() + + orig := httpClient + httpClient = &http.Client{Timeout: 50 * time.Millisecond} + defer func() { httpClient = orig }() + + err := startTask(srv.URL, "task-abc") + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "POST") { + t.Errorf("expected error mentioning POST, got: %v", err) + } +} diff --git a/internal/cli/http.go b/internal/cli/http.go new file mode 100644 index 0000000..907818a --- /dev/null +++ b/internal/cli/http.go @@ -0,0 +1,10 @@ +package cli + +import ( + "net/http" + "time" +) + +const httpTimeout = 30 * time.Second + +var httpClient = &http.Client{Timeout: httpTimeout} diff --git a/internal/cli/report.go b/internal/cli/report.go new file mode 100644 index 0000000..7f95c80 --- /dev/null +++ b/internal/cli/report.go @@ -0,0 +1,74 @@ +package cli + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + "github.com/thepeterstone/claudomator/internal/reporter" + "github.com/thepeterstone/claudomator/internal/storage" +) + +func newReportCmd() *cobra.Command { + var format string + var limit int + var taskID string + + cmd := &cobra.Command{ + Use: "report", + Short: "Report execution history", + RunE: func(cmd *cobra.Command, args []string) error { + return runReport(format, limit, taskID) + }, + } + + cmd.Flags().StringVar(&format, "format", "table", "output format: table, json, html") + cmd.Flags().IntVar(&limit, "limit", 50, "maximum number of executions to show") + cmd.Flags().StringVar(&taskID, "task", "", "filter by task ID") + + return cmd +} + +func runReport(format string, limit int, taskID string) error { + var rep reporter.Reporter + switch format { + case "table", "": + rep = &reporter.ConsoleReporter{} + case "json": + rep = &reporter.JSONReporter{Pretty: true} + case "html": + rep = &reporter.HTMLReporter{} + default: + return fmt.Errorf("invalid format %q: must be table, json, or html", format) + } + + store, err := storage.Open(cfg.DBPath) + if err != nil { + return fmt.Errorf("opening db: %w", err) + } + defer store.Close() + + recent, err := store.ListRecentExecutions(time.Time{}, limit, taskID) + if err != nil { + return fmt.Errorf("listing executions: %w", err) + } + + execs := make([]*storage.Execution, len(recent)) + for i, r := range recent { + e := &storage.Execution{ + ID: r.ID, + TaskID: r.TaskID, + Status: r.State, + StartTime: r.StartedAt, + ExitCode: r.ExitCode, + CostUSD: r.CostUSD, + } + if r.FinishedAt != nil { + e.EndTime = *r.FinishedAt + } + execs[i] = e + } + + return rep.Generate(os.Stdout, execs) +} diff --git a/internal/cli/report_test.go b/internal/cli/report_test.go new file mode 100644 index 0000000..3ef96f4 --- /dev/null +++ b/internal/cli/report_test.go @@ -0,0 +1,32 @@ +package cli + +import ( + "strings" + "testing" +) + +func TestReportCmd_InvalidFormat(t *testing.T) { + cmd := newReportCmd() + cmd.SetArgs([]string{"--format", "xml"}) + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for invalid format, got nil") + } + if !strings.Contains(err.Error(), "format") { + t.Errorf("expected error to mention 'format', got: %v", err) + } +} + +func TestReportCmd_DefaultsRegistered(t *testing.T) { + cmd := newReportCmd() + f := cmd.Flags() + if f.Lookup("format") == nil { + t.Error("missing --format flag") + } + if f.Lookup("limit") == nil { + t.Error("missing --limit flag") + } + if f.Lookup("task") == nil { + t.Error("missing --task flag") + } +} diff --git a/internal/cli/root.go b/internal/cli/root.go index 1a528fb..ab6ac1f 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -1,12 +1,17 @@ package cli import ( + "fmt" + "log/slog" + "os" "path/filepath" "github.com/thepeterstone/claudomator/internal/config" "github.com/spf13/cobra" ) +const defaultServerURL = "http://localhost:8484" + var ( cfgFile string verbose bool @@ -14,7 +19,12 @@ var ( ) func NewRootCmd() *cobra.Command { - cfg = config.Default() + var err error + cfg, err = config.Default() + if err != nil { + fmt.Fprintf(os.Stderr, "fatal: %v\n", err) + os.Exit(1) + } cmd := &cobra.Command{ Use: "claudomator", @@ -43,6 +53,7 @@ func NewRootCmd() *cobra.Command { newLogsCmd(), newStartCmd(), newCreateCmd(), + newReportCmd(), ) return cmd @@ -51,3 +62,11 @@ func NewRootCmd() *cobra.Command { func Execute() error { return NewRootCmd().Execute() } + +func newLogger(v bool) *slog.Logger { + level := slog.LevelInfo + if v { + level = slog.LevelDebug + } + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} diff --git a/internal/cli/run.go b/internal/cli/run.go index ed831f5..ebf371c 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -3,7 +3,6 @@ package cli import ( "context" "fmt" - "log/slog" "os" "os/signal" "syscall" @@ -71,11 +70,7 @@ func runTasks(file string, parallel int, dryRun bool) error { } defer store.Close() - level := slog.LevelInfo - if verbose { - level = slog.LevelDebug - } - logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) + logger := newLogger(verbose) runner := &executor.ClaudeRunner{ BinaryPath: cfg.ClaudeBinaryPath, diff --git a/internal/cli/serve.go b/internal/cli/serve.go index 363e276..cd5bfce 100644 --- a/internal/cli/serve.go +++ b/internal/cli/serve.go @@ -3,7 +3,6 @@ package cli import ( "context" "fmt" - "log/slog" "net/http" "os" "os/signal" @@ -12,6 +11,7 @@ import ( "github.com/thepeterstone/claudomator/internal/api" "github.com/thepeterstone/claudomator/internal/executor" + "github.com/thepeterstone/claudomator/internal/notify" "github.com/thepeterstone/claudomator/internal/storage" "github.com/thepeterstone/claudomator/internal/version" "github.com/spf13/cobra" @@ -44,11 +44,7 @@ func serve(addr string) error { } defer store.Close() - level := slog.LevelInfo - if verbose { - level = slog.LevelDebug - } - logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) + logger := newLogger(verbose) apiURL := "http://localhost" + addr if len(addr) > 0 && addr[0] != ':' { @@ -63,6 +59,9 @@ func serve(addr string) error { pool := executor.NewPool(cfg.MaxConcurrent, runner, store, logger) srv := api.NewServer(store, pool, logger, cfg.ClaudeBinaryPath) + if cfg.WebhookURL != "" { + srv.SetNotifier(notify.NewWebhookNotifier(cfg.WebhookURL, logger)) + } srv.StartHub() httpSrv := &http.Server{ @@ -81,7 +80,9 @@ func serve(addr string) error { logger.Info("shutting down server...") shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second) defer shutdownCancel() - httpSrv.Shutdown(shutdownCtx) + if err := httpSrv.Shutdown(shutdownCtx); err != nil { + logger.Warn("shutdown error", "err", err) + } }() fmt.Printf("Claudomator %s listening on %s\n", version.Version(), addr) diff --git a/internal/cli/serve_test.go b/internal/cli/serve_test.go new file mode 100644 index 0000000..6bd0e8f --- /dev/null +++ b/internal/cli/serve_test.go @@ -0,0 +1,91 @@ +package cli + +import ( + "context" + "log/slog" + "net" + "net/http" + "sync" + "testing" + "time" +) + +// recordHandler captures log records for assertions. +type recordHandler struct { + mu sync.Mutex + records []slog.Record +} + +func (h *recordHandler) Enabled(_ context.Context, _ slog.Level) bool { return true } +func (h *recordHandler) Handle(_ context.Context, r slog.Record) error { + h.mu.Lock() + h.records = append(h.records, r) + h.mu.Unlock() + return nil +} +func (h *recordHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } +func (h *recordHandler) WithGroup(_ string) slog.Handler { return h } +func (h *recordHandler) hasWarn(msg string) bool { + h.mu.Lock() + defer h.mu.Unlock() + for _, r := range h.records { + if r.Level == slog.LevelWarn && r.Message == msg { + return true + } + } + return false +} + +// TestServe_ShutdownError_IsLogged verifies that a shutdown timeout error is +// logged as a warning rather than silently dropped. +func TestServe_ShutdownError_IsLogged(t *testing.T) { + // Start a real listener so we have an address. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + + // Handler that hangs so the active connection prevents clean shutdown. + hang := make(chan struct{}) + mux := http.NewServeMux() + mux.HandleFunc("/hang", func(w http.ResponseWriter, r *http.Request) { + <-hang + }) + + srv := &http.Server{Handler: mux} + + // Serve in background. + go srv.Serve(ln) //nolint:errcheck + + // Open a connection and start a hanging request so the server has an + // active connection when we call Shutdown. + addr := ln.Addr().String() + connReady := make(chan struct{}) + go func() { + req, _ := http.NewRequest(http.MethodGet, "http://"+addr+"/hang", nil) + close(connReady) + http.DefaultClient.Do(req) //nolint:errcheck + }() + <-connReady + // Give the goroutine a moment to establish the request. + time.Sleep(20 * time.Millisecond) + + // Shutdown with an already-expired deadline so it times out immediately. + expiredCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) + defer cancel() + + h := &recordHandler{} + logger := slog.New(h) + + // This is the exact logic from serve.go's shutdown goroutine. + if err := srv.Shutdown(expiredCtx); err != nil { + logger.Warn("shutdown error", "err", err) + } + + // Unblock the hanging handler. + close(hang) + + if !h.hasWarn("shutdown error") { + t.Error("expected shutdown error to be logged as Warn, but it was not") + } +} diff --git a/internal/cli/start.go b/internal/cli/start.go index 6ec09b2..9e66e00 100644 --- a/internal/cli/start.go +++ b/internal/cli/start.go @@ -3,7 +3,8 @@ package cli import ( "encoding/json" "fmt" - "net/http" + "io" + "net/url" "github.com/spf13/cobra" ) @@ -25,15 +26,18 @@ func newStartCmd() *cobra.Command { } func startTask(serverURL, id string) error { - url := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, id) - resp, err := http.Post(url, "application/json", nil) //nolint:noctx + url := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, url.PathEscape(id)) + resp, err := httpClient.Post(url, "application/json", nil) //nolint:noctx if err != nil { return fmt.Errorf("POST %s: %w", url, err) } defer resp.Body.Close() + raw, _ := io.ReadAll(resp.Body) var body map[string]string - _ = json.NewDecoder(resp.Body).Decode(&body) + if err := json.Unmarshal(raw, &body); err != nil { + return fmt.Errorf("server returned invalid JSON (status %d): %s", resp.StatusCode, string(raw)) + } if resp.StatusCode >= 300 { return fmt.Errorf("server returned %d: %s", resp.StatusCode, body["error"]) -- cgit v1.2.3