This commit is contained in:
2025-08-23 22:31:47 -04:00
parent 9ca9c53baf
commit e5bccc85c2
22 changed files with 2405 additions and 209 deletions

235
internal/middleware/csrf.go Normal file
View File

@ -0,0 +1,235 @@
package middleware
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"net/http"
"strconv"
"strings"
"time"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
)
// CSRFMiddleware provides CSRF protection
type CSRFMiddleware struct {
config config.ConfigProvider
logger *zap.Logger
}
// NewCSRFMiddleware creates a new CSRF middleware
func NewCSRFMiddleware(config config.ConfigProvider, logger *zap.Logger) *CSRFMiddleware {
return &CSRFMiddleware{
config: config,
logger: logger,
}
}
// CSRFProtection implements CSRF protection for state-changing operations
func (cm *CSRFMiddleware) CSRFProtection(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip CSRF protection for safe methods
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
// Skip CSRF protection for specific endpoints that use other authentication
if cm.shouldSkipCSRF(r) {
next.ServeHTTP(w, r)
return
}
// Get CSRF token from header
csrfToken := r.Header.Get("X-CSRF-Token")
if csrfToken == "" {
cm.logger.Warn("Missing CSRF token",
zap.String("path", r.URL.Path),
zap.String("method", r.Method),
zap.String("remote_addr", r.RemoteAddr))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"csrf_token_missing","message":"CSRF token required"}`))
return
}
// Validate CSRF token
if !cm.validateCSRFToken(csrfToken, r) {
cm.logger.Warn("Invalid CSRF token",
zap.String("path", r.URL.Path),
zap.String("method", r.Method),
zap.String("remote_addr", r.RemoteAddr))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"csrf_token_invalid","message":"Invalid CSRF token"}`))
return
}
cm.logger.Debug("CSRF token validated successfully",
zap.String("path", r.URL.Path))
next.ServeHTTP(w, r)
})
}
// GenerateCSRFToken generates a new CSRF token for a user session
func (cm *CSRFMiddleware) GenerateCSRFToken(userID string) (string, error) {
// Generate random bytes for token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
cm.logger.Error("Failed to generate CSRF token", zap.Error(err))
return "", err
}
// Create timestamp
timestamp := time.Now().Unix()
// Create token data
tokenData := hex.EncodeToString(tokenBytes)
// Create signing string: userID:timestamp:tokenData
timestampStr := strconv.FormatInt(timestamp, 10)
signingString := userID + ":" + timestampStr + ":" + tokenData
// Sign the token with HMAC
signature := cm.signData(signingString)
// Return encoded token: tokenData.timestamp.signature
token := tokenData + "." + timestampStr + "." + signature
return token, nil
}
// validateCSRFToken validates a CSRF token
func (cm *CSRFMiddleware) validateCSRFToken(token string, r *http.Request) bool {
// Parse token parts
parts := strings.Split(token, ".")
if len(parts) != 3 {
cm.logger.Debug("Invalid CSRF token format")
return false
}
tokenData, timestampStr, signature := parts[0], parts[1], parts[2]
// Get user ID from request context or headers
userID := cm.getUserIDFromRequest(r)
if userID == "" {
cm.logger.Debug("No user ID found for CSRF validation")
return false
}
// Recreate signing string
signingString := userID + ":" + timestampStr + ":" + tokenData
// Verify signature
expectedSignature := cm.signData(signingString)
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
cm.logger.Debug("CSRF token signature verification failed")
return false
}
// Parse timestamp
timestampInt, err := strconv.ParseInt(timestampStr, 10, 64)
if err != nil {
cm.logger.Debug("Invalid timestamp in CSRF token", zap.Error(err))
return false
}
timestamp := time.Unix(timestampInt, 0)
// Check if token is expired (valid for 1 hour by default)
maxAge := cm.config.GetDuration("CSRF_TOKEN_MAX_AGE")
if maxAge <= 0 {
maxAge = 1 * time.Hour
}
if time.Since(timestamp) > maxAge {
cm.logger.Debug("CSRF token expired",
zap.Time("timestamp", timestamp),
zap.Duration("age", time.Since(timestamp)),
zap.Duration("max_age", maxAge))
return false
}
return true
}
// signData signs data with HMAC
func (cm *CSRFMiddleware) signData(data string) string {
// Use the same signing key as for authentication
signingKey := cm.config.GetString("AUTH_SIGNING_KEY")
if signingKey == "" {
cm.logger.Error("AUTH_SIGNING_KEY not configured for CSRF protection")
return ""
}
mac := hmac.New(sha256.New, []byte(signingKey))
mac.Write([]byte(data))
return hex.EncodeToString(mac.Sum(nil))
}
// getUserIDFromRequest extracts user ID from request
func (cm *CSRFMiddleware) getUserIDFromRequest(r *http.Request) string {
// Try to get from X-User-Email header
userEmail := r.Header.Get(cm.config.GetString("AUTH_HEADER_USER_EMAIL"))
if userEmail != "" {
return userEmail
}
// Try to get from context (if set by authentication middleware)
if userID := r.Context().Value("user_id"); userID != nil {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}
// shouldSkipCSRF determines if CSRF protection should be skipped for a request
func (cm *CSRFMiddleware) shouldSkipCSRF(r *http.Request) bool {
// Skip for API endpoints that use API key authentication
if strings.HasPrefix(r.URL.Path, "/api/verify") {
return true
}
// Skip for health check endpoints
if r.URL.Path == "/health" || r.URL.Path == "/ready" {
return true
}
// Skip for webhook endpoints (if any)
if strings.HasPrefix(r.URL.Path, "/webhook/") {
return true
}
return false
}
// SetCSRFCookie sets a secure CSRF token cookie
func (cm *CSRFMiddleware) SetCSRFCookie(w http.ResponseWriter, token string) {
cookie := &http.Cookie{
Name: "csrf_token",
Value: token,
Path: "/",
MaxAge: 3600, // 1 hour
HttpOnly: false, // JavaScript needs to read this for AJAX requests
Secure: true, // HTTPS only
SameSite: http.SameSiteStrictMode,
}
http.SetCookie(w, cookie)
}
// GetCSRFTokenFromCookie gets CSRF token from cookie
func (cm *CSRFMiddleware) GetCSRFTokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("csrf_token")
if err != nil {
return ""
}
return cookie.Value
}

