summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/auth/middleware.go9
-rw-r--r--internal/handlers/handlers_test.go9
-rw-r--r--internal/middleware/security.go2
-rw-r--r--internal/models/types.go51
-rw-r--r--internal/store/sqlite.go345
5 files changed, 410 insertions, 6 deletions
diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go
index ecdde82..78f3b53 100644
--- a/internal/auth/middleware.go
+++ b/internal/auth/middleware.go
@@ -6,6 +6,7 @@ import (
"crypto/subtle"
"encoding/base64"
"net/http"
+ "strings"
"github.com/alexedwards/scs/v2"
)
@@ -63,6 +64,12 @@ func (m *Middleware) ClearSession(r *http.Request) error {
// CSRFProtect checks for a valid CSRF token on state-changing requests
func (m *Middleware) CSRFProtect(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Skip CSRF for agent API endpoints (they use token-based auth, not cookies)
+ if strings.HasPrefix(r.URL.Path, "/agent/") {
+ next.ServeHTTP(w, r)
+ return
+ }
+
// Ensure a token exists in the session
if !m.sessions.Exists(r.Context(), SessionKeyCSRF) {
token, err := generateToken()
@@ -78,7 +85,7 @@ func (m *Middleware) CSRFProtect(next http.Handler) http.Handler {
// Check token for state-changing methods
if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" {
requestToken := r.Header.Get("X-CSRF-Token")
-
+
if requestToken == "" {
requestToken = r.FormValue("csrf_token")
}
diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go
index d863546..3367ef6 100644
--- a/internal/handlers/handlers_test.go
+++ b/internal/handlers/handlers_test.go
@@ -75,10 +75,11 @@ func loadTestTemplates(t *testing.T) *template.Template {
}
}
- // Parse partials
- tmpl, err = tmpl.ParseGlob(filepath.Join("web", "templates", "partials", "*.html"))
- if err != nil {
- tmpl, _ = tmpl.ParseGlob(filepath.Join("..", "..", "web", "templates", "partials", "*.html"))
+ // Parse partials - don't reassign tmpl if parsing fails
+ if parsed, err := tmpl.ParseGlob(filepath.Join("web", "templates", "partials", "*.html")); err == nil {
+ tmpl = parsed
+ } else if parsed, err := tmpl.ParseGlob(filepath.Join("..", "..", "web", "templates", "partials", "*.html")); err == nil {
+ tmpl = parsed
}
return tmpl
diff --git a/internal/middleware/security.go b/internal/middleware/security.go
index e048645..8d1e619 100644
--- a/internal/middleware/security.go
+++ b/internal/middleware/security.go
@@ -29,7 +29,7 @@ func SecurityHeaders(debug bool) func(http.Handler) http.Handler {
"style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; "+
"font-src 'self' https://fonts.gstatic.com; "+
"frame-src https://www.youtube.com https://embed.windy.com; "+
- "connect-src 'self'")
+ "connect-src 'self' wss: ws:")
next.ServeHTTP(w, r)
})
diff --git a/internal/models/types.go b/internal/models/types.go
index 4bf8462..5214bf8 100644
--- a/internal/models/types.go
+++ b/internal/models/types.go
@@ -146,3 +146,54 @@ type DashboardData struct {
LastUpdated time.Time `json:"last_updated"`
Errors []string `json:"errors,omitempty"`
}
+
+// Agent represents a registered external agent
+type Agent struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ AgentID string `json:"agent_id"` // UUID from agent
+ CreatedAt time.Time `json:"created_at"`
+ LastSeen *time.Time `json:"last_seen,omitempty"`
+ Trusted bool `json:"trusted"`
+}
+
+// AgentSession represents a pending request or active session
+type AgentSession struct {
+ ID int64 `json:"id"`
+ RequestToken string `json:"request_token"`
+ AgentName string `json:"agent_name"`
+ AgentID string `json:"agent_id"`
+ Status string `json:"status"` // pending, approved, denied, expired
+ CreatedAt time.Time `json:"created_at"`
+ ExpiresAt time.Time `json:"expires_at"`
+ SessionToken string `json:"session_token,omitempty"`
+ SessionExpiresAt *time.Time `json:"session_expires_at,omitempty"`
+}
+
+// AgentAuthRequest is the request body for agent auth
+type AgentAuthRequest struct {
+ Name string `json:"name"`
+ AgentID string `json:"agent_id"`
+}
+
+// AgentAuthResponse is the response for auth request
+type AgentAuthResponse struct {
+ RequestToken string `json:"request_token"`
+ Status string `json:"status"`
+}
+
+// AgentPollResponse is the response for poll endpoint
+type AgentPollResponse struct {
+ Status string `json:"status"`
+ SessionToken string `json:"session_token,omitempty"`
+ ExpiresAt *time.Time `json:"expires_at,omitempty"`
+}
+
+// AgentTrustLevel indicates the trust state of an agent
+type AgentTrustLevel string
+
+const (
+ AgentTrustNew AgentTrustLevel = "new"
+ AgentTrustRecognized AgentTrustLevel = "recognized"
+ AgentTrustSuspicious AgentTrustLevel = "suspicious"
+)
diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go
index 396ac54..b324e9f 100644
--- a/internal/store/sqlite.go
+++ b/internal/store/sqlite.go
@@ -796,3 +796,348 @@ func (s *Store) GetCardsByDateRange(start, end time.Time) ([]models.Card, error)
return cards, rows.Err()
}
+
+// Agent operations
+
+// CreateAgentSession creates a new pending agent session
+func (s *Store) CreateAgentSession(session *models.AgentSession) error {
+ result, err := s.db.Exec(`
+ INSERT INTO agent_sessions (request_token, agent_name, agent_id, status, expires_at)
+ VALUES (?, ?, ?, 'pending', ?)
+ `, session.RequestToken, session.AgentName, session.AgentID, session.ExpiresAt)
+ if err != nil {
+ return err
+ }
+ id, err := result.LastInsertId()
+ if err != nil {
+ return err
+ }
+ session.ID = id
+ return nil
+}
+
+// GetAgentSessionByRequestToken retrieves a session by request token
+func (s *Store) GetAgentSessionByRequestToken(token string) (*models.AgentSession, error) {
+ var session models.AgentSession
+ var sessionToken sql.NullString
+ var sessionExpiresAt sql.NullTime
+
+ err := s.db.QueryRow(`
+ SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at, session_token, session_expires_at
+ FROM agent_sessions
+ WHERE request_token = ?
+ `, token).Scan(
+ &session.ID,
+ &session.RequestToken,
+ &session.AgentName,
+ &session.AgentID,
+ &session.Status,
+ &session.CreatedAt,
+ &session.ExpiresAt,
+ &sessionToken,
+ &sessionExpiresAt,
+ )
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if sessionToken.Valid {
+ session.SessionToken = sessionToken.String
+ }
+ if sessionExpiresAt.Valid {
+ session.SessionExpiresAt = &sessionExpiresAt.Time
+ }
+ return &session, nil
+}
+
+// GetPendingAgentSessionByAgentID retrieves an existing pending session for an agent
+func (s *Store) GetPendingAgentSessionByAgentID(agentID string) (*models.AgentSession, error) {
+ var session models.AgentSession
+
+ err := s.db.QueryRow(`
+ SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at
+ FROM agent_sessions
+ WHERE agent_id = ? AND status = 'pending' AND expires_at > datetime('now', 'localtime')
+ ORDER BY created_at DESC
+ LIMIT 1
+ `, agentID).Scan(
+ &session.ID,
+ &session.RequestToken,
+ &session.AgentName,
+ &session.AgentID,
+ &session.Status,
+ &session.CreatedAt,
+ &session.ExpiresAt,
+ )
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ return &session, nil
+}
+
+// GetAgentSessionBySessionToken retrieves a session by session token
+func (s *Store) GetAgentSessionBySessionToken(token string) (*models.AgentSession, error) {
+ var session models.AgentSession
+ var sessionToken sql.NullString
+ var sessionExpiresAt sql.NullTime
+
+ err := s.db.QueryRow(`
+ SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at, session_token, session_expires_at
+ FROM agent_sessions
+ WHERE session_token = ? AND status = 'approved'
+ `, token).Scan(
+ &session.ID,
+ &session.RequestToken,
+ &session.AgentName,
+ &session.AgentID,
+ &session.Status,
+ &session.CreatedAt,
+ &session.ExpiresAt,
+ &sessionToken,
+ &sessionExpiresAt,
+ )
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if sessionToken.Valid {
+ session.SessionToken = sessionToken.String
+ }
+ if sessionExpiresAt.Valid {
+ session.SessionExpiresAt = &sessionExpiresAt.Time
+ }
+ return &session, nil
+}
+
+// ApproveAgentSession approves a pending session
+func (s *Store) ApproveAgentSession(requestToken, sessionToken string, sessionExpiresAt time.Time) error {
+ result, err := s.db.Exec(`
+ UPDATE agent_sessions
+ SET status = 'approved', session_token = ?, session_expires_at = ?
+ WHERE request_token = ? AND status = 'pending'
+ `, sessionToken, sessionExpiresAt, requestToken)
+ if err != nil {
+ return err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return errors.New("session not found or already processed")
+ }
+ return nil
+}
+
+// DenyAgentSession denies a pending session
+func (s *Store) DenyAgentSession(requestToken string) error {
+ result, err := s.db.Exec(`
+ UPDATE agent_sessions
+ SET status = 'denied'
+ WHERE request_token = ? AND status = 'pending'
+ `, requestToken)
+ if err != nil {
+ return err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return errors.New("session not found or already processed")
+ }
+ return nil
+}
+
+// GetPendingAgentSessions retrieves all unexpired pending sessions
+func (s *Store) GetPendingAgentSessions() ([]models.AgentSession, error) {
+ rows, err := s.db.Query(`
+ SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at
+ FROM agent_sessions
+ WHERE status = 'pending' AND expires_at > datetime('now', 'localtime')
+ ORDER BY created_at DESC
+ `)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ var sessions []models.AgentSession
+ for rows.Next() {
+ var session models.AgentSession
+ if err := rows.Scan(
+ &session.ID,
+ &session.RequestToken,
+ &session.AgentName,
+ &session.AgentID,
+ &session.Status,
+ &session.CreatedAt,
+ &session.ExpiresAt,
+ ); err != nil {
+ return nil, err
+ }
+ sessions = append(sessions, session)
+ }
+ return sessions, rows.Err()
+}
+
+// InvalidatePreviousAgentSessions marks previous sessions for an agent as expired
+func (s *Store) InvalidatePreviousAgentSessions(agentID string) error {
+ _, err := s.db.Exec(`
+ UPDATE agent_sessions
+ SET status = 'expired'
+ WHERE agent_id = ? AND status IN ('pending', 'approved')
+ `, agentID)
+ return err
+}
+
+// GetAgentByAgentID retrieves an agent by their agent_id (UUID)
+func (s *Store) GetAgentByAgentID(agentID string) (*models.Agent, error) {
+ var agent models.Agent
+ var lastSeen sql.NullTime
+
+ err := s.db.QueryRow(`
+ SELECT id, name, agent_id, created_at, last_seen, trusted
+ FROM agents
+ WHERE agent_id = ?
+ `, agentID).Scan(
+ &agent.ID,
+ &agent.Name,
+ &agent.AgentID,
+ &agent.CreatedAt,
+ &lastSeen,
+ &agent.Trusted,
+ )
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if lastSeen.Valid {
+ agent.LastSeen = &lastSeen.Time
+ }
+ return &agent, nil
+}
+
+// GetAgentByName retrieves an agent by name
+func (s *Store) GetAgentByName(name string) (*models.Agent, error) {
+ var agent models.Agent
+ var lastSeen sql.NullTime
+
+ err := s.db.QueryRow(`
+ SELECT id, name, agent_id, created_at, last_seen, trusted
+ FROM agents
+ WHERE name = ?
+ `, name).Scan(
+ &agent.ID,
+ &agent.Name,
+ &agent.AgentID,
+ &agent.CreatedAt,
+ &lastSeen,
+ &agent.Trusted,
+ )
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if lastSeen.Valid {
+ agent.LastSeen = &lastSeen.Time
+ }
+ return &agent, nil
+}
+
+// CreateOrUpdateAgent creates or updates an agent record
+func (s *Store) CreateOrUpdateAgent(name, agentID string) error {
+ _, err := s.db.Exec(`
+ INSERT INTO agents (name, agent_id, last_seen, trusted)
+ VALUES (?, ?, datetime('now'), 1)
+ ON CONFLICT(agent_id) DO UPDATE SET
+ name = excluded.name,
+ last_seen = datetime('now')
+ `, name, agentID)
+ return err
+}
+
+// UpdateAgentLastSeen updates the last_seen timestamp for an agent
+func (s *Store) UpdateAgentLastSeen(agentID string) error {
+ _, err := s.db.Exec(`
+ UPDATE agents SET last_seen = datetime('now')
+ WHERE agent_id = ?
+ `, agentID)
+ return err
+}
+
+// GetAllAgents retrieves all agents
+func (s *Store) GetAllAgents() ([]models.Agent, error) {
+ rows, err := s.db.Query(`
+ SELECT id, name, agent_id, created_at, last_seen, trusted
+ FROM agents
+ ORDER BY last_seen DESC NULLS LAST
+ `)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ var agents []models.Agent
+ for rows.Next() {
+ var agent models.Agent
+ var lastSeen sql.NullTime
+ if err := rows.Scan(
+ &agent.ID,
+ &agent.Name,
+ &agent.AgentID,
+ &agent.CreatedAt,
+ &lastSeen,
+ &agent.Trusted,
+ ); err != nil {
+ return nil, err
+ }
+ if lastSeen.Valid {
+ agent.LastSeen = &lastSeen.Time
+ }
+ agents = append(agents, agent)
+ }
+ return agents, rows.Err()
+}
+
+// RevokeAgent sets trusted=false for an agent
+func (s *Store) RevokeAgent(agentID string) error {
+ _, err := s.db.Exec(`UPDATE agents SET trusted = 0 WHERE agent_id = ?`, agentID)
+ return err
+}
+
+// CheckAgentTrust determines trust level for an agent request
+func (s *Store) CheckAgentTrust(name, agentID string) (models.AgentTrustLevel, error) {
+ // Check if this exact agent_id is known
+ existingByID, err := s.GetAgentByAgentID(agentID)
+ if err != nil {
+ return "", err
+ }
+
+ // Check if this name is known with a different ID
+ existingByName, err := s.GetAgentByName(name)
+ if err != nil {
+ return "", err
+ }
+
+ if existingByID != nil && existingByID.Name == name && existingByID.Trusted {
+ return models.AgentTrustRecognized, nil
+ }
+
+ if existingByName != nil && existingByName.AgentID != agentID {
+ return models.AgentTrustSuspicious, nil
+ }
+
+ return models.AgentTrustNew, nil
+}