package middleware import ( "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "fmt" "net" "net/http" "io" "strconv" "strings" "sync" "time" "go.uber.org/zap" "golang.org/x/time/rate" "github.com/kms/api-key-service/internal/cache" "github.com/kms/api-key-service/internal/config" "github.com/kms/api-key-service/internal/repository" ) // SecurityMiddleware provides various security features type SecurityMiddleware struct { config config.ConfigProvider logger *zap.Logger cacheManager *cache.CacheManager appRepo repository.ApplicationRepository rateLimiters map[string]*rate.Limiter mu sync.RWMutex } // NewSecurityMiddleware creates a new security middleware func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger, appRepo repository.ApplicationRepository) *SecurityMiddleware { cacheManager := cache.NewCacheManager(config, logger) return &SecurityMiddleware{ config: config, logger: logger, cacheManager: cacheManager, appRepo: appRepo, rateLimiters: make(map[string]*rate.Limiter), } } // RateLimitMiddleware implements per-IP rate limiting func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !s.config.GetBool("RATE_LIMIT_ENABLED") { next.ServeHTTP(w, r) return } // Get client IP clientIP := s.getClientIP(r) // Get or create rate limiter for this IP limiter := s.getRateLimiter(clientIP) // Check if request is allowed if !limiter.Allow() { s.logger.Warn("Rate limit exceeded", zap.String("client_ip", clientIP), zap.String("path", r.URL.Path)) // Track rate limit violations s.trackRateLimitViolation(clientIP) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte(`{"error":"rate_limit_exceeded","message":"Too many requests"}`)) return } next.ServeHTTP(w, r) }) } // BruteForceProtectionMiddleware implements brute force protection func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := s.getClientIP(r) // Check if IP is temporarily blocked if s.isIPBlocked(clientIP) { s.logger.Warn("Blocked IP attempted access", zap.String("client_ip", clientIP), zap.String("path", r.URL.Path)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) w.Write([]byte(`{"error":"ip_blocked","message":"IP temporarily blocked due to suspicious activity"}`)) return } next.ServeHTTP(w, r) }) } // IPWhitelistMiddleware implements IP whitelisting func (s *SecurityMiddleware) IPWhitelistMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { whitelist := s.config.GetStringSlice("IP_WHITELIST") if len(whitelist) == 0 { // No whitelist configured, allow all next.ServeHTTP(w, r) return } clientIP := s.getClientIP(r) // Check if IP is in whitelist if !s.isIPInList(clientIP, whitelist) { s.logger.Warn("Non-whitelisted IP attempted access", zap.String("client_ip", clientIP), zap.String("path", r.URL.Path)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) w.Write([]byte(`{"error":"ip_not_whitelisted","message":"IP not in whitelist"}`)) return } next.ServeHTTP(w, r) }) } // SecurityHeadersMiddleware adds security headers func (s *SecurityMiddleware) SecurityHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Add security headers w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Content-Security-Policy", "default-src 'self'") // Add HSTS header for HTTPS if r.TLS != nil { w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } next.ServeHTTP(w, r) }) } // AuthenticationFailureTracker tracks authentication failures for brute force protection func (s *SecurityMiddleware) TrackAuthenticationFailure(clientIP, userID string) { ctx := context.Background() // Track failures by IP ipKey := cache.CacheKey("auth_failures_ip", clientIP) s.incrementFailureCount(ctx, ipKey) // Track failures by user ID if provided if userID != "" { userKey := cache.CacheKey("auth_failures_user", userID) s.incrementFailureCount(ctx, userKey) } // Check if we should block the IP s.checkAndBlockIP(clientIP) } // ClearAuthenticationFailures clears failure count on successful authentication func (s *SecurityMiddleware) ClearAuthenticationFailures(clientIP, userID string) { ctx := context.Background() // Clear failures by IP ipKey := cache.CacheKey("auth_failures_ip", clientIP) s.cacheManager.Delete(ctx, ipKey) // Clear failures by user ID if provided if userID != "" { userKey := cache.CacheKey("auth_failures_user", userID) s.cacheManager.Delete(ctx, userKey) } } // Helper methods func (s *SecurityMiddleware) getClientIP(r *http.Request) string { // Check X-Forwarded-For header first xff := r.Header.Get("X-Forwarded-For") if xff != "" { // Take the first IP in the chain ips := strings.Split(xff, ",") return strings.TrimSpace(ips[0]) } // Check X-Real-IP header xri := r.Header.Get("X-Real-IP") if xri != "" { return xri } // Fall back to RemoteAddr ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return ip } func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter { s.mu.RLock() limiter, exists := s.rateLimiters[clientIP] s.mu.RUnlock() if exists { return limiter } // Create new rate limiter rps := s.config.GetInt("RATE_LIMIT_RPS") if rps <= 0 { rps = 100 // Default } burst := s.config.GetInt("RATE_LIMIT_BURST") if burst <= 0 { burst = 200 // Default } limiter = rate.NewLimiter(rate.Limit(rps), burst) s.mu.Lock() s.rateLimiters[clientIP] = limiter s.mu.Unlock() return limiter } func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) { ctx := context.Background() key := cache.CacheKey("rate_limit_violations", clientIP) s.incrementFailureCount(ctx, key) } func (s *SecurityMiddleware) isIPBlocked(clientIP string) bool { ctx := context.Background() key := cache.CacheKey("blocked_ips", clientIP) exists, err := s.cacheManager.Exists(ctx, key) if err != nil { s.logger.Error("Failed to check IP block status", zap.String("client_ip", clientIP), zap.Error(err)) return false } return exists } func (s *SecurityMiddleware) isIPInList(clientIP string, ipList []string) bool { for _, allowedIP := range ipList { allowedIP = strings.TrimSpace(allowedIP) // Support CIDR notation if strings.Contains(allowedIP, "/") { _, network, err := net.ParseCIDR(allowedIP) if err != nil { s.logger.Warn("Invalid CIDR in IP list", zap.String("cidr", allowedIP)) continue } ip := net.ParseIP(clientIP) if ip != nil && network.Contains(ip) { return true } } else { // Exact IP match if clientIP == allowedIP { return true } } } return false } func (s *SecurityMiddleware) incrementFailureCount(ctx context.Context, key string) { // Get current count var count int err := s.cacheManager.GetJSON(ctx, key, &count) if err != nil { // Key doesn't exist, start with 0 count = 0 } count++ // Store updated count with TTL ttl := s.config.GetDuration("AUTH_FAILURE_WINDOW") if ttl <= 0 { ttl = 15 * time.Minute // Default window } s.cacheManager.SetJSON(ctx, key, count, ttl) } func (s *SecurityMiddleware) checkAndBlockIP(clientIP string) { ctx := context.Background() key := cache.CacheKey("auth_failures_ip", clientIP) var count int err := s.cacheManager.GetJSON(ctx, key, &count) if err != nil { return // No failures recorded } maxFailures := s.config.GetInt("MAX_AUTH_FAILURES") if maxFailures <= 0 { maxFailures = 5 // Default } if count >= maxFailures { // Block the IP blockKey := cache.CacheKey("blocked_ips", clientIP) blockDuration := s.config.GetDuration("IP_BLOCK_DURATION") if blockDuration <= 0 { blockDuration = 1 * time.Hour // Default } blockInfo := map[string]interface{}{ "blocked_at": time.Now().Unix(), "failure_count": count, "reason": "excessive_auth_failures", } s.cacheManager.SetJSON(ctx, blockKey, blockInfo, blockDuration) s.logger.Warn("IP blocked due to excessive authentication failures", zap.String("client_ip", clientIP), zap.Int("failure_count", count), zap.Duration("block_duration", blockDuration)) } } // RequestSignatureMiddleware validates request signatures (for API key requests) func (s *SecurityMiddleware) RequestSignatureMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Only validate signatures for certain endpoints if !s.shouldValidateSignature(r) { next.ServeHTTP(w, r) return } signature := r.Header.Get("X-Signature") timestamp := r.Header.Get("X-Timestamp") if signature == "" || timestamp == "" { s.logger.Warn("Missing signature headers", zap.String("path", r.URL.Path), zap.String("client_ip", s.getClientIP(r))) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) w.Write([]byte(`{"error":"missing_signature","message":"Request signature required"}`)) return } // Validate timestamp (prevent replay attacks) if !s.isTimestampValid(timestamp) { s.logger.Warn("Invalid timestamp in request", zap.String("timestamp", timestamp), zap.String("client_ip", s.getClientIP(r))) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) w.Write([]byte(`{"error":"invalid_timestamp","message":"Request timestamp is invalid or too old"}`)) return } // Implement HMAC signature validation appID := r.Header.Get("X-App-ID") if appID == "" { s.logger.Warn("Missing App-ID header for signature validation", zap.String("path", r.URL.Path), zap.String("client_ip", s.getClientIP(r))) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) w.Write([]byte(`{"error":"missing_app_id","message":"X-App-ID header required for signature validation"}`)) return } // Retrieve application to get HMAC key ctx := r.Context() app, err := s.appRepo.GetByID(ctx, appID) if err != nil { s.logger.Warn("Failed to retrieve application for signature validation", zap.String("app_id", appID), zap.Error(err), zap.String("client_ip", s.getClientIP(r))) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error":"invalid_application","message":"Invalid application ID"}`)) return } // Validate the signature if !s.validateHMACSignature(r, app.HMACKey, signature, timestamp) { s.logger.Warn("Invalid request signature", zap.String("app_id", appID), zap.String("path", r.URL.Path), zap.String("client_ip", s.getClientIP(r))) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error":"invalid_signature","message":"Request signature is invalid"}`)) return } next.ServeHTTP(w, r) }) } func (s *SecurityMiddleware) shouldValidateSignature(r *http.Request) bool { // Define which endpoints require signature validation signatureRequiredPaths := []string{ "/api/v1/tokens", "/api/v1/applications", } for _, path := range signatureRequiredPaths { if strings.HasPrefix(r.URL.Path, path) { return true } } return false } func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool { // Parse timestamp timestamp, err := time.Parse(time.RFC3339, timestampStr) if err != nil { return false } // Check if timestamp is within acceptable window now := time.Now() maxAge := s.config.GetDuration("REQUEST_MAX_AGE") if maxAge <= 0 { maxAge = 5 * time.Minute // Default } return now.Sub(timestamp) <= maxAge && timestamp.Before(now.Add(1*time.Minute)) } // GetSecurityMetrics returns security-related metrics func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} { // This is a simplified version - in production you'd want more comprehensive metrics metrics := map[string]interface{}{ "active_rate_limiters": len(s.rateLimiters), "timestamp": time.Now().Unix(), } // Count blocked IPs (this is expensive, so you might want to cache this) // For now, we'll just return the basic metrics return metrics } // validateHMACSignature validates HMAC-SHA256 signature for request integrity func (s *SecurityMiddleware) validateHMACSignature(r *http.Request, hmacKey, signature, timestamp string) bool { // Create the signing string: METHOD + PATH + BODY + TIMESTAMP var bodyBytes []byte if r.Body != nil { var err error bodyBytes, err = io.ReadAll(r.Body) if err != nil { s.logger.Warn("Failed to read request body for signature validation", zap.Error(err)) return false } // Restore the body for downstream handlers r.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) } signingString := fmt.Sprintf("%s\n%s\n%s\n%s", r.Method, r.URL.Path, string(bodyBytes), timestamp) // Calculate expected signature mac := hmac.New(sha256.New, []byte(hmacKey)) mac.Write([]byte(signingString)) expectedSignature := hex.EncodeToString(mac.Sum(nil)) // Compare signatures (constant time comparison to prevent timing attacks) return hmac.Equal([]byte(signature), []byte(expectedSignature)) }