package auth import ( "database/sql" "encoding/base64" "errors" "time" "github.com/go-webauthn/webauthn/webauthn" "golang.org/x/crypto/bcrypt" ) var ( ErrInvalidCredentials = errors.New("invalid username or password") ErrUserNotFound = errors.New("user not found") ErrUserExists = errors.New("username already exists") ) // User represents an authenticated user type User struct { ID int64 Username string PasswordHash string CreatedAt time.Time } // Service handles authentication operations type Service struct { db *sql.DB } // NewService creates a new auth service func NewService(db *sql.DB) *Service { return &Service{db: db} } // Authenticate verifies username and password, returns user ID if valid func (s *Service) Authenticate(username, password string) (*User, error) { user, err := s.GetUserByUsername(username) if err != nil { if errors.Is(err, ErrUserNotFound) { return nil, ErrInvalidCredentials } return nil, err } if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { return nil, ErrInvalidCredentials } return user, nil } // GetUserByUsername retrieves a user by username func (s *Service) GetUserByUsername(username string) (*User, error) { var user User err := s.db.QueryRow( `SELECT id, username, password_hash, created_at FROM users WHERE username = ?`, username, ).Scan(&user.ID, &user.Username, &user.PasswordHash, &user.CreatedAt) if err == sql.ErrNoRows { return nil, ErrUserNotFound } if err != nil { return nil, err } return &user, nil } // GetUserByID retrieves a user by ID func (s *Service) GetUserByID(id int64) (*User, error) { var user User err := s.db.QueryRow( `SELECT id, username, password_hash, created_at FROM users WHERE id = ?`, id, ).Scan(&user.ID, &user.Username, &user.PasswordHash, &user.CreatedAt) if err == sql.ErrNoRows { return nil, ErrUserNotFound } if err != nil { return nil, err } return &user, nil } // CreateUser creates a new user with the given username and password func (s *Service) CreateUser(username, password string) (*User, error) { // Check if user exists _, err := s.GetUserByUsername(username) if err == nil { return nil, ErrUserExists } if !errors.Is(err, ErrUserNotFound) { return nil, err } // Hash password hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, err } // Insert user result, err := s.db.Exec( `INSERT INTO users (username, password_hash) VALUES (?, ?)`, username, string(hash), ) if err != nil { return nil, err } id, err := result.LastInsertId() if err != nil { return nil, err } return s.GetUserByID(id) } // UserCount returns the number of users in the database func (s *Service) UserCount() (int, error) { var count int err := s.db.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count) return count, err } // EnsureDefaultUser creates a default admin user if no users exist func (s *Service) EnsureDefaultUser(username, password string) error { count, err := s.UserCount() if err != nil { return err } if count == 0 { _, err = s.CreateUser(username, password) return err } return nil } // WebAuthnUser wraps User to implement the webauthn.User interface type WebAuthnUser struct { *User credentials []webauthn.Credential } func (u *WebAuthnUser) WebAuthnID() []byte { b := make([]byte, 8) id := u.ID for i := range 8 { b[i] = byte(id >> (i * 8)) } return b } func (u *WebAuthnUser) WebAuthnName() string { return u.Username } func (u *WebAuthnUser) WebAuthnDisplayName() string { return u.Username } func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential { return u.credentials } // SaveWebAuthnCredential stores a new WebAuthn credential for a user func (s *Service) SaveWebAuthnCredential(userID int64, cred *webauthn.Credential, name string) error { credID := base64.RawURLEncoding.EncodeToString(cred.ID) _, err := s.db.Exec( `INSERT INTO webauthn_credentials (id, user_id, public_key, attestation_type, aaguid, sign_count, name) VALUES (?, ?, ?, ?, ?, ?, ?)`, credID, userID, cred.PublicKey, cred.AttestationType, cred.Authenticator.AAGUID, cred.Authenticator.SignCount, name, ) return err } // GetWebAuthnCredentials returns all WebAuthn credentials for a user func (s *Service) GetWebAuthnCredentials(userID int64) ([]webauthn.Credential, error) { rows, err := s.db.Query( `SELECT id, public_key, attestation_type, aaguid, sign_count FROM webauthn_credentials WHERE user_id = ?`, userID, ) if err != nil { return nil, err } defer rows.Close() var creds []webauthn.Credential for rows.Next() { var ( credIDStr string publicKey []byte attestationType string aaguid []byte signCount uint32 ) if err := rows.Scan(&credIDStr, &publicKey, &attestationType, &aaguid, &signCount); err != nil { return nil, err } credID, err := base64.RawURLEncoding.DecodeString(credIDStr) if err != nil { return nil, err } creds = append(creds, webauthn.Credential{ ID: credID, PublicKey: publicKey, AttestationType: attestationType, Authenticator: webauthn.Authenticator{ AAGUID: aaguid, SignCount: signCount, }, }) } return creds, rows.Err() } // WebAuthnCredentialInfo holds display info for a stored passkey type WebAuthnCredentialInfo struct { ID string Name string CreatedAt time.Time } // GetWebAuthnCredentialInfos returns display info for a user's passkeys func (s *Service) GetWebAuthnCredentialInfos(userID int64) ([]WebAuthnCredentialInfo, error) { rows, err := s.db.Query( `SELECT id, name, created_at FROM webauthn_credentials WHERE user_id = ? ORDER BY created_at DESC`, userID, ) if err != nil { return nil, err } defer rows.Close() var infos []WebAuthnCredentialInfo for rows.Next() { var info WebAuthnCredentialInfo if err := rows.Scan(&info.ID, &info.Name, &info.CreatedAt); err != nil { return nil, err } infos = append(infos, info) } return infos, rows.Err() } // DeleteWebAuthnCredential removes a WebAuthn credential by its ID func (s *Service) DeleteWebAuthnCredential(credID string) error { result, err := s.db.Exec(`DELETE FROM webauthn_credentials WHERE id = ?`, credID) if err != nil { return err } n, _ := result.RowsAffected() if n == 0 { return errors.New("credential not found") } return nil } // UpdateWebAuthnCredentialSignCount updates the sign count after a successful assertion func (s *Service) UpdateWebAuthnCredentialSignCount(credID string, signCount uint32) error { _, err := s.db.Exec(`UPDATE webauthn_credentials SET sign_count = ? WHERE id = ?`, signCount, credID) return err } // GetUserWithCredentials returns a WebAuthnUser with loaded credentials func (s *Service) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) { user, err := s.GetUserByID(userID) if err != nil { return nil, err } creds, err := s.GetWebAuthnCredentials(userID) if err != nil { return nil, err } return &WebAuthnUser{User: user, credentials: creds}, nil } // FindUserByCredentialID finds the user who owns a given credential ID (for discoverable login) func (s *Service) FindUserByCredentialID(credID []byte) (*WebAuthnUser, error) { credIDStr := base64.RawURLEncoding.EncodeToString(credID) var userID int64 err := s.db.QueryRow(`SELECT user_id FROM webauthn_credentials WHERE id = ?`, credIDStr).Scan(&userID) if err == sql.ErrNoRows { return nil, ErrUserNotFound } if err != nil { return nil, err } return s.GetUserWithCredentials(userID) }