View File

@ -23,23 +23,25 @@ import (
// SecurityMiddleware provides various security features
type SecurityMiddleware struct {
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
appRepo repository.ApplicationRepository
rateLimiters map[string]*rate.Limiter
mu sync.RWMutex
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
appRepo repository.ApplicationRepository
rateLimiters map[string]*rate.Limiter
authRateLimiters map[string]*rate.Limiter
mu sync.RWMutex
}
// NewSecurityMiddleware creates a new security middleware
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger, appRepo repository.ApplicationRepository) *SecurityMiddleware {
cacheManager := cache.NewCacheManager(config, logger)
return &SecurityMiddleware{
config: config,
logger: logger,
cacheManager: cacheManager,
appRepo: appRepo,
rateLimiters: make(map[string]*rate.Limiter),
config: config,
logger: logger,
cacheManager: cacheManager,
appRepo: appRepo,
rateLimiters: make(map[string]*rate.Limiter),
authRateLimiters: make(map[string]*rate.Limiter),
}
}
@ -76,6 +78,38 @@ func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler
})
}
// AuthRateLimitMiddleware implements stricter rate limiting for authentication endpoints
func (s *SecurityMiddleware) AuthRateLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
next.ServeHTTP(w, r)
return
}
clientIP := s.getClientIP(r)
// Use stricter rate limits for auth endpoints
limiter := s.getAuthRateLimiter(clientIP)
// Check if request is allowed
if !limiter.Allow() {
s.logger.Warn("Auth rate limit exceeded",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
// Track authentication failures for brute force protection
s.TrackAuthenticationFailure(clientIP, "")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"auth_rate_limit_exceeded","message":"Too many authentication attempts"}`))
return
}
next.ServeHTTP(w, r)
})
}
// BruteForceProtectionMiddleware implements brute force protection
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -231,6 +265,35 @@ func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
return limiter
}
func (s *SecurityMiddleware) getAuthRateLimiter(clientIP string) *rate.Limiter {
s.mu.RLock()
limiter, exists := s.authRateLimiters[clientIP]
s.mu.RUnlock()
if exists {
return limiter
}
// Create new auth rate limiter with stricter limits
authRPS := s.config.GetInt("AUTH_RATE_LIMIT_RPS")
if authRPS <= 0 {
authRPS = 5 // Very strict default for auth endpoints
}
authBurst := s.config.GetInt("AUTH_RATE_LIMIT_BURST")
if authBurst <= 0 {
authBurst = 10 // Allow small bursts
}
limiter = rate.NewLimiter(rate.Limit(authRPS), authBurst)
s.mu.Lock()
s.authRateLimiters[clientIP] = limiter
s.mu.Unlock()
return limiter
}
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
ctx := context.Background()
key := cache.CacheKey("rate_limit_violations", clientIP)