diff options
Diffstat (limited to 'internal/store/sqlite.go')
| -rw-r--r-- | internal/store/sqlite.go | 345 |
1 files changed, 345 insertions, 0 deletions
diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go index 396ac54..b324e9f 100644 --- a/internal/store/sqlite.go +++ b/internal/store/sqlite.go @@ -796,3 +796,348 @@ func (s *Store) GetCardsByDateRange(start, end time.Time) ([]models.Card, error) return cards, rows.Err() } + +// Agent operations + +// CreateAgentSession creates a new pending agent session +func (s *Store) CreateAgentSession(session *models.AgentSession) error { + result, err := s.db.Exec(` + INSERT INTO agent_sessions (request_token, agent_name, agent_id, status, expires_at) + VALUES (?, ?, ?, 'pending', ?) + `, session.RequestToken, session.AgentName, session.AgentID, session.ExpiresAt) + if err != nil { + return err + } + id, err := result.LastInsertId() + if err != nil { + return err + } + session.ID = id + return nil +} + +// GetAgentSessionByRequestToken retrieves a session by request token +func (s *Store) GetAgentSessionByRequestToken(token string) (*models.AgentSession, error) { + var session models.AgentSession + var sessionToken sql.NullString + var sessionExpiresAt sql.NullTime + + err := s.db.QueryRow(` + SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at, session_token, session_expires_at + FROM agent_sessions + WHERE request_token = ? + `, token).Scan( + &session.ID, + &session.RequestToken, + &session.AgentName, + &session.AgentID, + &session.Status, + &session.CreatedAt, + &session.ExpiresAt, + &sessionToken, + &sessionExpiresAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if sessionToken.Valid { + session.SessionToken = sessionToken.String + } + if sessionExpiresAt.Valid { + session.SessionExpiresAt = &sessionExpiresAt.Time + } + return &session, nil +} + +// GetPendingAgentSessionByAgentID retrieves an existing pending session for an agent +func (s *Store) GetPendingAgentSessionByAgentID(agentID string) (*models.AgentSession, error) { + var session models.AgentSession + + err := s.db.QueryRow(` + SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at + FROM agent_sessions + WHERE agent_id = ? AND status = 'pending' AND expires_at > datetime('now', 'localtime') + ORDER BY created_at DESC + LIMIT 1 + `, agentID).Scan( + &session.ID, + &session.RequestToken, + &session.AgentName, + &session.AgentID, + &session.Status, + &session.CreatedAt, + &session.ExpiresAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &session, nil +} + +// GetAgentSessionBySessionToken retrieves a session by session token +func (s *Store) GetAgentSessionBySessionToken(token string) (*models.AgentSession, error) { + var session models.AgentSession + var sessionToken sql.NullString + var sessionExpiresAt sql.NullTime + + err := s.db.QueryRow(` + SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at, session_token, session_expires_at + FROM agent_sessions + WHERE session_token = ? AND status = 'approved' + `, token).Scan( + &session.ID, + &session.RequestToken, + &session.AgentName, + &session.AgentID, + &session.Status, + &session.CreatedAt, + &session.ExpiresAt, + &sessionToken, + &sessionExpiresAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if sessionToken.Valid { + session.SessionToken = sessionToken.String + } + if sessionExpiresAt.Valid { + session.SessionExpiresAt = &sessionExpiresAt.Time + } + return &session, nil +} + +// ApproveAgentSession approves a pending session +func (s *Store) ApproveAgentSession(requestToken, sessionToken string, sessionExpiresAt time.Time) error { + result, err := s.db.Exec(` + UPDATE agent_sessions + SET status = 'approved', session_token = ?, session_expires_at = ? + WHERE request_token = ? AND status = 'pending' + `, sessionToken, sessionExpiresAt, requestToken) + if err != nil { + return err + } + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return errors.New("session not found or already processed") + } + return nil +} + +// DenyAgentSession denies a pending session +func (s *Store) DenyAgentSession(requestToken string) error { + result, err := s.db.Exec(` + UPDATE agent_sessions + SET status = 'denied' + WHERE request_token = ? AND status = 'pending' + `, requestToken) + if err != nil { + return err + } + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return errors.New("session not found or already processed") + } + return nil +} + +// GetPendingAgentSessions retrieves all unexpired pending sessions +func (s *Store) GetPendingAgentSessions() ([]models.AgentSession, error) { + rows, err := s.db.Query(` + SELECT id, request_token, agent_name, agent_id, status, created_at, expires_at + FROM agent_sessions + WHERE status = 'pending' AND expires_at > datetime('now', 'localtime') + ORDER BY created_at DESC + `) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var sessions []models.AgentSession + for rows.Next() { + var session models.AgentSession + if err := rows.Scan( + &session.ID, + &session.RequestToken, + &session.AgentName, + &session.AgentID, + &session.Status, + &session.CreatedAt, + &session.ExpiresAt, + ); err != nil { + return nil, err + } + sessions = append(sessions, session) + } + return sessions, rows.Err() +} + +// InvalidatePreviousAgentSessions marks previous sessions for an agent as expired +func (s *Store) InvalidatePreviousAgentSessions(agentID string) error { + _, err := s.db.Exec(` + UPDATE agent_sessions + SET status = 'expired' + WHERE agent_id = ? AND status IN ('pending', 'approved') + `, agentID) + return err +} + +// GetAgentByAgentID retrieves an agent by their agent_id (UUID) +func (s *Store) GetAgentByAgentID(agentID string) (*models.Agent, error) { + var agent models.Agent + var lastSeen sql.NullTime + + err := s.db.QueryRow(` + SELECT id, name, agent_id, created_at, last_seen, trusted + FROM agents + WHERE agent_id = ? + `, agentID).Scan( + &agent.ID, + &agent.Name, + &agent.AgentID, + &agent.CreatedAt, + &lastSeen, + &agent.Trusted, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if lastSeen.Valid { + agent.LastSeen = &lastSeen.Time + } + return &agent, nil +} + +// GetAgentByName retrieves an agent by name +func (s *Store) GetAgentByName(name string) (*models.Agent, error) { + var agent models.Agent + var lastSeen sql.NullTime + + err := s.db.QueryRow(` + SELECT id, name, agent_id, created_at, last_seen, trusted + FROM agents + WHERE name = ? + `, name).Scan( + &agent.ID, + &agent.Name, + &agent.AgentID, + &agent.CreatedAt, + &lastSeen, + &agent.Trusted, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if lastSeen.Valid { + agent.LastSeen = &lastSeen.Time + } + return &agent, nil +} + +// CreateOrUpdateAgent creates or updates an agent record +func (s *Store) CreateOrUpdateAgent(name, agentID string) error { + _, err := s.db.Exec(` + INSERT INTO agents (name, agent_id, last_seen, trusted) + VALUES (?, ?, datetime('now'), 1) + ON CONFLICT(agent_id) DO UPDATE SET + name = excluded.name, + last_seen = datetime('now') + `, name, agentID) + return err +} + +// UpdateAgentLastSeen updates the last_seen timestamp for an agent +func (s *Store) UpdateAgentLastSeen(agentID string) error { + _, err := s.db.Exec(` + UPDATE agents SET last_seen = datetime('now') + WHERE agent_id = ? + `, agentID) + return err +} + +// GetAllAgents retrieves all agents +func (s *Store) GetAllAgents() ([]models.Agent, error) { + rows, err := s.db.Query(` + SELECT id, name, agent_id, created_at, last_seen, trusted + FROM agents + ORDER BY last_seen DESC NULLS LAST + `) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var agents []models.Agent + for rows.Next() { + var agent models.Agent + var lastSeen sql.NullTime + if err := rows.Scan( + &agent.ID, + &agent.Name, + &agent.AgentID, + &agent.CreatedAt, + &lastSeen, + &agent.Trusted, + ); err != nil { + return nil, err + } + if lastSeen.Valid { + agent.LastSeen = &lastSeen.Time + } + agents = append(agents, agent) + } + return agents, rows.Err() +} + +// RevokeAgent sets trusted=false for an agent +func (s *Store) RevokeAgent(agentID string) error { + _, err := s.db.Exec(`UPDATE agents SET trusted = 0 WHERE agent_id = ?`, agentID) + return err +} + +// CheckAgentTrust determines trust level for an agent request +func (s *Store) CheckAgentTrust(name, agentID string) (models.AgentTrustLevel, error) { + // Check if this exact agent_id is known + existingByID, err := s.GetAgentByAgentID(agentID) + if err != nil { + return "", err + } + + // Check if this name is known with a different ID + existingByName, err := s.GetAgentByName(name) + if err != nil { + return "", err + } + + if existingByID != nil && existingByID.Name == name && existingByID.Trusted { + return models.AgentTrustRecognized, nil + } + + if existingByName != nil && existingByName.AgentID != agentID { + return models.AgentTrustSuspicious, nil + } + + return models.AgentTrustNew, nil +} |
