summaryrefslogtreecommitdiff
path: root/internal/cli
diff options
context:
space:
mode:
Diffstat (limited to 'internal/cli')
-rw-r--r--internal/cli/create_test.go125
-rw-r--r--internal/cli/http.go10
-rw-r--r--internal/cli/report.go74
-rw-r--r--internal/cli/report_test.go32
-rw-r--r--internal/cli/root.go21
-rw-r--r--internal/cli/run.go7
-rw-r--r--internal/cli/serve.go15
-rw-r--r--internal/cli/serve_test.go91
-rw-r--r--internal/cli/start.go12
9 files changed, 369 insertions, 18 deletions
diff --git a/internal/cli/create_test.go b/internal/cli/create_test.go
new file mode 100644
index 0000000..22ce6bd
--- /dev/null
+++ b/internal/cli/create_test.go
@@ -0,0 +1,125 @@
+package cli
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestCreateTask_TimesOut(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ select {
+ case <-r.Context().Done():
+ case <-time.After(5 * time.Second): // fallback so srv.Close() never deadlocks
+ }
+ }))
+ defer srv.Close()
+
+ orig := httpClient
+ httpClient = &http.Client{Timeout: 50 * time.Millisecond}
+ defer func() { httpClient = orig }()
+
+ err := createTask(srv.URL, "test", "do something", "", "", "", 1.0, "15m", "normal", false)
+ if err == nil {
+ t.Fatal("expected timeout error, got nil")
+ }
+ if !strings.Contains(err.Error(), "POST /api/tasks") {
+ t.Errorf("expected error mentioning POST /api/tasks, got: %v", err)
+ }
+}
+
+func TestStartTask_EscapesTaskID(t *testing.T) {
+ var capturedPath string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ capturedPath = r.URL.RawPath
+ if capturedPath == "" {
+ capturedPath = r.URL.Path
+ }
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{}`))
+ }))
+ defer srv.Close()
+
+ err := startTask(srv.URL, "task/with/slashes")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if strings.Contains(capturedPath, "task/with/slashes") {
+ t.Errorf("task ID was not escaped; raw path contains unescaped slashes: %s", capturedPath)
+ }
+ if !strings.Contains(capturedPath, "task%2Fwith%2Fslashes") {
+ t.Errorf("expected escaped path segment, got: %s", capturedPath)
+ }
+}
+
+func TestCreateTask_MissingIDField_ReturnsError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{"name":"test"}`)) // no "id" field
+ }))
+ defer srv.Close()
+
+ err := createTask(srv.URL, "test", "do something", "", "", "", 1.0, "15m", "normal", false)
+ if err == nil {
+ t.Fatal("expected error for missing id field, got nil")
+ }
+ if !strings.Contains(err.Error(), "without id") {
+ t.Errorf("expected error mentioning missing id, got: %v", err)
+ }
+}
+
+func TestCreateTask_NonJSONResponse_ReturnsError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusBadGateway)
+ w.Write([]byte(`<html>502 Bad Gateway</html>`))
+ }))
+ defer srv.Close()
+
+ err := createTask(srv.URL, "test", "do something", "", "", "", 1.0, "15m", "normal", false)
+ if err == nil {
+ t.Fatal("expected error for non-JSON response, got nil")
+ }
+ if !strings.Contains(err.Error(), "invalid JSON") {
+ t.Errorf("expected error mentioning invalid JSON, got: %v", err)
+ }
+}
+
+func TestStartTask_NonJSONResponse_ReturnsError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusBadGateway)
+ w.Write([]byte(`<html>502 Bad Gateway</html>`))
+ }))
+ defer srv.Close()
+
+ err := startTask(srv.URL, "task-abc")
+ if err == nil {
+ t.Fatal("expected error for non-JSON response, got nil")
+ }
+ if !strings.Contains(err.Error(), "invalid JSON") {
+ t.Errorf("expected error mentioning invalid JSON, got: %v", err)
+ }
+}
+
+func TestStartTask_TimesOut(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ select {
+ case <-r.Context().Done():
+ case <-time.After(5 * time.Second): // fallback so srv.Close() never deadlocks
+ }
+ }))
+ defer srv.Close()
+
+ orig := httpClient
+ httpClient = &http.Client{Timeout: 50 * time.Millisecond}
+ defer func() { httpClient = orig }()
+
+ err := startTask(srv.URL, "task-abc")
+ if err == nil {
+ t.Fatal("expected timeout error, got nil")
+ }
+ if !strings.Contains(err.Error(), "POST") {
+ t.Errorf("expected error mentioning POST, got: %v", err)
+ }
+}
diff --git a/internal/cli/http.go b/internal/cli/http.go
new file mode 100644
index 0000000..907818a
--- /dev/null
+++ b/internal/cli/http.go
@@ -0,0 +1,10 @@
+package cli
+
+import (
+ "net/http"
+ "time"
+)
+
+const httpTimeout = 30 * time.Second
+
+var httpClient = &http.Client{Timeout: httpTimeout}
diff --git a/internal/cli/report.go b/internal/cli/report.go
new file mode 100644
index 0000000..7f95c80
--- /dev/null
+++ b/internal/cli/report.go
@@ -0,0 +1,74 @@
+package cli
+
+import (
+ "fmt"
+ "os"
+ "time"
+
+ "github.com/spf13/cobra"
+ "github.com/thepeterstone/claudomator/internal/reporter"
+ "github.com/thepeterstone/claudomator/internal/storage"
+)
+
+func newReportCmd() *cobra.Command {
+ var format string
+ var limit int
+ var taskID string
+
+ cmd := &cobra.Command{
+ Use: "report",
+ Short: "Report execution history",
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return runReport(format, limit, taskID)
+ },
+ }
+
+ cmd.Flags().StringVar(&format, "format", "table", "output format: table, json, html")
+ cmd.Flags().IntVar(&limit, "limit", 50, "maximum number of executions to show")
+ cmd.Flags().StringVar(&taskID, "task", "", "filter by task ID")
+
+ return cmd
+}
+
+func runReport(format string, limit int, taskID string) error {
+ var rep reporter.Reporter
+ switch format {
+ case "table", "":
+ rep = &reporter.ConsoleReporter{}
+ case "json":
+ rep = &reporter.JSONReporter{Pretty: true}
+ case "html":
+ rep = &reporter.HTMLReporter{}
+ default:
+ return fmt.Errorf("invalid format %q: must be table, json, or html", format)
+ }
+
+ store, err := storage.Open(cfg.DBPath)
+ if err != nil {
+ return fmt.Errorf("opening db: %w", err)
+ }
+ defer store.Close()
+
+ recent, err := store.ListRecentExecutions(time.Time{}, limit, taskID)
+ if err != nil {
+ return fmt.Errorf("listing executions: %w", err)
+ }
+
+ execs := make([]*storage.Execution, len(recent))
+ for i, r := range recent {
+ e := &storage.Execution{
+ ID: r.ID,
+ TaskID: r.TaskID,
+ Status: r.State,
+ StartTime: r.StartedAt,
+ ExitCode: r.ExitCode,
+ CostUSD: r.CostUSD,
+ }
+ if r.FinishedAt != nil {
+ e.EndTime = *r.FinishedAt
+ }
+ execs[i] = e
+ }
+
+ return rep.Generate(os.Stdout, execs)
+}
diff --git a/internal/cli/report_test.go b/internal/cli/report_test.go
new file mode 100644
index 0000000..3ef96f4
--- /dev/null
+++ b/internal/cli/report_test.go
@@ -0,0 +1,32 @@
+package cli
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestReportCmd_InvalidFormat(t *testing.T) {
+ cmd := newReportCmd()
+ cmd.SetArgs([]string{"--format", "xml"})
+ err := cmd.Execute()
+ if err == nil {
+ t.Fatal("expected error for invalid format, got nil")
+ }
+ if !strings.Contains(err.Error(), "format") {
+ t.Errorf("expected error to mention 'format', got: %v", err)
+ }
+}
+
+func TestReportCmd_DefaultsRegistered(t *testing.T) {
+ cmd := newReportCmd()
+ f := cmd.Flags()
+ if f.Lookup("format") == nil {
+ t.Error("missing --format flag")
+ }
+ if f.Lookup("limit") == nil {
+ t.Error("missing --limit flag")
+ }
+ if f.Lookup("task") == nil {
+ t.Error("missing --task flag")
+ }
+}
diff --git a/internal/cli/root.go b/internal/cli/root.go
index 1a528fb..ab6ac1f 100644
--- a/internal/cli/root.go
+++ b/internal/cli/root.go
@@ -1,12 +1,17 @@
package cli
import (
+ "fmt"
+ "log/slog"
+ "os"
"path/filepath"
"github.com/thepeterstone/claudomator/internal/config"
"github.com/spf13/cobra"
)
+const defaultServerURL = "http://localhost:8484"
+
var (
cfgFile string
verbose bool
@@ -14,7 +19,12 @@ var (
)
func NewRootCmd() *cobra.Command {
- cfg = config.Default()
+ var err error
+ cfg, err = config.Default()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "fatal: %v\n", err)
+ os.Exit(1)
+ }
cmd := &cobra.Command{
Use: "claudomator",
@@ -43,6 +53,7 @@ func NewRootCmd() *cobra.Command {
newLogsCmd(),
newStartCmd(),
newCreateCmd(),
+ newReportCmd(),
)
return cmd
@@ -51,3 +62,11 @@ func NewRootCmd() *cobra.Command {
func Execute() error {
return NewRootCmd().Execute()
}
+
+func newLogger(v bool) *slog.Logger {
+ level := slog.LevelInfo
+ if v {
+ level = slog.LevelDebug
+ }
+ return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
+}
diff --git a/internal/cli/run.go b/internal/cli/run.go
index ed831f5..ebf371c 100644
--- a/internal/cli/run.go
+++ b/internal/cli/run.go
@@ -3,7 +3,6 @@ package cli
import (
"context"
"fmt"
- "log/slog"
"os"
"os/signal"
"syscall"
@@ -71,11 +70,7 @@ func runTasks(file string, parallel int, dryRun bool) error {
}
defer store.Close()
- level := slog.LevelInfo
- if verbose {
- level = slog.LevelDebug
- }
- logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
+ logger := newLogger(verbose)
runner := &executor.ClaudeRunner{
BinaryPath: cfg.ClaudeBinaryPath,
diff --git a/internal/cli/serve.go b/internal/cli/serve.go
index 363e276..cd5bfce 100644
--- a/internal/cli/serve.go
+++ b/internal/cli/serve.go
@@ -3,7 +3,6 @@ package cli
import (
"context"
"fmt"
- "log/slog"
"net/http"
"os"
"os/signal"
@@ -12,6 +11,7 @@ import (
"github.com/thepeterstone/claudomator/internal/api"
"github.com/thepeterstone/claudomator/internal/executor"
+ "github.com/thepeterstone/claudomator/internal/notify"
"github.com/thepeterstone/claudomator/internal/storage"
"github.com/thepeterstone/claudomator/internal/version"
"github.com/spf13/cobra"
@@ -44,11 +44,7 @@ func serve(addr string) error {
}
defer store.Close()
- level := slog.LevelInfo
- if verbose {
- level = slog.LevelDebug
- }
- logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
+ logger := newLogger(verbose)
apiURL := "http://localhost" + addr
if len(addr) > 0 && addr[0] != ':' {
@@ -63,6 +59,9 @@ func serve(addr string) error {
pool := executor.NewPool(cfg.MaxConcurrent, runner, store, logger)
srv := api.NewServer(store, pool, logger, cfg.ClaudeBinaryPath)
+ if cfg.WebhookURL != "" {
+ srv.SetNotifier(notify.NewWebhookNotifier(cfg.WebhookURL, logger))
+ }
srv.StartHub()
httpSrv := &http.Server{
@@ -81,7 +80,9 @@ func serve(addr string) error {
logger.Info("shutting down server...")
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second)
defer shutdownCancel()
- httpSrv.Shutdown(shutdownCtx)
+ if err := httpSrv.Shutdown(shutdownCtx); err != nil {
+ logger.Warn("shutdown error", "err", err)
+ }
}()
fmt.Printf("Claudomator %s listening on %s\n", version.Version(), addr)
diff --git a/internal/cli/serve_test.go b/internal/cli/serve_test.go
new file mode 100644
index 0000000..6bd0e8f
--- /dev/null
+++ b/internal/cli/serve_test.go
@@ -0,0 +1,91 @@
+package cli
+
+import (
+ "context"
+ "log/slog"
+ "net"
+ "net/http"
+ "sync"
+ "testing"
+ "time"
+)
+
+// recordHandler captures log records for assertions.
+type recordHandler struct {
+ mu sync.Mutex
+ records []slog.Record
+}
+
+func (h *recordHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
+func (h *recordHandler) Handle(_ context.Context, r slog.Record) error {
+ h.mu.Lock()
+ h.records = append(h.records, r)
+ h.mu.Unlock()
+ return nil
+}
+func (h *recordHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
+func (h *recordHandler) WithGroup(_ string) slog.Handler { return h }
+func (h *recordHandler) hasWarn(msg string) bool {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ for _, r := range h.records {
+ if r.Level == slog.LevelWarn && r.Message == msg {
+ return true
+ }
+ }
+ return false
+}
+
+// TestServe_ShutdownError_IsLogged verifies that a shutdown timeout error is
+// logged as a warning rather than silently dropped.
+func TestServe_ShutdownError_IsLogged(t *testing.T) {
+ // Start a real listener so we have an address.
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("listen: %v", err)
+ }
+
+ // Handler that hangs so the active connection prevents clean shutdown.
+ hang := make(chan struct{})
+ mux := http.NewServeMux()
+ mux.HandleFunc("/hang", func(w http.ResponseWriter, r *http.Request) {
+ <-hang
+ })
+
+ srv := &http.Server{Handler: mux}
+
+ // Serve in background.
+ go srv.Serve(ln) //nolint:errcheck
+
+ // Open a connection and start a hanging request so the server has an
+ // active connection when we call Shutdown.
+ addr := ln.Addr().String()
+ connReady := make(chan struct{})
+ go func() {
+ req, _ := http.NewRequest(http.MethodGet, "http://"+addr+"/hang", nil)
+ close(connReady)
+ http.DefaultClient.Do(req) //nolint:errcheck
+ }()
+ <-connReady
+ // Give the goroutine a moment to establish the request.
+ time.Sleep(20 * time.Millisecond)
+
+ // Shutdown with an already-expired deadline so it times out immediately.
+ expiredCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second))
+ defer cancel()
+
+ h := &recordHandler{}
+ logger := slog.New(h)
+
+ // This is the exact logic from serve.go's shutdown goroutine.
+ if err := srv.Shutdown(expiredCtx); err != nil {
+ logger.Warn("shutdown error", "err", err)
+ }
+
+ // Unblock the hanging handler.
+ close(hang)
+
+ if !h.hasWarn("shutdown error") {
+ t.Error("expected shutdown error to be logged as Warn, but it was not")
+ }
+}
diff --git a/internal/cli/start.go b/internal/cli/start.go
index 6ec09b2..9e66e00 100644
--- a/internal/cli/start.go
+++ b/internal/cli/start.go
@@ -3,7 +3,8 @@ package cli
import (
"encoding/json"
"fmt"
- "net/http"
+ "io"
+ "net/url"
"github.com/spf13/cobra"
)
@@ -25,15 +26,18 @@ func newStartCmd() *cobra.Command {
}
func startTask(serverURL, id string) error {
- url := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, id)
- resp, err := http.Post(url, "application/json", nil) //nolint:noctx
+ url := fmt.Sprintf("%s/api/tasks/%s/run", serverURL, url.PathEscape(id))
+ resp, err := httpClient.Post(url, "application/json", nil) //nolint:noctx
if err != nil {
return fmt.Errorf("POST %s: %w", url, err)
}
defer resp.Body.Close()
+ raw, _ := io.ReadAll(resp.Body)
var body map[string]string
- _ = json.NewDecoder(resp.Body).Decode(&body)
+ if err := json.Unmarshal(raw, &body); err != nil {
+ return fmt.Errorf("server returned invalid JSON (status %d): %s", resp.StatusCode, string(raw))
+ }
if resp.StatusCode >= 300 {
return fmt.Errorf("server returned %d: %s", resp.StatusCode, body["error"])