diff options
Diffstat (limited to 'internal/auth/auth.go')
| -rw-r--r-- | internal/auth/auth.go | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go index a602dad..ce62fa4 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -2,9 +2,11 @@ package auth import ( "database/sql" + "encoding/base64" "errors" "time" + "github.com/go-webauthn/webauthn/webauthn" "golang.org/x/crypto/bcrypt" ) @@ -140,3 +142,148 @@ func (s *Service) EnsureDefaultUser(username, password string) error { return nil } + +// WebAuthnUser wraps User to implement the webauthn.User interface +type WebAuthnUser struct { + *User + credentials []webauthn.Credential +} + +func (u *WebAuthnUser) WebAuthnID() []byte { + b := make([]byte, 8) + id := u.ID + for i := range 8 { + b[i] = byte(id >> (i * 8)) + } + return b +} + +func (u *WebAuthnUser) WebAuthnName() string { return u.Username } +func (u *WebAuthnUser) WebAuthnDisplayName() string { return u.Username } +func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential { return u.credentials } + +// SaveWebAuthnCredential stores a new WebAuthn credential for a user +func (s *Service) SaveWebAuthnCredential(userID int64, cred *webauthn.Credential, name string) error { + credID := base64.RawURLEncoding.EncodeToString(cred.ID) + _, err := s.db.Exec( + `INSERT INTO webauthn_credentials (id, user_id, public_key, attestation_type, aaguid, sign_count, name) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + credID, userID, cred.PublicKey, cred.AttestationType, cred.Authenticator.AAGUID, cred.Authenticator.SignCount, name, + ) + return err +} + +// GetWebAuthnCredentials returns all WebAuthn credentials for a user +func (s *Service) GetWebAuthnCredentials(userID int64) ([]webauthn.Credential, error) { + rows, err := s.db.Query( + `SELECT id, public_key, attestation_type, aaguid, sign_count FROM webauthn_credentials WHERE user_id = ?`, + userID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var creds []webauthn.Credential + for rows.Next() { + var ( + credIDStr string + publicKey []byte + attestationType string + aaguid []byte + signCount uint32 + ) + if err := rows.Scan(&credIDStr, &publicKey, &attestationType, &aaguid, &signCount); err != nil { + return nil, err + } + credID, err := base64.RawURLEncoding.DecodeString(credIDStr) + if err != nil { + return nil, err + } + creds = append(creds, webauthn.Credential{ + ID: credID, + PublicKey: publicKey, + AttestationType: attestationType, + Authenticator: webauthn.Authenticator{ + AAGUID: aaguid, + SignCount: signCount, + }, + }) + } + return creds, rows.Err() +} + +// WebAuthnCredentialInfo holds display info for a stored passkey +type WebAuthnCredentialInfo struct { + ID string + Name string + CreatedAt time.Time +} + +// GetWebAuthnCredentialInfos returns display info for a user's passkeys +func (s *Service) GetWebAuthnCredentialInfos(userID int64) ([]WebAuthnCredentialInfo, error) { + rows, err := s.db.Query( + `SELECT id, name, created_at FROM webauthn_credentials WHERE user_id = ? ORDER BY created_at DESC`, + userID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var infos []WebAuthnCredentialInfo + for rows.Next() { + var info WebAuthnCredentialInfo + if err := rows.Scan(&info.ID, &info.Name, &info.CreatedAt); err != nil { + return nil, err + } + infos = append(infos, info) + } + return infos, rows.Err() +} + +// DeleteWebAuthnCredential removes a WebAuthn credential by its ID +func (s *Service) DeleteWebAuthnCredential(credID string) error { + result, err := s.db.Exec(`DELETE FROM webauthn_credentials WHERE id = ?`, credID) + if err != nil { + return err + } + n, _ := result.RowsAffected() + if n == 0 { + return errors.New("credential not found") + } + return nil +} + +// UpdateWebAuthnCredentialSignCount updates the sign count after a successful assertion +func (s *Service) UpdateWebAuthnCredentialSignCount(credID string, signCount uint32) error { + _, err := s.db.Exec(`UPDATE webauthn_credentials SET sign_count = ? WHERE id = ?`, signCount, credID) + return err +} + +// GetUserWithCredentials returns a WebAuthnUser with loaded credentials +func (s *Service) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) { + user, err := s.GetUserByID(userID) + if err != nil { + return nil, err + } + creds, err := s.GetWebAuthnCredentials(userID) + if err != nil { + return nil, err + } + return &WebAuthnUser{User: user, credentials: creds}, nil +} + +// FindUserByCredentialID finds the user who owns a given credential ID (for discoverable login) +func (s *Service) FindUserByCredentialID(credID []byte) (*WebAuthnUser, error) { + credIDStr := base64.RawURLEncoding.EncodeToString(credID) + var userID int64 + err := s.db.QueryRow(`SELECT user_id FROM webauthn_credentials WHERE id = ?`, credIDStr).Scan(&userID) + if err == sql.ErrNoRows { + return nil, ErrUserNotFound + } + if err != nil { + return nil, err + } + return s.GetUserWithCredentials(userID) +} |
