summaryrefslogtreecommitdiff
path: root/internal/auth/auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/auth/auth.go')
-rw-r--r--internal/auth/auth.go147
1 files changed, 147 insertions, 0 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index a602dad..ce62fa4 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -2,9 +2,11 @@ package auth
import (
"database/sql"
+ "encoding/base64"
"errors"
"time"
+ "github.com/go-webauthn/webauthn/webauthn"
"golang.org/x/crypto/bcrypt"
)
@@ -140,3 +142,148 @@ func (s *Service) EnsureDefaultUser(username, password string) error {
return nil
}
+
+// WebAuthnUser wraps User to implement the webauthn.User interface
+type WebAuthnUser struct {
+ *User
+ credentials []webauthn.Credential
+}
+
+func (u *WebAuthnUser) WebAuthnID() []byte {
+ b := make([]byte, 8)
+ id := u.ID
+ for i := range 8 {
+ b[i] = byte(id >> (i * 8))
+ }
+ return b
+}
+
+func (u *WebAuthnUser) WebAuthnName() string { return u.Username }
+func (u *WebAuthnUser) WebAuthnDisplayName() string { return u.Username }
+func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential { return u.credentials }
+
+// SaveWebAuthnCredential stores a new WebAuthn credential for a user
+func (s *Service) SaveWebAuthnCredential(userID int64, cred *webauthn.Credential, name string) error {
+ credID := base64.RawURLEncoding.EncodeToString(cred.ID)
+ _, err := s.db.Exec(
+ `INSERT INTO webauthn_credentials (id, user_id, public_key, attestation_type, aaguid, sign_count, name)
+ VALUES (?, ?, ?, ?, ?, ?, ?)`,
+ credID, userID, cred.PublicKey, cred.AttestationType, cred.Authenticator.AAGUID, cred.Authenticator.SignCount, name,
+ )
+ return err
+}
+
+// GetWebAuthnCredentials returns all WebAuthn credentials for a user
+func (s *Service) GetWebAuthnCredentials(userID int64) ([]webauthn.Credential, error) {
+ rows, err := s.db.Query(
+ `SELECT id, public_key, attestation_type, aaguid, sign_count FROM webauthn_credentials WHERE user_id = ?`,
+ userID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var creds []webauthn.Credential
+ for rows.Next() {
+ var (
+ credIDStr string
+ publicKey []byte
+ attestationType string
+ aaguid []byte
+ signCount uint32
+ )
+ if err := rows.Scan(&credIDStr, &publicKey, &attestationType, &aaguid, &signCount); err != nil {
+ return nil, err
+ }
+ credID, err := base64.RawURLEncoding.DecodeString(credIDStr)
+ if err != nil {
+ return nil, err
+ }
+ creds = append(creds, webauthn.Credential{
+ ID: credID,
+ PublicKey: publicKey,
+ AttestationType: attestationType,
+ Authenticator: webauthn.Authenticator{
+ AAGUID: aaguid,
+ SignCount: signCount,
+ },
+ })
+ }
+ return creds, rows.Err()
+}
+
+// WebAuthnCredentialInfo holds display info for a stored passkey
+type WebAuthnCredentialInfo struct {
+ ID string
+ Name string
+ CreatedAt time.Time
+}
+
+// GetWebAuthnCredentialInfos returns display info for a user's passkeys
+func (s *Service) GetWebAuthnCredentialInfos(userID int64) ([]WebAuthnCredentialInfo, error) {
+ rows, err := s.db.Query(
+ `SELECT id, name, created_at FROM webauthn_credentials WHERE user_id = ? ORDER BY created_at DESC`,
+ userID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var infos []WebAuthnCredentialInfo
+ for rows.Next() {
+ var info WebAuthnCredentialInfo
+ if err := rows.Scan(&info.ID, &info.Name, &info.CreatedAt); err != nil {
+ return nil, err
+ }
+ infos = append(infos, info)
+ }
+ return infos, rows.Err()
+}
+
+// DeleteWebAuthnCredential removes a WebAuthn credential by its ID
+func (s *Service) DeleteWebAuthnCredential(credID string) error {
+ result, err := s.db.Exec(`DELETE FROM webauthn_credentials WHERE id = ?`, credID)
+ if err != nil {
+ return err
+ }
+ n, _ := result.RowsAffected()
+ if n == 0 {
+ return errors.New("credential not found")
+ }
+ return nil
+}
+
+// UpdateWebAuthnCredentialSignCount updates the sign count after a successful assertion
+func (s *Service) UpdateWebAuthnCredentialSignCount(credID string, signCount uint32) error {
+ _, err := s.db.Exec(`UPDATE webauthn_credentials SET sign_count = ? WHERE id = ?`, signCount, credID)
+ return err
+}
+
+// GetUserWithCredentials returns a WebAuthnUser with loaded credentials
+func (s *Service) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) {
+ user, err := s.GetUserByID(userID)
+ if err != nil {
+ return nil, err
+ }
+ creds, err := s.GetWebAuthnCredentials(userID)
+ if err != nil {
+ return nil, err
+ }
+ return &WebAuthnUser{User: user, credentials: creds}, nil
+}
+
+// FindUserByCredentialID finds the user who owns a given credential ID (for discoverable login)
+func (s *Service) FindUserByCredentialID(credID []byte) (*WebAuthnUser, error) {
+ credIDStr := base64.RawURLEncoding.EncodeToString(credID)
+ var userID int64
+ err := s.db.QueryRow(`SELECT user_id FROM webauthn_credentials WHERE id = ?`, credIDStr).Scan(&userID)
+ if err == sql.ErrNoRows {
+ return nil, ErrUserNotFound
+ }
+ if err != nil {
+ return nil, err
+ }
+ return s.GetUserWithCredentials(userID)
+}