diff options
| author | Peter Stone <thepeterstone@gmail.com> | 2026-01-28 22:19:28 -1000 |
|---|---|---|
| committer | Peter Stone <thepeterstone@gmail.com> | 2026-01-28 22:19:28 -1000 |
| commit | 05b1930e04ac222d73ffb2f45c1b1febb69f893d (patch) | |
| tree | bc451d72b5265ff044c4655ed90685c601688b6d /internal | |
| parent | 058ff7d699f088edb851336928dd3eea2934cc07 (diff) | |
Add Agent Context API for external agent integration
Phase 1: Authentication and read-only context
- POST /agent/auth/request - request access with name + agent_id
- GET /agent/auth/poll - poll for approval status
- POST /agent/auth/approve|deny - user approval (browser auth required)
- GET /agent/context - 7-day timeline context (agent session required)
Phase 1.5: Browser-only agent endpoints (HTML pages)
- GET /agent/web/request - request page with token
- GET /agent/web/status - status page with polling
- GET /agent/web/context - context page with timeline data
WebSocket notifications:
- GET /ws/notifications - push agent requests to browsers
- Approval modal with trust indicators and countdown timer
Database:
- agents table for registered agent tracking
- agent_sessions table for pending/active sessions
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/auth/middleware.go | 9 | ||||
| -rw-r--r-- | internal/handlers/handlers_test.go | 9 | ||||
| -rw-r--r-- | internal/middleware/security.go | 2 | ||||
| -rw-r--r-- | internal/models/types.go | 51 | ||||
| -rw-r--r-- | internal/store/sqlite.go | 345 |
5 files changed, 410 insertions, 6 deletions
diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index ecdde82..78f3b53 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -6,6 +6,7 @@ import ( "crypto/subtle" "encoding/base64" "net/http" + "strings" "github.com/alexedwards/scs/v2" ) @@ -63,6 +64,12 @@ func (m *Middleware) ClearSession(r *http.Request) error { // CSRFProtect checks for a valid CSRF token on state-changing requests func (m *Middleware) CSRFProtect(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip CSRF for agent API endpoints (they use token-based auth, not cookies) + if strings.HasPrefix(r.URL.Path, "/agent/") { + next.ServeHTTP(w, r) + return + } + // Ensure a token exists in the session if !m.sessions.Exists(r.Context(), SessionKeyCSRF) { token, err := generateToken() @@ -78,7 +85,7 @@ func (m *Middleware) CSRFProtect(next http.Handler) http.Handler { // Check token for state-changing methods if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" { requestToken := r.Header.Get("X-CSRF-Token") - + if requestToken == "" { requestToken = r.FormValue("csrf_token") } diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go index d863546..3367ef6 100644 --- a/internal/handlers/handlers_test.go +++ b/internal/handlers/handlers_test.go @@ -75,10 +75,11 @@ func loadTestTemplates(t *testing.T) *template.Template { } } - // Parse partials - tmpl, err = tmpl.ParseGlob(filepath.Join("web", "templates", "partials", "*.html")) - if err != nil { - tmpl, _ = tmpl.ParseGlob(filepath.Join("..", "..", "web", "templates", "partials", "*.html")) + // Parse partials - don't reassign tmpl if parsing fails + if parsed, err := tmpl.ParseGlob(filepath.Join("web", "templates", "partials", "*.html")); err == nil { + tmpl = parsed + } else if parsed, err := tmpl.ParseGlob(filepath.Join("..", "..", "web", "templates", "partials", "*.html")); err == nil { + tmpl = parsed } return tmpl diff --git a/internal/middleware/security.go b/internal/middleware/security.go index e048645..8d1e619 100644 --- a/internal/middleware/security.go +++ b/internal/middleware/security.go @@ -29,7 +29,7 @@ func SecurityHeaders(debug bool) func(http.Handler) http.Handler { "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; "+ "font-src 'self' https://fonts.gstatic.com; "+ "frame-src https://www.youtube.com https://embed.windy.com; "+ - "connect-src 'self'") + "connect-src 'self' wss: ws:") next.ServeHTTP(w, r) }) diff --git a/internal/models/types.go b/internal/models/types.go index 4bf8462..5214bf8 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -146,3 +146,54 @@ type DashboardData struct { LastUpdated time.Time `json:"last_updated"` Errors []string `json:"errors,omitempty"` } + +// Agent represents a registered external agent +type Agent struct { + ID int64 `json:"id"` + Name string `json:"name"` + AgentID string `json:"agent_id"` // UUID from agent + CreatedAt time.Time `json:"created_at"` + LastSeen *time.Time `json:"last_seen,omitempty"` + Trusted bool `json:"trusted"` +} + +// AgentSession represents a pending request or active session +type AgentSession struct { + ID int64 `json:"id"` + RequestToken string `json:"request_token"` + AgentName string `json:"agent_name"` + AgentID string `json:"agent_id"` + Status string `json:"status"` // pending, approved, denied, expired + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + SessionToken string `json:"session_token,omitempty"` + SessionExpiresAt *time.Time `json:"session_expires_at,omitempty"` +} + +// AgentAuthRequest is the request body for agent auth +type AgentAuthRequest struct { + Name string `json:"name"` + AgentID string `json:"agent_id"` +} + +// AgentAuthResponse is the response for auth request +type AgentAuthResponse struct { + RequestToken string `json:"request_token"` + Status string `json:"status"` +} + +// AgentPollResponse is the response for poll endpoint +type AgentPollResponse struct { + Status string `json:"status"` + SessionToken string `json:"session_token,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` +} + +// AgentTrustLevel indicates the trust state of an agent +type AgentTrustLevel string + +const ( + AgentTrustNew AgentTrustLevel = "new" + AgentTrustRecognized AgentTrustLevel = "recognized" + AgentTrustSuspicious AgentTrustLevel = "suspicious" +) diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go index 396ac54..b324e9f 100644 --- a/internal/store/sqlite.go +++ b/internal/store/sqlite.go @@ -796,3 +796,348 @@ func (s *Store) GetCardsByDateRange(start, end time.Time) ([]models.Card, error) return cards, rows.Err() } + +// Agent operations + +// CreateAgentSession creates a new pending agent session +func (s *Store) CreateAgentSession(session *models.AgentSession) error { + result, err := s.db.Exec(` + INSERT INTO agent_sessions (request_token, agent_name, agent_id, status, expires_at) + VALUES (?, ?, ?, 'pending', ?) + `, session.RequestToken, session.AgentName, session.AgentID, session.ExpiresAt) + if err != nil { + return err + } + id, err := result.LastInsertId() + if err != nil { + return err + } + session.ID = id + return nil +} + +// GetAgentSessionByRequestToken retrieves a session by request token +func (s *Store) GetAgentSessionByRequestToken(token string) (*models.AgentSession, error) { + var session models.AgentSession + var sessionToken sql.NullString + var sessionExpiresAt sql.NullTime + + err := s.db.QueryRow(` + SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at, session_token, session_expires_at + FROM agent_sessions + WHERE request_token = ? + `, token).Scan( + &session.ID, + &session.RequestToken, + &session.AgentName, + &session.AgentID, + &session.Status, + &session.CreatedAt, + &session.ExpiresAt, + &sessionToken, + &sessionExpiresAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if sessionToken.Valid { + session.SessionToken = sessionToken.String + } + if sessionExpiresAt.Valid { + session.SessionExpiresAt = &sessionExpiresAt.Time + } + return &session, nil +} + +// GetPendingAgentSessionByAgentID retrieves an existing pending session for an agent +func (s *Store) GetPendingAgentSessionByAgentID(agentID string) (*models.AgentSession, error) { + var session models.AgentSession + + err := s.db.QueryRow(` + SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at + FROM agent_sessions + WHERE agent_id = ? AND status = 'pending' AND expires_at > datetime('now', 'localtime') + ORDER BY created_at DESC + LIMIT 1 + `, agentID).Scan( + &session.ID, + &session.RequestToken, + &session.AgentName, + &session.AgentID, + &session.Status, + &session.CreatedAt, + &session.ExpiresAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &session, nil +} + +// GetAgentSessionBySessionToken retrieves a session by session token +func (s *Store) GetAgentSessionBySessionToken(token string) (*models.AgentSession, error) { + var session models.AgentSession + var sessionToken sql.NullString + var sessionExpiresAt sql.NullTime + + err := s.db.QueryRow(` + SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at, session_token, session_expires_at + FROM agent_sessions + WHERE session_token = ? AND status = 'approved' + `, token).Scan( + &session.ID, + &session.RequestToken, + &session.AgentName, + &session.AgentID, + &session.Status, + &session.CreatedAt, + &session.ExpiresAt, + &sessionToken, + &sessionExpiresAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if sessionToken.Valid { + session.SessionToken = sessionToken.String + } + if sessionExpiresAt.Valid { + session.SessionExpiresAt = &sessionExpiresAt.Time + } + return &session, nil +} + +// ApproveAgentSession approves a pending session +func (s *Store) ApproveAgentSession(requestToken, sessionToken string, sessionExpiresAt time.Time) error { + result, err := s.db.Exec(` + UPDATE agent_sessions + SET status = 'approved', session_token = ?, session_expires_at = ? + WHERE request_token = ? AND status = 'pending' + `, sessionToken, sessionExpiresAt, requestToken) + if err != nil { + return err + } + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return errors.New("session not found or already processed") + } + return nil +} + +// DenyAgentSession denies a pending session +func (s *Store) DenyAgentSession(requestToken string) error { + result, err := s.db.Exec(` + UPDATE agent_sessions + SET status = 'denied' + WHERE request_token = ? AND status = 'pending' + `, requestToken) + if err != nil { + return err + } + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return errors.New("session not found or already processed") + } + return nil +} + +// GetPendingAgentSessions retrieves all unexpired pending sessions +func (s *Store) GetPendingAgentSessions() ([]models.AgentSession, error) { + rows, err := s.db.Query(` + SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at + FROM agent_sessions + WHERE status = 'pending' AND expires_at > datetime('now', 'localtime') + ORDER BY created_at DESC + `) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var sessions []models.AgentSession + for rows.Next() { + var session models.AgentSession + if err := rows.Scan( + &session.ID, + &session.RequestToken, + &session.AgentName, + &session.AgentID, + &session.Status, + &session.CreatedAt, + &session.ExpiresAt, + ); err != nil { + return nil, err + } + sessions = append(sessions, session) + } + return sessions, rows.Err() +} + +// InvalidatePreviousAgentSessions marks previous sessions for an agent as expired +func (s *Store) InvalidatePreviousAgentSessions(agentID string) error { + _, err := s.db.Exec(` + UPDATE agent_sessions + SET status = 'expired' + WHERE agent_id = ? AND status IN ('pending', 'approved') + `, agentID) + return err +} + +// GetAgentByAgentID retrieves an agent by their agent_id (UUID) +func (s *Store) GetAgentByAgentID(agentID string) (*models.Agent, error) { + var agent models.Agent + var lastSeen sql.NullTime + + err := s.db.QueryRow(` + SELECT id, name, agent_id, created_at, last_seen, trusted + FROM agents + WHERE agent_id = ? + `, agentID).Scan( + &agent.ID, + &agent.Name, + &agent.AgentID, + &agent.CreatedAt, + &lastSeen, + &agent.Trusted, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if lastSeen.Valid { + agent.LastSeen = &lastSeen.Time + } + return &agent, nil +} + +// GetAgentByName retrieves an agent by name +func (s *Store) GetAgentByName(name string) (*models.Agent, error) { + var agent models.Agent + var lastSeen sql.NullTime + + err := s.db.QueryRow(` + SELECT id, name, agent_id, created_at, last_seen, trusted + FROM agents + WHERE name = ? + `, name).Scan( + &agent.ID, + &agent.Name, + &agent.AgentID, + &agent.CreatedAt, + &lastSeen, + &agent.Trusted, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if lastSeen.Valid { + agent.LastSeen = &lastSeen.Time + } + return &agent, nil +} + +// CreateOrUpdateAgent creates or updates an agent record +func (s *Store) CreateOrUpdateAgent(name, agentID string) error { + _, err := s.db.Exec(` + INSERT INTO agents (name, agent_id, last_seen, trusted) + VALUES (?, ?, datetime('now'), 1) + ON CONFLICT(agent_id) DO UPDATE SET + name = excluded.name, + last_seen = datetime('now') + `, name, agentID) + return err +} + +// UpdateAgentLastSeen updates the last_seen timestamp for an agent +func (s *Store) UpdateAgentLastSeen(agentID string) error { + _, err := s.db.Exec(` + UPDATE agents SET last_seen = datetime('now') + WHERE agent_id = ? + `, agentID) + return err +} + +// GetAllAgents retrieves all agents +func (s *Store) GetAllAgents() ([]models.Agent, error) { + rows, err := s.db.Query(` + SELECT id, name, agent_id, created_at, last_seen, trusted + FROM agents + ORDER BY last_seen DESC NULLS LAST + `) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var agents []models.Agent + for rows.Next() { + var agent models.Agent + var lastSeen sql.NullTime + if err := rows.Scan( + &agent.ID, + &agent.Name, + &agent.AgentID, + &agent.CreatedAt, + &lastSeen, + &agent.Trusted, + ); err != nil { + return nil, err + } + if lastSeen.Valid { + agent.LastSeen = &lastSeen.Time + } + agents = append(agents, agent) + } + return agents, rows.Err() +} + +// RevokeAgent sets trusted=false for an agent +func (s *Store) RevokeAgent(agentID string) error { + _, err := s.db.Exec(`UPDATE agents SET trusted = 0 WHERE agent_id = ?`, agentID) + return err +} + +// CheckAgentTrust determines trust level for an agent request +func (s *Store) CheckAgentTrust(name, agentID string) (models.AgentTrustLevel, error) { + // Check if this exact agent_id is known + existingByID, err := s.GetAgentByAgentID(agentID) + if err != nil { + return "", err + } + + // Check if this name is known with a different ID + existingByName, err := s.GetAgentByName(name) + if err != nil { + return "", err + } + + if existingByID != nil && existingByID.Name == name && existingByID.Trusted { + return models.AgentTrustRecognized, nil + } + + if existingByName != nil && existingByName.AgentID != agentID { + return models.AgentTrustSuspicious, nil + } + + return models.AgentTrustNew, nil +} |
