summaryrefslogtreecommitdiff
path: root/internal/auth/middleware.go
blob: 78f3b53e157acf184c8e9d0e1f35837524088846 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package auth

import (
	"context"
	"crypto/rand"
	"crypto/subtle"
	"encoding/base64"
	"net/http"
	"strings"

	"github.com/alexedwards/scs/v2"
)

const (
	SessionKeyUserID = "user_id"
	SessionKeyCSRF   = "csrf_token"
)

type contextKey string

const ContextKeyCSRF contextKey = "csrf_token"

// Middleware provides authentication middleware
type Middleware struct {
	sessions *scs.SessionManager
}

// NewMiddleware creates a new auth middleware
func NewMiddleware(sessions *scs.SessionManager) *Middleware {
	return &Middleware{sessions: sessions}
}

// RequireAuth redirects to login if not authenticated
func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if !m.IsAuthenticated(r) {
			http.Redirect(w, r, "/login", http.StatusSeeOther)
			return
		}
		next.ServeHTTP(w, r)
	})
}

// IsAuthenticated checks if the current request has a valid session
func (m *Middleware) IsAuthenticated(r *http.Request) bool {
	return m.sessions.Exists(r.Context(), SessionKeyUserID)
}

// GetUserID returns the authenticated user's ID from the session
func (m *Middleware) GetUserID(r *http.Request) int64 {
	return m.sessions.GetInt64(r.Context(), SessionKeyUserID)
}

// SetUserID sets the user ID in the session (called after successful login)
func (m *Middleware) SetUserID(r *http.Request, userID int64) {
	m.sessions.Put(r.Context(), SessionKeyUserID, userID)
}

// ClearSession removes the user ID from the session (called on logout)
func (m *Middleware) ClearSession(r *http.Request) error {
	return m.sessions.Destroy(r.Context())
}

// CSRFProtect checks for a valid CSRF token on state-changing requests
func (m *Middleware) CSRFProtect(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Skip CSRF for agent API endpoints (they use token-based auth, not cookies)
		if strings.HasPrefix(r.URL.Path, "/agent/") {
			next.ServeHTTP(w, r)
			return
		}

		// Ensure a token exists in the session
		if !m.sessions.Exists(r.Context(), SessionKeyCSRF) {
			token, err := generateToken()
			if err != nil {
				http.Error(w, "Internal Server Error", http.StatusInternalServerError)
				return
			}
			m.sessions.Put(r.Context(), SessionKeyCSRF, token)
		}

		token := m.sessions.GetString(r.Context(), SessionKeyCSRF)

		// Check token for state-changing methods
		if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH" {
			requestToken := r.Header.Get("X-CSRF-Token")

			if requestToken == "" {
				requestToken = r.FormValue("csrf_token")
			}

			// Use constant-time comparison to prevent timing attacks
			if requestToken == "" || subtle.ConstantTimeCompare([]byte(requestToken), []byte(token)) != 1 {
				http.Error(w, "Forbidden - CSRF Token Mismatch", http.StatusForbidden)
				return
			}
		}

		// Add token to context for handlers to use
		ctx := context.WithValue(r.Context(), ContextKeyCSRF, token)
		next.ServeHTTP(w, r.WithContext(ctx))
	})
}

// GetCSRFToken retrieves the CSRF token from the session
func (m *Middleware) GetCSRFToken(r *http.Request) string {
	if !m.sessions.Exists(r.Context(), SessionKeyCSRF) {
		return ""
	}
	return m.sessions.GetString(r.Context(), SessionKeyCSRF)
}

// GetCSRFTokenFromContext retrieves the CSRF token from the context
func GetCSRFTokenFromContext(ctx context.Context) string {
	token, ok := ctx.Value(ContextKeyCSRF).(string)
	if !ok {
		return ""
	}
	return token
}

func generateToken() (string, error) {
	b := make([]byte, 32)
	_, err := rand.Read(b)
	if err != nil {
		return "", err
	}
	return base64.URLEncoding.EncodeToString(b), nil
}