-
This commit is contained in:
235
internal/middleware/csrf.go
Normal file
235
internal/middleware/csrf.go
Normal file
@ -0,0 +1,235 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
)
|
||||
|
||||
// CSRFMiddleware provides CSRF protection
|
||||
type CSRFMiddleware struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCSRFMiddleware creates a new CSRF middleware
|
||||
func NewCSRFMiddleware(config config.ConfigProvider, logger *zap.Logger) *CSRFMiddleware {
|
||||
return &CSRFMiddleware{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFProtection implements CSRF protection for state-changing operations
|
||||
func (cm *CSRFMiddleware) CSRFProtection(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip CSRF protection for safe methods
|
||||
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip CSRF protection for specific endpoints that use other authentication
|
||||
if cm.shouldSkipCSRF(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Get CSRF token from header
|
||||
csrfToken := r.Header.Get("X-CSRF-Token")
|
||||
if csrfToken == "" {
|
||||
cm.logger.Warn("Missing CSRF token",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("remote_addr", r.RemoteAddr))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"csrf_token_missing","message":"CSRF token required"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF token
|
||||
if !cm.validateCSRFToken(csrfToken, r) {
|
||||
cm.logger.Warn("Invalid CSRF token",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("remote_addr", r.RemoteAddr))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"csrf_token_invalid","message":"Invalid CSRF token"}`))
|
||||
return
|
||||
}
|
||||
|
||||
cm.logger.Debug("CSRF token validated successfully",
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateCSRFToken generates a new CSRF token for a user session
|
||||
func (cm *CSRFMiddleware) GenerateCSRFToken(userID string) (string, error) {
|
||||
// Generate random bytes for token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
cm.logger.Error("Failed to generate CSRF token", zap.Error(err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create timestamp
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
// Create token data
|
||||
tokenData := hex.EncodeToString(tokenBytes)
|
||||
|
||||
// Create signing string: userID:timestamp:tokenData
|
||||
timestampStr := strconv.FormatInt(timestamp, 10)
|
||||
signingString := userID + ":" + timestampStr + ":" + tokenData
|
||||
|
||||
// Sign the token with HMAC
|
||||
signature := cm.signData(signingString)
|
||||
|
||||
// Return encoded token: tokenData.timestamp.signature
|
||||
token := tokenData + "." + timestampStr + "." + signature
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// validateCSRFToken validates a CSRF token
|
||||
func (cm *CSRFMiddleware) validateCSRFToken(token string, r *http.Request) bool {
|
||||
// Parse token parts
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
cm.logger.Debug("Invalid CSRF token format")
|
||||
return false
|
||||
}
|
||||
|
||||
tokenData, timestampStr, signature := parts[0], parts[1], parts[2]
|
||||
|
||||
// Get user ID from request context or headers
|
||||
userID := cm.getUserIDFromRequest(r)
|
||||
if userID == "" {
|
||||
cm.logger.Debug("No user ID found for CSRF validation")
|
||||
return false
|
||||
}
|
||||
|
||||
// Recreate signing string
|
||||
signingString := userID + ":" + timestampStr + ":" + tokenData
|
||||
|
||||
// Verify signature
|
||||
expectedSignature := cm.signData(signingString)
|
||||
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
|
||||
cm.logger.Debug("CSRF token signature verification failed")
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse timestamp
|
||||
timestampInt, err := strconv.ParseInt(timestampStr, 10, 64)
|
||||
if err != nil {
|
||||
cm.logger.Debug("Invalid timestamp in CSRF token", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
timestamp := time.Unix(timestampInt, 0)
|
||||
|
||||
// Check if token is expired (valid for 1 hour by default)
|
||||
maxAge := cm.config.GetDuration("CSRF_TOKEN_MAX_AGE")
|
||||
if maxAge <= 0 {
|
||||
maxAge = 1 * time.Hour
|
||||
}
|
||||
|
||||
if time.Since(timestamp) > maxAge {
|
||||
cm.logger.Debug("CSRF token expired",
|
||||
zap.Time("timestamp", timestamp),
|
||||
zap.Duration("age", time.Since(timestamp)),
|
||||
zap.Duration("max_age", maxAge))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// signData signs data with HMAC
|
||||
func (cm *CSRFMiddleware) signData(data string) string {
|
||||
// Use the same signing key as for authentication
|
||||
signingKey := cm.config.GetString("AUTH_SIGNING_KEY")
|
||||
if signingKey == "" {
|
||||
cm.logger.Error("AUTH_SIGNING_KEY not configured for CSRF protection")
|
||||
return ""
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(signingKey))
|
||||
mac.Write([]byte(data))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// getUserIDFromRequest extracts user ID from request
|
||||
func (cm *CSRFMiddleware) getUserIDFromRequest(r *http.Request) string {
|
||||
// Try to get from X-User-Email header
|
||||
userEmail := r.Header.Get(cm.config.GetString("AUTH_HEADER_USER_EMAIL"))
|
||||
if userEmail != "" {
|
||||
return userEmail
|
||||
}
|
||||
|
||||
// Try to get from context (if set by authentication middleware)
|
||||
if userID := r.Context().Value("user_id"); userID != nil {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// shouldSkipCSRF determines if CSRF protection should be skipped for a request
|
||||
func (cm *CSRFMiddleware) shouldSkipCSRF(r *http.Request) bool {
|
||||
// Skip for API endpoints that use API key authentication
|
||||
if strings.HasPrefix(r.URL.Path, "/api/verify") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip for health check endpoints
|
||||
if r.URL.Path == "/health" || r.URL.Path == "/ready" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip for webhook endpoints (if any)
|
||||
if strings.HasPrefix(r.URL.Path, "/webhook/") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetCSRFCookie sets a secure CSRF token cookie
|
||||
func (cm *CSRFMiddleware) SetCSRFCookie(w http.ResponseWriter, token string) {
|
||||
cookie := &http.Cookie{
|
||||
Name: "csrf_token",
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: 3600, // 1 hour
|
||||
HttpOnly: false, // JavaScript needs to read this for AJAX requests
|
||||
Secure: true, // HTTPS only
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
|
||||
// GetCSRFTokenFromCookie gets CSRF token from cookie
|
||||
func (cm *CSRFMiddleware) GetCSRFTokenFromCookie(r *http.Request) string {
|
||||
cookie, err := r.Cookie("csrf_token")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
@ -23,23 +23,25 @@ import (
|
||||
|
||||
// SecurityMiddleware provides various security features
|
||||
type SecurityMiddleware struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
cacheManager *cache.CacheManager
|
||||
appRepo repository.ApplicationRepository
|
||||
rateLimiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
cacheManager *cache.CacheManager
|
||||
appRepo repository.ApplicationRepository
|
||||
rateLimiters map[string]*rate.Limiter
|
||||
authRateLimiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSecurityMiddleware creates a new security middleware
|
||||
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger, appRepo repository.ApplicationRepository) *SecurityMiddleware {
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
return &SecurityMiddleware{
|
||||
config: config,
|
||||
logger: logger,
|
||||
cacheManager: cacheManager,
|
||||
appRepo: appRepo,
|
||||
rateLimiters: make(map[string]*rate.Limiter),
|
||||
config: config,
|
||||
logger: logger,
|
||||
cacheManager: cacheManager,
|
||||
appRepo: appRepo,
|
||||
rateLimiters: make(map[string]*rate.Limiter),
|
||||
authRateLimiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,6 +78,38 @@ func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler
|
||||
})
|
||||
}
|
||||
|
||||
// AuthRateLimitMiddleware implements stricter rate limiting for authentication endpoints
|
||||
func (s *SecurityMiddleware) AuthRateLimitMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Use stricter rate limits for auth endpoints
|
||||
limiter := s.getAuthRateLimiter(clientIP)
|
||||
|
||||
// Check if request is allowed
|
||||
if !limiter.Allow() {
|
||||
s.logger.Warn("Auth rate limit exceeded",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
// Track authentication failures for brute force protection
|
||||
s.TrackAuthenticationFailure(clientIP, "")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte(`{"error":"auth_rate_limit_exceeded","message":"Too many authentication attempts"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// BruteForceProtectionMiddleware implements brute force protection
|
||||
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -231,6 +265,35 @@ func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) getAuthRateLimiter(clientIP string) *rate.Limiter {
|
||||
s.mu.RLock()
|
||||
limiter, exists := s.authRateLimiters[clientIP]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Create new auth rate limiter with stricter limits
|
||||
authRPS := s.config.GetInt("AUTH_RATE_LIMIT_RPS")
|
||||
if authRPS <= 0 {
|
||||
authRPS = 5 // Very strict default for auth endpoints
|
||||
}
|
||||
|
||||
authBurst := s.config.GetInt("AUTH_RATE_LIMIT_BURST")
|
||||
if authBurst <= 0 {
|
||||
authBurst = 10 // Allow small bursts
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(rate.Limit(authRPS), authBurst)
|
||||
|
||||
s.mu.Lock()
|
||||
s.authRateLimiters[clientIP] = limiter
|
||||
s.mu.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("rate_limit_violations", clientIP)
|
||||
|
||||
Reference in New Issue
Block a user