Files
skybridge/internal/middleware/security.go
2025-08-26 19:15:37 -04:00

599 lines
17 KiB
Go

package middleware
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"net/http"
"io"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"golang.org/x/time/rate"
"github.com/kms/api-key-service/internal/cache"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/repository"
)
// SecurityMiddleware provides various security features
type SecurityMiddleware struct {
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
appRepo repository.ApplicationRepository
rateLimiters map[string]*rate.Limiter
authRateLimiters map[string]*rate.Limiter
mu sync.RWMutex
}
// NewSecurityMiddleware creates a new security middleware
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger, appRepo repository.ApplicationRepository) *SecurityMiddleware {
cacheManager := cache.NewCacheManager(config, logger)
return &SecurityMiddleware{
config: config,
logger: logger,
cacheManager: cacheManager,
appRepo: appRepo,
rateLimiters: make(map[string]*rate.Limiter),
authRateLimiters: make(map[string]*rate.Limiter),
}
}
// RateLimitMiddleware implements per-IP rate limiting
func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
next.ServeHTTP(w, r)
return
}
// Get client IP
clientIP := s.getClientIP(r)
// Get or create rate limiter for this IP
limiter := s.getRateLimiter(clientIP)
// Check if request is allowed
if !limiter.Allow() {
s.logger.Warn("Rate limit exceeded",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
// Track rate limit violations
s.trackRateLimitViolation(clientIP)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"rate_limit_exceeded","message":"Too many requests"}`))
return
}
next.ServeHTTP(w, r)
})
}
// AuthRateLimitMiddleware implements stricter rate limiting for authentication endpoints
func (s *SecurityMiddleware) AuthRateLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
next.ServeHTTP(w, r)
return
}
clientIP := s.getClientIP(r)
// Use stricter rate limits for auth endpoints
limiter := s.getAuthRateLimiter(clientIP)
// Check if request is allowed
if !limiter.Allow() {
s.logger.Warn("Auth rate limit exceeded",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
// Track authentication failures for brute force protection
s.TrackAuthenticationFailure(clientIP, "")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"auth_rate_limit_exceeded","message":"Too many authentication attempts"}`))
return
}
next.ServeHTTP(w, r)
})
}
// BruteForceProtectionMiddleware implements brute force protection
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := s.getClientIP(r)
// Check if IP is temporarily blocked
if s.isIPBlocked(clientIP) {
s.logger.Warn("Blocked IP attempted access",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"ip_blocked","message":"IP temporarily blocked due to suspicious activity"}`))
return
}
next.ServeHTTP(w, r)
})
}
// IPWhitelistMiddleware implements IP whitelisting
func (s *SecurityMiddleware) IPWhitelistMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
whitelist := s.config.GetStringSlice("IP_WHITELIST")
if len(whitelist) == 0 {
// No whitelist configured, allow all
next.ServeHTTP(w, r)
return
}
clientIP := s.getClientIP(r)
// Check if IP is in whitelist
if !s.isIPInList(clientIP, whitelist) {
s.logger.Warn("Non-whitelisted IP attempted access",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"ip_not_whitelisted","message":"IP not in whitelist"}`))
return
}
next.ServeHTTP(w, r)
})
}
// SecurityHeadersMiddleware adds security headers
func (s *SecurityMiddleware) SecurityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Set Content Security Policy - more permissive for test pages in development
csp := "default-src 'self'"
if !s.config.IsProduction() && strings.HasPrefix(r.URL.Path, "/test/") {
// Allow inline styles and scripts for test pages in development
csp = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline' 'unsafe-eval'"
s.logger.Debug("Using permissive CSP for test page", zap.String("path", r.URL.Path), zap.String("csp", csp))
} else {
s.logger.Debug("Using default CSP", zap.String("path", r.URL.Path), zap.Bool("is_production", s.config.IsProduction()))
}
w.Header().Set("Content-Security-Policy", csp)
// Add HSTS header for HTTPS
if r.TLS != nil {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
next.ServeHTTP(w, r)
})
}
// GinSecurityHeaders returns a Gin-compatible middleware function
func (s *SecurityMiddleware) GinSecurityHeaders() gin.HandlerFunc {
return func(c *gin.Context) {
// Add security headers
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
// Set Content Security Policy - more permissive for test pages in development
csp := "default-src 'self'"
if !s.config.IsProduction() && strings.HasPrefix(c.Request.URL.Path, "/test/") {
// Allow inline styles and scripts for test pages in development
csp = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline' 'unsafe-eval'"
s.logger.Debug("Using permissive CSP for test page", zap.String("path", c.Request.URL.Path), zap.String("csp", csp))
} else {
s.logger.Debug("Using default CSP", zap.String("path", c.Request.URL.Path), zap.Bool("is_production", s.config.IsProduction()))
}
c.Header("Content-Security-Policy", csp)
// Add HSTS header for HTTPS
if c.Request.TLS != nil {
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
c.Next()
}
}
// AuthenticationFailureTracker tracks authentication failures for brute force protection
func (s *SecurityMiddleware) TrackAuthenticationFailure(clientIP, userID string) {
ctx := context.Background()
// Track failures by IP
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
s.incrementFailureCount(ctx, ipKey)
// Track failures by user ID if provided
if userID != "" {
userKey := cache.CacheKey("auth_failures_user", userID)
s.incrementFailureCount(ctx, userKey)
}
// Check if we should block the IP
s.checkAndBlockIP(clientIP)
}
// ClearAuthenticationFailures clears failure count on successful authentication
func (s *SecurityMiddleware) ClearAuthenticationFailures(clientIP, userID string) {
ctx := context.Background()
// Clear failures by IP
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
s.cacheManager.Delete(ctx, ipKey)
// Clear failures by user ID if provided
if userID != "" {
userKey := cache.CacheKey("auth_failures_user", userID)
s.cacheManager.Delete(ctx, userKey)
}
}
// Helper methods
func (s *SecurityMiddleware) getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
// Take the first IP in the chain
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
xri := r.Header.Get("X-Real-IP")
if xri != "" {
return xri
}
// Fall back to RemoteAddr
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
s.mu.RLock()
limiter, exists := s.rateLimiters[clientIP]
s.mu.RUnlock()
if exists {
return limiter
}
// Create new rate limiter
rps := s.config.GetInt("RATE_LIMIT_RPS")
if rps <= 0 {
rps = 100 // Default
}
burst := s.config.GetInt("RATE_LIMIT_BURST")
if burst <= 0 {
burst = 200 // Default
}
limiter = rate.NewLimiter(rate.Limit(rps), burst)
s.mu.Lock()
s.rateLimiters[clientIP] = limiter
s.mu.Unlock()
return limiter
}
func (s *SecurityMiddleware) getAuthRateLimiter(clientIP string) *rate.Limiter {
s.mu.RLock()
limiter, exists := s.authRateLimiters[clientIP]
s.mu.RUnlock()
if exists {
return limiter
}
// Create new auth rate limiter with stricter limits
authRPS := s.config.GetInt("AUTH_RATE_LIMIT_RPS")
if authRPS <= 0 {
authRPS = 5 // Very strict default for auth endpoints
}
authBurst := s.config.GetInt("AUTH_RATE_LIMIT_BURST")
if authBurst <= 0 {
authBurst = 10 // Allow small bursts
}
limiter = rate.NewLimiter(rate.Limit(authRPS), authBurst)
s.mu.Lock()
s.authRateLimiters[clientIP] = limiter
s.mu.Unlock()
return limiter
}
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
ctx := context.Background()
key := cache.CacheKey("rate_limit_violations", clientIP)
s.incrementFailureCount(ctx, key)
}
func (s *SecurityMiddleware) isIPBlocked(clientIP string) bool {
ctx := context.Background()
key := cache.CacheKey("blocked_ips", clientIP)
exists, err := s.cacheManager.Exists(ctx, key)
if err != nil {
s.logger.Error("Failed to check IP block status",
zap.String("client_ip", clientIP),
zap.Error(err))
return false
}
return exists
}
func (s *SecurityMiddleware) isIPInList(clientIP string, ipList []string) bool {
for _, allowedIP := range ipList {
allowedIP = strings.TrimSpace(allowedIP)
// Support CIDR notation
if strings.Contains(allowedIP, "/") {
_, network, err := net.ParseCIDR(allowedIP)
if err != nil {
s.logger.Warn("Invalid CIDR in IP list", zap.String("cidr", allowedIP))
continue
}
ip := net.ParseIP(clientIP)
if ip != nil && network.Contains(ip) {
return true
}
} else {
// Exact IP match
if clientIP == allowedIP {
return true
}
}
}
return false
}
func (s *SecurityMiddleware) incrementFailureCount(ctx context.Context, key string) {
// Get current count
var count int
err := s.cacheManager.GetJSON(ctx, key, &count)
if err != nil {
// Key doesn't exist, start with 0
count = 0
}
count++
// Store updated count with TTL
ttl := s.config.GetDuration("AUTH_FAILURE_WINDOW")
if ttl <= 0 {
ttl = 15 * time.Minute // Default window
}
s.cacheManager.SetJSON(ctx, key, count, ttl)
}
func (s *SecurityMiddleware) checkAndBlockIP(clientIP string) {
ctx := context.Background()
key := cache.CacheKey("auth_failures_ip", clientIP)
var count int
err := s.cacheManager.GetJSON(ctx, key, &count)
if err != nil {
return // No failures recorded
}
maxFailures := s.config.GetInt("MAX_AUTH_FAILURES")
if maxFailures <= 0 {
maxFailures = 5 // Default
}
if count >= maxFailures {
// Block the IP
blockKey := cache.CacheKey("blocked_ips", clientIP)
blockDuration := s.config.GetDuration("IP_BLOCK_DURATION")
if blockDuration <= 0 {
blockDuration = 1 * time.Hour // Default
}
blockInfo := map[string]interface{}{
"blocked_at": time.Now().Unix(),
"failure_count": count,
"reason": "excessive_auth_failures",
}
s.cacheManager.SetJSON(ctx, blockKey, blockInfo, blockDuration)
s.logger.Warn("IP blocked due to excessive authentication failures",
zap.String("client_ip", clientIP),
zap.Int("failure_count", count),
zap.Duration("block_duration", blockDuration))
}
}
// RequestSignatureMiddleware validates request signatures (for API key requests)
func (s *SecurityMiddleware) RequestSignatureMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only validate signatures for certain endpoints
if !s.shouldValidateSignature(r) {
next.ServeHTTP(w, r)
return
}
signature := r.Header.Get("X-Signature")
timestamp := r.Header.Get("X-Timestamp")
if signature == "" || timestamp == "" {
s.logger.Warn("Missing signature headers",
zap.String("path", r.URL.Path),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"missing_signature","message":"Request signature required"}`))
return
}
// Validate timestamp (prevent replay attacks)
if !s.isTimestampValid(timestamp) {
s.logger.Warn("Invalid timestamp in request",
zap.String("timestamp", timestamp),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"invalid_timestamp","message":"Request timestamp is invalid or too old"}`))
return
}
// Implement HMAC signature validation
appID := r.Header.Get("X-App-ID")
if appID == "" {
s.logger.Warn("Missing App-ID header for signature validation",
zap.String("path", r.URL.Path),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"missing_app_id","message":"X-App-ID header required for signature validation"}`))
return
}
// Retrieve application to get HMAC key
ctx := r.Context()
app, err := s.appRepo.GetByID(ctx, appID)
if err != nil {
s.logger.Warn("Failed to retrieve application for signature validation",
zap.String("app_id", appID),
zap.Error(err),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"invalid_application","message":"Invalid application ID"}`))
return
}
// Validate the signature
if !s.validateHMACSignature(r, app.HMACKey, signature, timestamp) {
s.logger.Warn("Invalid request signature",
zap.String("app_id", appID),
zap.String("path", r.URL.Path),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"invalid_signature","message":"Request signature is invalid"}`))
return
}
next.ServeHTTP(w, r)
})
}
func (s *SecurityMiddleware) shouldValidateSignature(r *http.Request) bool {
// Define which endpoints require signature validation
signatureRequiredPaths := []string{
"/api/v1/tokens",
"/api/v1/applications",
}
for _, path := range signatureRequiredPaths {
if strings.HasPrefix(r.URL.Path, path) {
return true
}
}
return false
}
func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
// Parse timestamp
timestamp, err := time.Parse(time.RFC3339, timestampStr)
if err != nil {
return false
}
// Check if timestamp is within acceptable window
now := time.Now()
maxAge := s.config.GetDuration("REQUEST_MAX_AGE")
if maxAge <= 0 {
maxAge = 5 * time.Minute // Default
}
return now.Sub(timestamp) <= maxAge && timestamp.Before(now.Add(1*time.Minute))
}
// GetSecurityMetrics returns security-related metrics
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
// This is a simplified version - in production you'd want more comprehensive metrics
metrics := map[string]interface{}{
"active_rate_limiters": len(s.rateLimiters),
"timestamp": time.Now().Unix(),
}
// Count blocked IPs (this is expensive, so you might want to cache this)
// For now, we'll just return the basic metrics
return metrics
}
// validateHMACSignature validates HMAC-SHA256 signature for request integrity
func (s *SecurityMiddleware) validateHMACSignature(r *http.Request, hmacKey, signature, timestamp string) bool {
// Create the signing string: METHOD + PATH + BODY + TIMESTAMP
var bodyBytes []byte
if r.Body != nil {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.Warn("Failed to read request body for signature validation", zap.Error(err))
return false
}
// Restore the body for downstream handlers
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
}
signingString := fmt.Sprintf("%s\n%s\n%s\n%s",
r.Method,
r.URL.Path,
string(bodyBytes),
timestamp)
// Calculate expected signature
mac := hmac.New(sha256.New, []byte(hmacKey))
mac.Write([]byte(signingString))
expectedSignature := hex.EncodeToString(mac.Sum(nil))
// Compare signatures (constant time comparison to prevent timing attacks)
return hmac.Equal([]byte(signature), []byte(expectedSignature))
}