summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/auth/auth.go147
-rw-r--r--internal/auth/handlers.go259
-rw-r--r--internal/auth/handlers_test.go4
-rw-r--r--internal/config/config.go8
4 files changed, 407 insertions, 11 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index a602dad..ce62fa4 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -2,9 +2,11 @@ package auth
import (
"database/sql"
+ "encoding/base64"
"errors"
"time"
+ "github.com/go-webauthn/webauthn/webauthn"
"golang.org/x/crypto/bcrypt"
)
@@ -140,3 +142,148 @@ func (s *Service) EnsureDefaultUser(username, password string) error {
return nil
}
+
+// WebAuthnUser wraps User to implement the webauthn.User interface
+type WebAuthnUser struct {
+ *User
+ credentials []webauthn.Credential
+}
+
+func (u *WebAuthnUser) WebAuthnID() []byte {
+ b := make([]byte, 8)
+ id := u.ID
+ for i := range 8 {
+ b[i] = byte(id >> (i * 8))
+ }
+ return b
+}
+
+func (u *WebAuthnUser) WebAuthnName() string { return u.Username }
+func (u *WebAuthnUser) WebAuthnDisplayName() string { return u.Username }
+func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential { return u.credentials }
+
+// SaveWebAuthnCredential stores a new WebAuthn credential for a user
+func (s *Service) SaveWebAuthnCredential(userID int64, cred *webauthn.Credential, name string) error {
+ credID := base64.RawURLEncoding.EncodeToString(cred.ID)
+ _, err := s.db.Exec(
+ `INSERT INTO webauthn_credentials (id, user_id, public_key, attestation_type, aaguid, sign_count, name)
+ VALUES (?, ?, ?, ?, ?, ?, ?)`,
+ credID, userID, cred.PublicKey, cred.AttestationType, cred.Authenticator.AAGUID, cred.Authenticator.SignCount, name,
+ )
+ return err
+}
+
+// GetWebAuthnCredentials returns all WebAuthn credentials for a user
+func (s *Service) GetWebAuthnCredentials(userID int64) ([]webauthn.Credential, error) {
+ rows, err := s.db.Query(
+ `SELECT id, public_key, attestation_type, aaguid, sign_count FROM webauthn_credentials WHERE user_id = ?`,
+ userID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var creds []webauthn.Credential
+ for rows.Next() {
+ var (
+ credIDStr string
+ publicKey []byte
+ attestationType string
+ aaguid []byte
+ signCount uint32
+ )
+ if err := rows.Scan(&credIDStr, &publicKey, &attestationType, &aaguid, &signCount); err != nil {
+ return nil, err
+ }
+ credID, err := base64.RawURLEncoding.DecodeString(credIDStr)
+ if err != nil {
+ return nil, err
+ }
+ creds = append(creds, webauthn.Credential{
+ ID: credID,
+ PublicKey: publicKey,
+ AttestationType: attestationType,
+ Authenticator: webauthn.Authenticator{
+ AAGUID: aaguid,
+ SignCount: signCount,
+ },
+ })
+ }
+ return creds, rows.Err()
+}
+
+// WebAuthnCredentialInfo holds display info for a stored passkey
+type WebAuthnCredentialInfo struct {
+ ID string
+ Name string
+ CreatedAt time.Time
+}
+
+// GetWebAuthnCredentialInfos returns display info for a user's passkeys
+func (s *Service) GetWebAuthnCredentialInfos(userID int64) ([]WebAuthnCredentialInfo, error) {
+ rows, err := s.db.Query(
+ `SELECT id, name, created_at FROM webauthn_credentials WHERE user_id = ? ORDER BY created_at DESC`,
+ userID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var infos []WebAuthnCredentialInfo
+ for rows.Next() {
+ var info WebAuthnCredentialInfo
+ if err := rows.Scan(&info.ID, &info.Name, &info.CreatedAt); err != nil {
+ return nil, err
+ }
+ infos = append(infos, info)
+ }
+ return infos, rows.Err()
+}
+
+// DeleteWebAuthnCredential removes a WebAuthn credential by its ID
+func (s *Service) DeleteWebAuthnCredential(credID string) error {
+ result, err := s.db.Exec(`DELETE FROM webauthn_credentials WHERE id = ?`, credID)
+ if err != nil {
+ return err
+ }
+ n, _ := result.RowsAffected()
+ if n == 0 {
+ return errors.New("credential not found")
+ }
+ return nil
+}
+
+// UpdateWebAuthnCredentialSignCount updates the sign count after a successful assertion
+func (s *Service) UpdateWebAuthnCredentialSignCount(credID string, signCount uint32) error {
+ _, err := s.db.Exec(`UPDATE webauthn_credentials SET sign_count = ? WHERE id = ?`, signCount, credID)
+ return err
+}
+
+// GetUserWithCredentials returns a WebAuthnUser with loaded credentials
+func (s *Service) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) {
+ user, err := s.GetUserByID(userID)
+ if err != nil {
+ return nil, err
+ }
+ creds, err := s.GetWebAuthnCredentials(userID)
+ if err != nil {
+ return nil, err
+ }
+ return &WebAuthnUser{User: user, credentials: creds}, nil
+}
+
+// FindUserByCredentialID finds the user who owns a given credential ID (for discoverable login)
+func (s *Service) FindUserByCredentialID(credID []byte) (*WebAuthnUser, error) {
+ credIDStr := base64.RawURLEncoding.EncodeToString(credID)
+ var userID int64
+ err := s.db.QueryRow(`SELECT user_id FROM webauthn_credentials WHERE id = ?`, credIDStr).Scan(&userID)
+ if err == sql.ErrNoRows {
+ return nil, ErrUserNotFound
+ }
+ if err != nil {
+ return nil, err
+ }
+ return s.GetUserWithCredentials(userID)
+}
diff --git a/internal/auth/handlers.go b/internal/auth/handlers.go
index c690d29..78595d0 100644
--- a/internal/auth/handlers.go
+++ b/internal/auth/handlers.go
@@ -1,11 +1,16 @@
package auth
import (
+ "encoding/base64"
+ "encoding/json"
"html/template"
"log"
"net/http"
"github.com/alexedwards/scs/v2"
+ "github.com/go-chi/chi/v5"
+ "github.com/go-webauthn/webauthn/protocol"
+ "github.com/go-webauthn/webauthn/webauthn"
)
// Handlers provides HTTP handlers for authentication
@@ -14,15 +19,17 @@ type Handlers struct {
sessions *scs.SessionManager
middleware *Middleware
templates *template.Template
+ webauthn *webauthn.WebAuthn // nil if WebAuthn is not configured
}
// NewHandlers creates new auth handlers
-func NewHandlers(service *Service, sessions *scs.SessionManager, templates *template.Template) *Handlers {
+func NewHandlers(service *Service, sessions *scs.SessionManager, templates *template.Template, wa *webauthn.WebAuthn) *Handlers {
return &Handlers{
service: service,
sessions: sessions,
middleware: NewMiddleware(sessions),
templates: templates,
+ webauthn: wa,
}
}
@@ -40,11 +47,13 @@ func (h *Handlers) HandleLoginPage(w http.ResponseWriter, r *http.Request) {
}
data := struct {
- Error string
- CSRFToken string
+ Error string
+ CSRFToken string
+ WebAuthnEnabled bool
}{
- Error: "",
- CSRFToken: h.middleware.GetCSRFToken(r),
+ Error: "",
+ CSRFToken: h.middleware.GetCSRFToken(r),
+ WebAuthnEnabled: h.webauthn != nil,
}
if err := h.templates.ExecuteTemplate(w, "login.html", data); err != nil {
@@ -100,11 +109,13 @@ func (h *Handlers) HandleLogout(w http.ResponseWriter, r *http.Request) {
func (h *Handlers) renderLoginError(w http.ResponseWriter, r *http.Request, errorMsg string) {
data := struct {
- Error string
- CSRFToken string
+ Error string
+ CSRFToken string
+ WebAuthnEnabled bool
}{
- Error: errorMsg,
- CSRFToken: h.middleware.GetCSRFToken(r),
+ Error: errorMsg,
+ CSRFToken: h.middleware.GetCSRFToken(r),
+ WebAuthnEnabled: h.webauthn != nil,
}
w.WriteHeader(http.StatusUnauthorized)
@@ -113,3 +124,233 @@ func (h *Handlers) renderLoginError(w http.ResponseWriter, r *http.Request, erro
log.Printf("Error rendering login template: %v", err)
}
}
+
+// HandleListPasskeys returns the passkeys list partial for the settings page
+func (h *Handlers) HandleListPasskeys(w http.ResponseWriter, r *http.Request) {
+ if h.webauthn == nil {
+ http.Error(w, "WebAuthn not configured", http.StatusNotFound)
+ return
+ }
+ userID := h.middleware.GetUserID(r)
+ infos, err := h.service.GetWebAuthnCredentialInfos(userID)
+ if err != nil {
+ log.Printf("Error getting passkeys: %v", err)
+ http.Error(w, "Failed to load passkeys", http.StatusInternalServerError)
+ return
+ }
+
+ data := struct {
+ Passkeys []WebAuthnCredentialInfo
+ CSRFToken string
+ }{
+ Passkeys: infos,
+ CSRFToken: h.middleware.GetCSRFToken(r),
+ }
+
+ if err := h.templates.ExecuteTemplate(w, "passkeys_list.html", data); err != nil {
+ log.Printf("Error rendering passkeys list: %v", err)
+ http.Error(w, "Failed to render template", http.StatusInternalServerError)
+ }
+}
+
+// HandlePasskeyRegisterBegin starts the WebAuthn registration ceremony
+func (h *Handlers) HandlePasskeyRegisterBegin(w http.ResponseWriter, r *http.Request) {
+ if h.webauthn == nil {
+ jsonError(w, "WebAuthn not configured", http.StatusNotFound)
+ return
+ }
+
+ userID := h.middleware.GetUserID(r)
+ user, err := h.service.GetUserWithCredentials(userID)
+ if err != nil {
+ log.Printf("Error getting user for passkey registration: %v", err)
+ jsonError(w, "Failed to get user", http.StatusInternalServerError)
+ return
+ }
+
+ options, session, err := h.webauthn.BeginRegistration(user,
+ webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementPreferred),
+ )
+ if err != nil {
+ log.Printf("Error beginning passkey registration: %v", err)
+ jsonError(w, "Failed to start registration", http.StatusInternalServerError)
+ return
+ }
+
+ // Store session data for the finish step
+ sessionBytes, err := json.Marshal(session)
+ if err != nil {
+ jsonError(w, "Failed to store session", http.StatusInternalServerError)
+ return
+ }
+ h.sessions.Put(r.Context(), "webauthn_registration", string(sessionBytes))
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(options)
+}
+
+// HandlePasskeyRegisterFinish completes the WebAuthn registration ceremony
+func (h *Handlers) HandlePasskeyRegisterFinish(w http.ResponseWriter, r *http.Request) {
+ if h.webauthn == nil {
+ jsonError(w, "WebAuthn not configured", http.StatusNotFound)
+ return
+ }
+
+ userID := h.middleware.GetUserID(r)
+ user, err := h.service.GetUserWithCredentials(userID)
+ if err != nil {
+ log.Printf("Error getting user for passkey registration finish: %v", err)
+ jsonError(w, "Failed to get user", http.StatusInternalServerError)
+ return
+ }
+
+ // Recover session data
+ sessionJSON := h.sessions.GetString(r.Context(), "webauthn_registration")
+ if sessionJSON == "" {
+ jsonError(w, "No registration session found", http.StatusBadRequest)
+ return
+ }
+ h.sessions.Remove(r.Context(), "webauthn_registration")
+
+ var session webauthn.SessionData
+ if err := json.Unmarshal([]byte(sessionJSON), &session); err != nil {
+ jsonError(w, "Invalid session data", http.StatusBadRequest)
+ return
+ }
+
+ cred, err := h.webauthn.FinishRegistration(user, session, r)
+ if err != nil {
+ log.Printf("Error finishing passkey registration: %v", err)
+ jsonError(w, "Registration failed", http.StatusBadRequest)
+ return
+ }
+
+ // Get the friendly name from query param or use default
+ name := r.URL.Query().Get("name")
+ if name == "" {
+ name = "Passkey"
+ }
+
+ if err := h.service.SaveWebAuthnCredential(userID, cred, name); err != nil {
+ log.Printf("Error saving passkey: %v", err)
+ jsonError(w, "Failed to save passkey", http.StatusInternalServerError)
+ return
+ }
+
+ log.Printf("Passkey registered for user %d: %s", userID, name)
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
+}
+
+// HandleDeletePasskey removes a passkey
+func (h *Handlers) HandleDeletePasskey(w http.ResponseWriter, r *http.Request) {
+ if h.webauthn == nil {
+ http.Error(w, "WebAuthn not configured", http.StatusNotFound)
+ return
+ }
+
+ // Extract passkey ID from URL path
+ credID := chi.URLParam(r, "id")
+ if credID == "" {
+ http.Error(w, "Missing passkey ID", http.StatusBadRequest)
+ return
+ }
+
+ if err := h.service.DeleteWebAuthnCredential(credID); err != nil {
+ log.Printf("Error deleting passkey: %v", err)
+ http.Error(w, "Failed to delete passkey", http.StatusInternalServerError)
+ return
+ }
+
+ log.Printf("Passkey deleted: %s", credID)
+ w.WriteHeader(http.StatusOK)
+}
+
+// HandlePasskeyLoginBegin starts the WebAuthn authentication ceremony (discoverable credentials)
+func (h *Handlers) HandlePasskeyLoginBegin(w http.ResponseWriter, r *http.Request) {
+ if h.webauthn == nil {
+ jsonError(w, "WebAuthn not configured", http.StatusNotFound)
+ return
+ }
+
+ options, session, err := h.webauthn.BeginDiscoverableLogin()
+ if err != nil {
+ log.Printf("Error beginning passkey login: %v", err)
+ jsonError(w, "Failed to start login", http.StatusInternalServerError)
+ return
+ }
+
+ sessionBytes, err := json.Marshal(session)
+ if err != nil {
+ jsonError(w, "Failed to store session", http.StatusInternalServerError)
+ return
+ }
+ h.sessions.Put(r.Context(), "webauthn_login", string(sessionBytes))
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(options)
+}
+
+// HandlePasskeyLoginFinish completes the WebAuthn authentication ceremony
+func (h *Handlers) HandlePasskeyLoginFinish(w http.ResponseWriter, r *http.Request) {
+ if h.webauthn == nil {
+ jsonError(w, "WebAuthn not configured", http.StatusNotFound)
+ return
+ }
+
+ sessionJSON := h.sessions.GetString(r.Context(), "webauthn_login")
+ if sessionJSON == "" {
+ jsonError(w, "No login session found", http.StatusBadRequest)
+ return
+ }
+ h.sessions.Remove(r.Context(), "webauthn_login")
+
+ var session webauthn.SessionData
+ if err := json.Unmarshal([]byte(sessionJSON), &session); err != nil {
+ jsonError(w, "Invalid session data", http.StatusBadRequest)
+ return
+ }
+
+ // User discovery handler - called by the library to find the user from the credential
+ userHandler := func(rawID, userHandle []byte) (webauthn.User, error) {
+ return h.service.FindUserByCredentialID(rawID)
+ }
+
+ cred, err := h.webauthn.FinishDiscoverableLogin(userHandler, session, r)
+ if err != nil {
+ log.Printf("Error finishing passkey login: %v", err)
+ jsonError(w, "Login failed", http.StatusUnauthorized)
+ return
+ }
+
+ // Find the user who owns this credential to create a session
+ credIDStr := base64.RawURLEncoding.EncodeToString(cred.ID)
+ user, err := h.service.FindUserByCredentialID(cred.ID)
+ if err != nil {
+ log.Printf("Error finding user for credential: %v", err)
+ jsonError(w, "Login failed", http.StatusUnauthorized)
+ return
+ }
+
+ // Update sign count
+ if err := h.service.UpdateWebAuthnCredentialSignCount(credIDStr, cred.Authenticator.SignCount); err != nil {
+ log.Printf("Error updating sign count: %v", err)
+ }
+
+ // Create session (same as password login)
+ if err := h.sessions.RenewToken(r.Context()); err != nil {
+ jsonError(w, "Failed to create session", http.StatusInternalServerError)
+ return
+ }
+ h.middleware.SetUserID(r, user.ID)
+
+ log.Printf("User %s logged in via passkey", user.Username)
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]string{"status": "ok", "redirect": "/"})
+}
+
+func jsonError(w http.ResponseWriter, msg string, code int) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(code)
+ json.NewEncoder(w).Encode(map[string]string{"error": msg})
+}
diff --git a/internal/auth/handlers_test.go b/internal/auth/handlers_test.go
index 128ae80..aed0e90 100644
--- a/internal/auth/handlers_test.go
+++ b/internal/auth/handlers_test.go
@@ -26,7 +26,7 @@ func TestHandleLogin(t *testing.T) {
sessionManager := scs.New()
templates := template.Must(template.New("login.html").Parse("{{.Error}}"))
- handlers := NewHandlers(service, sessionManager, templates)
+ handlers := NewHandlers(service, sessionManager, templates, nil)
// Setup mock user
password := "password"
@@ -74,7 +74,7 @@ func TestHandleLogin_InvalidCredentials(t *testing.T) {
sessionManager := scs.New()
templates := template.Must(template.New("login.html").Parse("{{.Error}}"))
- handlers := NewHandlers(service, sessionManager, templates)
+ handlers := NewHandlers(service, sessionManager, templates, nil)
mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
WithArgs("testuser").
diff --git a/internal/config/config.go b/internal/config/config.go
index 2d77025..86d0d5b 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -35,6 +35,10 @@ type Config struct {
// Display
Timezone string // IANA timezone name (e.g., "Pacific/Honolulu")
+
+ // WebAuthn
+ WebAuthnRPID string // Relying Party ID (domain, e.g., "doot.terst.org")
+ WebAuthnOrigin string // Expected origin (e.g., "https://doot.terst.org")
}
// Load reads configuration from environment variables
@@ -67,6 +71,10 @@ func Load() (*Config, error) {
// Display
Timezone: getEnvWithDefault("TIMEZONE", "Pacific/Honolulu"),
+
+ // WebAuthn
+ WebAuthnRPID: os.Getenv("WEBAUTHN_RP_ID"),
+ WebAuthnOrigin: os.Getenv("WEBAUTHN_ORIGIN"),
}
// Validate required fields