599 lines
17 KiB
Go
599 lines
17 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"io"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/time/rate"
|
|
|
|
"github.com/kms/api-key-service/internal/cache"
|
|
"github.com/kms/api-key-service/internal/config"
|
|
"github.com/kms/api-key-service/internal/repository"
|
|
)
|
|
|
|
// SecurityMiddleware provides various security features
|
|
type SecurityMiddleware struct {
|
|
config config.ConfigProvider
|
|
logger *zap.Logger
|
|
cacheManager *cache.CacheManager
|
|
appRepo repository.ApplicationRepository
|
|
rateLimiters map[string]*rate.Limiter
|
|
authRateLimiters map[string]*rate.Limiter
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewSecurityMiddleware creates a new security middleware
|
|
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger, appRepo repository.ApplicationRepository) *SecurityMiddleware {
|
|
cacheManager := cache.NewCacheManager(config, logger)
|
|
return &SecurityMiddleware{
|
|
config: config,
|
|
logger: logger,
|
|
cacheManager: cacheManager,
|
|
appRepo: appRepo,
|
|
rateLimiters: make(map[string]*rate.Limiter),
|
|
authRateLimiters: make(map[string]*rate.Limiter),
|
|
}
|
|
}
|
|
|
|
// RateLimitMiddleware implements per-IP rate limiting
|
|
func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Get client IP
|
|
clientIP := s.getClientIP(r)
|
|
|
|
// Get or create rate limiter for this IP
|
|
limiter := s.getRateLimiter(clientIP)
|
|
|
|
// Check if request is allowed
|
|
if !limiter.Allow() {
|
|
s.logger.Warn("Rate limit exceeded",
|
|
zap.String("client_ip", clientIP),
|
|
zap.String("path", r.URL.Path))
|
|
|
|
// Track rate limit violations
|
|
s.trackRateLimitViolation(clientIP)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Write([]byte(`{"error":"rate_limit_exceeded","message":"Too many requests"}`))
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// AuthRateLimitMiddleware implements stricter rate limiting for authentication endpoints
|
|
func (s *SecurityMiddleware) AuthRateLimitMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
clientIP := s.getClientIP(r)
|
|
|
|
// Use stricter rate limits for auth endpoints
|
|
limiter := s.getAuthRateLimiter(clientIP)
|
|
|
|
// Check if request is allowed
|
|
if !limiter.Allow() {
|
|
s.logger.Warn("Auth rate limit exceeded",
|
|
zap.String("client_ip", clientIP),
|
|
zap.String("path", r.URL.Path))
|
|
|
|
// Track authentication failures for brute force protection
|
|
s.TrackAuthenticationFailure(clientIP, "")
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Write([]byte(`{"error":"auth_rate_limit_exceeded","message":"Too many authentication attempts"}`))
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// BruteForceProtectionMiddleware implements brute force protection
|
|
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
clientIP := s.getClientIP(r)
|
|
|
|
// Check if IP is temporarily blocked
|
|
if s.isIPBlocked(clientIP) {
|
|
s.logger.Warn("Blocked IP attempted access",
|
|
zap.String("client_ip", clientIP),
|
|
zap.String("path", r.URL.Path))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
w.Write([]byte(`{"error":"ip_blocked","message":"IP temporarily blocked due to suspicious activity"}`))
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// IPWhitelistMiddleware implements IP whitelisting
|
|
func (s *SecurityMiddleware) IPWhitelistMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
whitelist := s.config.GetStringSlice("IP_WHITELIST")
|
|
if len(whitelist) == 0 {
|
|
// No whitelist configured, allow all
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
clientIP := s.getClientIP(r)
|
|
|
|
// Check if IP is in whitelist
|
|
if !s.isIPInList(clientIP, whitelist) {
|
|
s.logger.Warn("Non-whitelisted IP attempted access",
|
|
zap.String("client_ip", clientIP),
|
|
zap.String("path", r.URL.Path))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
w.Write([]byte(`{"error":"ip_not_whitelisted","message":"IP not in whitelist"}`))
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// SecurityHeadersMiddleware adds security headers
|
|
func (s *SecurityMiddleware) SecurityHeadersMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Add security headers
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
|
|
// Set Content Security Policy - more permissive for test pages in development
|
|
csp := "default-src 'self'"
|
|
if !s.config.IsProduction() && strings.HasPrefix(r.URL.Path, "/test/") {
|
|
// Allow inline styles and scripts for test pages in development
|
|
csp = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline' 'unsafe-eval'"
|
|
s.logger.Debug("Using permissive CSP for test page", zap.String("path", r.URL.Path), zap.String("csp", csp))
|
|
} else {
|
|
s.logger.Debug("Using default CSP", zap.String("path", r.URL.Path), zap.Bool("is_production", s.config.IsProduction()))
|
|
}
|
|
w.Header().Set("Content-Security-Policy", csp)
|
|
|
|
// Add HSTS header for HTTPS
|
|
if r.TLS != nil {
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// GinSecurityHeaders returns a Gin-compatible middleware function
|
|
func (s *SecurityMiddleware) GinSecurityHeaders() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Add security headers
|
|
c.Header("X-Content-Type-Options", "nosniff")
|
|
c.Header("X-Frame-Options", "DENY")
|
|
c.Header("X-XSS-Protection", "1; mode=block")
|
|
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
|
|
// Set Content Security Policy - more permissive for test pages in development
|
|
csp := "default-src 'self'"
|
|
if !s.config.IsProduction() && strings.HasPrefix(c.Request.URL.Path, "/test/") {
|
|
// Allow inline styles and scripts for test pages in development
|
|
csp = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline' 'unsafe-eval'"
|
|
s.logger.Debug("Using permissive CSP for test page", zap.String("path", c.Request.URL.Path), zap.String("csp", csp))
|
|
} else {
|
|
s.logger.Debug("Using default CSP", zap.String("path", c.Request.URL.Path), zap.Bool("is_production", s.config.IsProduction()))
|
|
}
|
|
c.Header("Content-Security-Policy", csp)
|
|
|
|
// Add HSTS header for HTTPS
|
|
if c.Request.TLS != nil {
|
|
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// AuthenticationFailureTracker tracks authentication failures for brute force protection
|
|
func (s *SecurityMiddleware) TrackAuthenticationFailure(clientIP, userID string) {
|
|
ctx := context.Background()
|
|
|
|
// Track failures by IP
|
|
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
|
|
s.incrementFailureCount(ctx, ipKey)
|
|
|
|
// Track failures by user ID if provided
|
|
if userID != "" {
|
|
userKey := cache.CacheKey("auth_failures_user", userID)
|
|
s.incrementFailureCount(ctx, userKey)
|
|
}
|
|
|
|
// Check if we should block the IP
|
|
s.checkAndBlockIP(clientIP)
|
|
}
|
|
|
|
// ClearAuthenticationFailures clears failure count on successful authentication
|
|
func (s *SecurityMiddleware) ClearAuthenticationFailures(clientIP, userID string) {
|
|
ctx := context.Background()
|
|
|
|
// Clear failures by IP
|
|
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
|
|
s.cacheManager.Delete(ctx, ipKey)
|
|
|
|
// Clear failures by user ID if provided
|
|
if userID != "" {
|
|
userKey := cache.CacheKey("auth_failures_user", userID)
|
|
s.cacheManager.Delete(ctx, userKey)
|
|
}
|
|
}
|
|
|
|
// Helper methods
|
|
|
|
func (s *SecurityMiddleware) getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header first
|
|
xff := r.Header.Get("X-Forwarded-For")
|
|
if xff != "" {
|
|
// Take the first IP in the chain
|
|
ips := strings.Split(xff, ",")
|
|
return strings.TrimSpace(ips[0])
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
xri := r.Header.Get("X-Real-IP")
|
|
if xri != "" {
|
|
return xri
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return ip
|
|
}
|
|
|
|
func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
|
|
s.mu.RLock()
|
|
limiter, exists := s.rateLimiters[clientIP]
|
|
s.mu.RUnlock()
|
|
|
|
if exists {
|
|
return limiter
|
|
}
|
|
|
|
// Create new rate limiter
|
|
rps := s.config.GetInt("RATE_LIMIT_RPS")
|
|
if rps <= 0 {
|
|
rps = 100 // Default
|
|
}
|
|
|
|
burst := s.config.GetInt("RATE_LIMIT_BURST")
|
|
if burst <= 0 {
|
|
burst = 200 // Default
|
|
}
|
|
|
|
limiter = rate.NewLimiter(rate.Limit(rps), burst)
|
|
|
|
s.mu.Lock()
|
|
s.rateLimiters[clientIP] = limiter
|
|
s.mu.Unlock()
|
|
|
|
return limiter
|
|
}
|
|
|
|
func (s *SecurityMiddleware) getAuthRateLimiter(clientIP string) *rate.Limiter {
|
|
s.mu.RLock()
|
|
limiter, exists := s.authRateLimiters[clientIP]
|
|
s.mu.RUnlock()
|
|
|
|
if exists {
|
|
return limiter
|
|
}
|
|
|
|
// Create new auth rate limiter with stricter limits
|
|
authRPS := s.config.GetInt("AUTH_RATE_LIMIT_RPS")
|
|
if authRPS <= 0 {
|
|
authRPS = 5 // Very strict default for auth endpoints
|
|
}
|
|
|
|
authBurst := s.config.GetInt("AUTH_RATE_LIMIT_BURST")
|
|
if authBurst <= 0 {
|
|
authBurst = 10 // Allow small bursts
|
|
}
|
|
|
|
limiter = rate.NewLimiter(rate.Limit(authRPS), authBurst)
|
|
|
|
s.mu.Lock()
|
|
s.authRateLimiters[clientIP] = limiter
|
|
s.mu.Unlock()
|
|
|
|
return limiter
|
|
}
|
|
|
|
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
|
|
ctx := context.Background()
|
|
key := cache.CacheKey("rate_limit_violations", clientIP)
|
|
s.incrementFailureCount(ctx, key)
|
|
}
|
|
|
|
func (s *SecurityMiddleware) isIPBlocked(clientIP string) bool {
|
|
ctx := context.Background()
|
|
key := cache.CacheKey("blocked_ips", clientIP)
|
|
|
|
exists, err := s.cacheManager.Exists(ctx, key)
|
|
if err != nil {
|
|
s.logger.Error("Failed to check IP block status",
|
|
zap.String("client_ip", clientIP),
|
|
zap.Error(err))
|
|
return false
|
|
}
|
|
|
|
return exists
|
|
}
|
|
|
|
func (s *SecurityMiddleware) isIPInList(clientIP string, ipList []string) bool {
|
|
for _, allowedIP := range ipList {
|
|
allowedIP = strings.TrimSpace(allowedIP)
|
|
|
|
// Support CIDR notation
|
|
if strings.Contains(allowedIP, "/") {
|
|
_, network, err := net.ParseCIDR(allowedIP)
|
|
if err != nil {
|
|
s.logger.Warn("Invalid CIDR in IP list", zap.String("cidr", allowedIP))
|
|
continue
|
|
}
|
|
|
|
ip := net.ParseIP(clientIP)
|
|
if ip != nil && network.Contains(ip) {
|
|
return true
|
|
}
|
|
} else {
|
|
// Exact IP match
|
|
if clientIP == allowedIP {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (s *SecurityMiddleware) incrementFailureCount(ctx context.Context, key string) {
|
|
// Get current count
|
|
var count int
|
|
err := s.cacheManager.GetJSON(ctx, key, &count)
|
|
if err != nil {
|
|
// Key doesn't exist, start with 0
|
|
count = 0
|
|
}
|
|
|
|
count++
|
|
|
|
// Store updated count with TTL
|
|
ttl := s.config.GetDuration("AUTH_FAILURE_WINDOW")
|
|
if ttl <= 0 {
|
|
ttl = 15 * time.Minute // Default window
|
|
}
|
|
|
|
s.cacheManager.SetJSON(ctx, key, count, ttl)
|
|
}
|
|
|
|
func (s *SecurityMiddleware) checkAndBlockIP(clientIP string) {
|
|
ctx := context.Background()
|
|
key := cache.CacheKey("auth_failures_ip", clientIP)
|
|
|
|
var count int
|
|
err := s.cacheManager.GetJSON(ctx, key, &count)
|
|
if err != nil {
|
|
return // No failures recorded
|
|
}
|
|
|
|
maxFailures := s.config.GetInt("MAX_AUTH_FAILURES")
|
|
if maxFailures <= 0 {
|
|
maxFailures = 5 // Default
|
|
}
|
|
|
|
if count >= maxFailures {
|
|
// Block the IP
|
|
blockKey := cache.CacheKey("blocked_ips", clientIP)
|
|
blockDuration := s.config.GetDuration("IP_BLOCK_DURATION")
|
|
if blockDuration <= 0 {
|
|
blockDuration = 1 * time.Hour // Default
|
|
}
|
|
|
|
blockInfo := map[string]interface{}{
|
|
"blocked_at": time.Now().Unix(),
|
|
"failure_count": count,
|
|
"reason": "excessive_auth_failures",
|
|
}
|
|
|
|
s.cacheManager.SetJSON(ctx, blockKey, blockInfo, blockDuration)
|
|
|
|
s.logger.Warn("IP blocked due to excessive authentication failures",
|
|
zap.String("client_ip", clientIP),
|
|
zap.Int("failure_count", count),
|
|
zap.Duration("block_duration", blockDuration))
|
|
}
|
|
}
|
|
|
|
// RequestSignatureMiddleware validates request signatures (for API key requests)
|
|
func (s *SecurityMiddleware) RequestSignatureMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Only validate signatures for certain endpoints
|
|
if !s.shouldValidateSignature(r) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
signature := r.Header.Get("X-Signature")
|
|
timestamp := r.Header.Get("X-Timestamp")
|
|
|
|
if signature == "" || timestamp == "" {
|
|
s.logger.Warn("Missing signature headers",
|
|
zap.String("path", r.URL.Path),
|
|
zap.String("client_ip", s.getClientIP(r)))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte(`{"error":"missing_signature","message":"Request signature required"}`))
|
|
return
|
|
}
|
|
|
|
// Validate timestamp (prevent replay attacks)
|
|
if !s.isTimestampValid(timestamp) {
|
|
s.logger.Warn("Invalid timestamp in request",
|
|
zap.String("timestamp", timestamp),
|
|
zap.String("client_ip", s.getClientIP(r)))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte(`{"error":"invalid_timestamp","message":"Request timestamp is invalid or too old"}`))
|
|
return
|
|
}
|
|
|
|
// Implement HMAC signature validation
|
|
appID := r.Header.Get("X-App-ID")
|
|
if appID == "" {
|
|
s.logger.Warn("Missing App-ID header for signature validation",
|
|
zap.String("path", r.URL.Path),
|
|
zap.String("client_ip", s.getClientIP(r)))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte(`{"error":"missing_app_id","message":"X-App-ID header required for signature validation"}`))
|
|
return
|
|
}
|
|
|
|
// Retrieve application to get HMAC key
|
|
ctx := r.Context()
|
|
app, err := s.appRepo.GetByID(ctx, appID)
|
|
if err != nil {
|
|
s.logger.Warn("Failed to retrieve application for signature validation",
|
|
zap.String("app_id", appID),
|
|
zap.Error(err),
|
|
zap.String("client_ip", s.getClientIP(r)))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error":"invalid_application","message":"Invalid application ID"}`))
|
|
return
|
|
}
|
|
|
|
// Validate the signature
|
|
if !s.validateHMACSignature(r, app.HMACKey, signature, timestamp) {
|
|
s.logger.Warn("Invalid request signature",
|
|
zap.String("app_id", appID),
|
|
zap.String("path", r.URL.Path),
|
|
zap.String("client_ip", s.getClientIP(r)))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error":"invalid_signature","message":"Request signature is invalid"}`))
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (s *SecurityMiddleware) shouldValidateSignature(r *http.Request) bool {
|
|
// Define which endpoints require signature validation
|
|
signatureRequiredPaths := []string{
|
|
"/api/v1/tokens",
|
|
"/api/v1/applications",
|
|
}
|
|
|
|
for _, path := range signatureRequiredPaths {
|
|
if strings.HasPrefix(r.URL.Path, path) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
|
|
// Parse timestamp
|
|
timestamp, err := time.Parse(time.RFC3339, timestampStr)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Check if timestamp is within acceptable window
|
|
now := time.Now()
|
|
maxAge := s.config.GetDuration("REQUEST_MAX_AGE")
|
|
if maxAge <= 0 {
|
|
maxAge = 5 * time.Minute // Default
|
|
}
|
|
|
|
return now.Sub(timestamp) <= maxAge && timestamp.Before(now.Add(1*time.Minute))
|
|
}
|
|
|
|
// GetSecurityMetrics returns security-related metrics
|
|
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
|
|
// This is a simplified version - in production you'd want more comprehensive metrics
|
|
metrics := map[string]interface{}{
|
|
"active_rate_limiters": len(s.rateLimiters),
|
|
"timestamp": time.Now().Unix(),
|
|
}
|
|
|
|
// Count blocked IPs (this is expensive, so you might want to cache this)
|
|
// For now, we'll just return the basic metrics
|
|
|
|
return metrics
|
|
}
|
|
|
|
// validateHMACSignature validates HMAC-SHA256 signature for request integrity
|
|
func (s *SecurityMiddleware) validateHMACSignature(r *http.Request, hmacKey, signature, timestamp string) bool {
|
|
// Create the signing string: METHOD + PATH + BODY + TIMESTAMP
|
|
var bodyBytes []byte
|
|
if r.Body != nil {
|
|
var err error
|
|
bodyBytes, err = io.ReadAll(r.Body)
|
|
if err != nil {
|
|
s.logger.Warn("Failed to read request body for signature validation", zap.Error(err))
|
|
return false
|
|
}
|
|
// Restore the body for downstream handlers
|
|
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
|
|
}
|
|
|
|
signingString := fmt.Sprintf("%s\n%s\n%s\n%s",
|
|
r.Method,
|
|
r.URL.Path,
|
|
string(bodyBytes),
|
|
timestamp)
|
|
|
|
// Calculate expected signature
|
|
mac := hmac.New(sha256.New, []byte(hmacKey))
|
|
mac.Write([]byte(signingString))
|
|
expectedSignature := hex.EncodeToString(mac.Sum(nil))
|
|
|
|
// Compare signatures (constant time comparison to prevent timing attacks)
|
|
return hmac.Equal([]byte(signature), []byte(expectedSignature))
|
|
}
|