Files
skybridge/kms/internal/middleware/csrf.go
2025-08-26 19:16:41 -04:00

235 lines
6.3 KiB
Go

package middleware
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"net/http"
"strconv"
"strings"
"time"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
)
// CSRFMiddleware provides CSRF protection
type CSRFMiddleware struct {
config config.ConfigProvider
logger *zap.Logger
}
// NewCSRFMiddleware creates a new CSRF middleware
func NewCSRFMiddleware(config config.ConfigProvider, logger *zap.Logger) *CSRFMiddleware {
return &CSRFMiddleware{
config: config,
logger: logger,
}
}
// CSRFProtection implements CSRF protection for state-changing operations
func (cm *CSRFMiddleware) CSRFProtection(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip CSRF protection for safe methods
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
// Skip CSRF protection for specific endpoints that use other authentication
if cm.shouldSkipCSRF(r) {
next.ServeHTTP(w, r)
return
}
// Get CSRF token from header
csrfToken := r.Header.Get("X-CSRF-Token")
if csrfToken == "" {
cm.logger.Warn("Missing CSRF token",
zap.String("path", r.URL.Path),
zap.String("method", r.Method),
zap.String("remote_addr", r.RemoteAddr))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"csrf_token_missing","message":"CSRF token required"}`))
return
}
// Validate CSRF token
if !cm.validateCSRFToken(csrfToken, r) {
cm.logger.Warn("Invalid CSRF token",
zap.String("path", r.URL.Path),
zap.String("method", r.Method),
zap.String("remote_addr", r.RemoteAddr))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"csrf_token_invalid","message":"Invalid CSRF token"}`))
return
}
cm.logger.Debug("CSRF token validated successfully",
zap.String("path", r.URL.Path))
next.ServeHTTP(w, r)
})
}
// GenerateCSRFToken generates a new CSRF token for a user session
func (cm *CSRFMiddleware) GenerateCSRFToken(userID string) (string, error) {
// Generate random bytes for token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
cm.logger.Error("Failed to generate CSRF token", zap.Error(err))
return "", err
}
// Create timestamp
timestamp := time.Now().Unix()
// Create token data
tokenData := hex.EncodeToString(tokenBytes)
// Create signing string: userID:timestamp:tokenData
timestampStr := strconv.FormatInt(timestamp, 10)
signingString := userID + ":" + timestampStr + ":" + tokenData
// Sign the token with HMAC
signature := cm.signData(signingString)
// Return encoded token: tokenData.timestamp.signature
token := tokenData + "." + timestampStr + "." + signature
return token, nil
}
// validateCSRFToken validates a CSRF token
func (cm *CSRFMiddleware) validateCSRFToken(token string, r *http.Request) bool {
// Parse token parts
parts := strings.Split(token, ".")
if len(parts) != 3 {
cm.logger.Debug("Invalid CSRF token format")
return false
}
tokenData, timestampStr, signature := parts[0], parts[1], parts[2]
// Get user ID from request context or headers
userID := cm.getUserIDFromRequest(r)
if userID == "" {
cm.logger.Debug("No user ID found for CSRF validation")
return false
}
// Recreate signing string
signingString := userID + ":" + timestampStr + ":" + tokenData
// Verify signature
expectedSignature := cm.signData(signingString)
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
cm.logger.Debug("CSRF token signature verification failed")
return false
}
// Parse timestamp
timestampInt, err := strconv.ParseInt(timestampStr, 10, 64)
if err != nil {
cm.logger.Debug("Invalid timestamp in CSRF token", zap.Error(err))
return false
}
timestamp := time.Unix(timestampInt, 0)
// Check if token is expired (valid for 1 hour by default)
maxAge := cm.config.GetDuration("CSRF_TOKEN_MAX_AGE")
if maxAge <= 0 {
maxAge = 1 * time.Hour
}
if time.Since(timestamp) > maxAge {
cm.logger.Debug("CSRF token expired",
zap.Time("timestamp", timestamp),
zap.Duration("age", time.Since(timestamp)),
zap.Duration("max_age", maxAge))
return false
}
return true
}
// signData signs data with HMAC
func (cm *CSRFMiddleware) signData(data string) string {
// Use the same signing key as for authentication
signingKey := cm.config.GetString("AUTH_SIGNING_KEY")
if signingKey == "" {
cm.logger.Error("AUTH_SIGNING_KEY not configured for CSRF protection")
return ""
}
mac := hmac.New(sha256.New, []byte(signingKey))
mac.Write([]byte(data))
return hex.EncodeToString(mac.Sum(nil))
}
// getUserIDFromRequest extracts user ID from request
func (cm *CSRFMiddleware) getUserIDFromRequest(r *http.Request) string {
// Try to get from X-User-Email header
userEmail := r.Header.Get(cm.config.GetString("AUTH_HEADER_USER_EMAIL"))
if userEmail != "" {
return userEmail
}
// Try to get from context (if set by authentication middleware)
if userID := r.Context().Value("user_id"); userID != nil {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}
// shouldSkipCSRF determines if CSRF protection should be skipped for a request
func (cm *CSRFMiddleware) shouldSkipCSRF(r *http.Request) bool {
// Skip for API endpoints that use API key authentication
if strings.HasPrefix(r.URL.Path, "/api/verify") {
return true
}
// Skip for health check endpoints
if r.URL.Path == "/health" || r.URL.Path == "/ready" {
return true
}
// Skip for webhook endpoints (if any)
if strings.HasPrefix(r.URL.Path, "/webhook/") {
return true
}
return false
}
// SetCSRFCookie sets a secure CSRF token cookie
func (cm *CSRFMiddleware) SetCSRFCookie(w http.ResponseWriter, token string) {
cookie := &http.Cookie{
Name: "csrf_token",
Value: token,
Path: "/",
MaxAge: 3600, // 1 hour
HttpOnly: false, // JavaScript needs to read this for AJAX requests
Secure: true, // HTTPS only
SameSite: http.SameSiteStrictMode,
}
http.SetCookie(w, cookie)
}
// GetCSRFTokenFromCookie gets CSRF token from cookie
func (cm *CSRFMiddleware) GetCSRFTokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("csrf_token")
if err != nil {
return ""
}
return cookie.Value
}