235 lines
6.3 KiB
Go
235 lines
6.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/kms/api-key-service/internal/config"
|
|
)
|
|
|
|
// CSRFMiddleware provides CSRF protection
|
|
type CSRFMiddleware struct {
|
|
config config.ConfigProvider
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewCSRFMiddleware creates a new CSRF middleware
|
|
func NewCSRFMiddleware(config config.ConfigProvider, logger *zap.Logger) *CSRFMiddleware {
|
|
return &CSRFMiddleware{
|
|
config: config,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// CSRFProtection implements CSRF protection for state-changing operations
|
|
func (cm *CSRFMiddleware) CSRFProtection(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Skip CSRF protection for safe methods
|
|
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Skip CSRF protection for specific endpoints that use other authentication
|
|
if cm.shouldSkipCSRF(r) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Get CSRF token from header
|
|
csrfToken := r.Header.Get("X-CSRF-Token")
|
|
if csrfToken == "" {
|
|
cm.logger.Warn("Missing CSRF token",
|
|
zap.String("path", r.URL.Path),
|
|
zap.String("method", r.Method),
|
|
zap.String("remote_addr", r.RemoteAddr))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
w.Write([]byte(`{"error":"csrf_token_missing","message":"CSRF token required"}`))
|
|
return
|
|
}
|
|
|
|
// Validate CSRF token
|
|
if !cm.validateCSRFToken(csrfToken, r) {
|
|
cm.logger.Warn("Invalid CSRF token",
|
|
zap.String("path", r.URL.Path),
|
|
zap.String("method", r.Method),
|
|
zap.String("remote_addr", r.RemoteAddr))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
w.Write([]byte(`{"error":"csrf_token_invalid","message":"Invalid CSRF token"}`))
|
|
return
|
|
}
|
|
|
|
cm.logger.Debug("CSRF token validated successfully",
|
|
zap.String("path", r.URL.Path))
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// GenerateCSRFToken generates a new CSRF token for a user session
|
|
func (cm *CSRFMiddleware) GenerateCSRFToken(userID string) (string, error) {
|
|
// Generate random bytes for token
|
|
tokenBytes := make([]byte, 32)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
cm.logger.Error("Failed to generate CSRF token", zap.Error(err))
|
|
return "", err
|
|
}
|
|
|
|
// Create timestamp
|
|
timestamp := time.Now().Unix()
|
|
|
|
// Create token data
|
|
tokenData := hex.EncodeToString(tokenBytes)
|
|
|
|
// Create signing string: userID:timestamp:tokenData
|
|
timestampStr := strconv.FormatInt(timestamp, 10)
|
|
signingString := userID + ":" + timestampStr + ":" + tokenData
|
|
|
|
// Sign the token with HMAC
|
|
signature := cm.signData(signingString)
|
|
|
|
// Return encoded token: tokenData.timestamp.signature
|
|
token := tokenData + "." + timestampStr + "." + signature
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// validateCSRFToken validates a CSRF token
|
|
func (cm *CSRFMiddleware) validateCSRFToken(token string, r *http.Request) bool {
|
|
// Parse token parts
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 3 {
|
|
cm.logger.Debug("Invalid CSRF token format")
|
|
return false
|
|
}
|
|
|
|
tokenData, timestampStr, signature := parts[0], parts[1], parts[2]
|
|
|
|
// Get user ID from request context or headers
|
|
userID := cm.getUserIDFromRequest(r)
|
|
if userID == "" {
|
|
cm.logger.Debug("No user ID found for CSRF validation")
|
|
return false
|
|
}
|
|
|
|
// Recreate signing string
|
|
signingString := userID + ":" + timestampStr + ":" + tokenData
|
|
|
|
// Verify signature
|
|
expectedSignature := cm.signData(signingString)
|
|
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
|
|
cm.logger.Debug("CSRF token signature verification failed")
|
|
return false
|
|
}
|
|
|
|
// Parse timestamp
|
|
timestampInt, err := strconv.ParseInt(timestampStr, 10, 64)
|
|
if err != nil {
|
|
cm.logger.Debug("Invalid timestamp in CSRF token", zap.Error(err))
|
|
return false
|
|
}
|
|
timestamp := time.Unix(timestampInt, 0)
|
|
|
|
// Check if token is expired (valid for 1 hour by default)
|
|
maxAge := cm.config.GetDuration("CSRF_TOKEN_MAX_AGE")
|
|
if maxAge <= 0 {
|
|
maxAge = 1 * time.Hour
|
|
}
|
|
|
|
if time.Since(timestamp) > maxAge {
|
|
cm.logger.Debug("CSRF token expired",
|
|
zap.Time("timestamp", timestamp),
|
|
zap.Duration("age", time.Since(timestamp)),
|
|
zap.Duration("max_age", maxAge))
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// signData signs data with HMAC
|
|
func (cm *CSRFMiddleware) signData(data string) string {
|
|
// Use the same signing key as for authentication
|
|
signingKey := cm.config.GetString("AUTH_SIGNING_KEY")
|
|
if signingKey == "" {
|
|
cm.logger.Error("AUTH_SIGNING_KEY not configured for CSRF protection")
|
|
return ""
|
|
}
|
|
|
|
mac := hmac.New(sha256.New, []byte(signingKey))
|
|
mac.Write([]byte(data))
|
|
return hex.EncodeToString(mac.Sum(nil))
|
|
}
|
|
|
|
// getUserIDFromRequest extracts user ID from request
|
|
func (cm *CSRFMiddleware) getUserIDFromRequest(r *http.Request) string {
|
|
// Try to get from X-User-Email header
|
|
userEmail := r.Header.Get(cm.config.GetString("AUTH_HEADER_USER_EMAIL"))
|
|
if userEmail != "" {
|
|
return userEmail
|
|
}
|
|
|
|
// Try to get from context (if set by authentication middleware)
|
|
if userID := r.Context().Value("user_id"); userID != nil {
|
|
if id, ok := userID.(string); ok {
|
|
return id
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// shouldSkipCSRF determines if CSRF protection should be skipped for a request
|
|
func (cm *CSRFMiddleware) shouldSkipCSRF(r *http.Request) bool {
|
|
// Skip for API endpoints that use API key authentication
|
|
if strings.HasPrefix(r.URL.Path, "/api/verify") {
|
|
return true
|
|
}
|
|
|
|
// Skip for health check endpoints
|
|
if r.URL.Path == "/health" || r.URL.Path == "/ready" {
|
|
return true
|
|
}
|
|
|
|
// Skip for webhook endpoints (if any)
|
|
if strings.HasPrefix(r.URL.Path, "/webhook/") {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// SetCSRFCookie sets a secure CSRF token cookie
|
|
func (cm *CSRFMiddleware) SetCSRFCookie(w http.ResponseWriter, token string) {
|
|
cookie := &http.Cookie{
|
|
Name: "csrf_token",
|
|
Value: token,
|
|
Path: "/",
|
|
MaxAge: 3600, // 1 hour
|
|
HttpOnly: false, // JavaScript needs to read this for AJAX requests
|
|
Secure: true, // HTTPS only
|
|
SameSite: http.SameSiteStrictMode,
|
|
}
|
|
http.SetCookie(w, cookie)
|
|
}
|
|
|
|
// GetCSRFTokenFromCookie gets CSRF token from cookie
|
|
func (cm *CSRFMiddleware) GetCSRFTokenFromCookie(r *http.Request) string {
|
|
cookie, err := r.Cookie("csrf_token")
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return cookie.Value
|
|
} |