-
This commit is contained in:
423
internal/middleware/security.go
Normal file
423
internal/middleware/security.go
Normal file
@ -0,0 +1,423 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/kms/api-key-service/internal/cache"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// SecurityMiddleware provides various security features
|
||||
type SecurityMiddleware struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
cacheManager *cache.CacheManager
|
||||
rateLimiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSecurityMiddleware creates a new security middleware
|
||||
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger) *SecurityMiddleware {
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
return &SecurityMiddleware{
|
||||
config: config,
|
||||
logger: logger,
|
||||
cacheManager: cacheManager,
|
||||
rateLimiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitMiddleware implements per-IP rate limiting
|
||||
func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Get client IP
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Get or create rate limiter for this IP
|
||||
limiter := s.getRateLimiter(clientIP)
|
||||
|
||||
// Check if request is allowed
|
||||
if !limiter.Allow() {
|
||||
s.logger.Warn("Rate limit exceeded",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
// Track rate limit violations
|
||||
s.trackRateLimitViolation(clientIP)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte(`{"error":"rate_limit_exceeded","message":"Too many requests"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// BruteForceProtectionMiddleware implements brute force protection
|
||||
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Check if IP is temporarily blocked
|
||||
if s.isIPBlocked(clientIP) {
|
||||
s.logger.Warn("Blocked IP attempted access",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"ip_blocked","message":"IP temporarily blocked due to suspicious activity"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// IPWhitelistMiddleware implements IP whitelisting
|
||||
func (s *SecurityMiddleware) IPWhitelistMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
whitelist := s.config.GetStringSlice("IP_WHITELIST")
|
||||
if len(whitelist) == 0 {
|
||||
// No whitelist configured, allow all
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Check if IP is in whitelist
|
||||
if !s.isIPInList(clientIP, whitelist) {
|
||||
s.logger.Warn("Non-whitelisted IP attempted access",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"ip_not_whitelisted","message":"IP not in whitelist"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware adds security headers
|
||||
func (s *SecurityMiddleware) SecurityHeadersMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add security headers
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
||||
|
||||
// Add HSTS header for HTTPS
|
||||
if r.TLS != nil {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// AuthenticationFailureTracker tracks authentication failures for brute force protection
|
||||
func (s *SecurityMiddleware) TrackAuthenticationFailure(clientIP, userID string) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Track failures by IP
|
||||
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
|
||||
s.incrementFailureCount(ctx, ipKey)
|
||||
|
||||
// Track failures by user ID if provided
|
||||
if userID != "" {
|
||||
userKey := cache.CacheKey("auth_failures_user", userID)
|
||||
s.incrementFailureCount(ctx, userKey)
|
||||
}
|
||||
|
||||
// Check if we should block the IP
|
||||
s.checkAndBlockIP(clientIP)
|
||||
}
|
||||
|
||||
// ClearAuthenticationFailures clears failure count on successful authentication
|
||||
func (s *SecurityMiddleware) ClearAuthenticationFailures(clientIP, userID string) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Clear failures by IP
|
||||
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
|
||||
s.cacheManager.Delete(ctx, ipKey)
|
||||
|
||||
// Clear failures by user ID if provided
|
||||
if userID != "" {
|
||||
userKey := cache.CacheKey("auth_failures_user", userID)
|
||||
s.cacheManager.Delete(ctx, userKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (s *SecurityMiddleware) getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header first
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff != "" {
|
||||
// Take the first IP in the chain
|
||||
ips := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
xri := r.Header.Get("X-Real-IP")
|
||||
if xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
|
||||
s.mu.RLock()
|
||||
limiter, exists := s.rateLimiters[clientIP]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Create new rate limiter
|
||||
rps := s.config.GetInt("RATE_LIMIT_RPS")
|
||||
if rps <= 0 {
|
||||
rps = 100 // Default
|
||||
}
|
||||
|
||||
burst := s.config.GetInt("RATE_LIMIT_BURST")
|
||||
if burst <= 0 {
|
||||
burst = 200 // Default
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(rate.Limit(rps), burst)
|
||||
|
||||
s.mu.Lock()
|
||||
s.rateLimiters[clientIP] = limiter
|
||||
s.mu.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("rate_limit_violations", clientIP)
|
||||
s.incrementFailureCount(ctx, key)
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) isIPBlocked(clientIP string) bool {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("blocked_ips", clientIP)
|
||||
|
||||
exists, err := s.cacheManager.Exists(ctx, key)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to check IP block status",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) isIPInList(clientIP string, ipList []string) bool {
|
||||
for _, allowedIP := range ipList {
|
||||
allowedIP = strings.TrimSpace(allowedIP)
|
||||
|
||||
// Support CIDR notation
|
||||
if strings.Contains(allowedIP, "/") {
|
||||
_, network, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
s.logger.Warn("Invalid CIDR in IP list", zap.String("cidr", allowedIP))
|
||||
continue
|
||||
}
|
||||
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip != nil && network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// Exact IP match
|
||||
if clientIP == allowedIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) incrementFailureCount(ctx context.Context, key string) {
|
||||
// Get current count
|
||||
var count int
|
||||
err := s.cacheManager.GetJSON(ctx, key, &count)
|
||||
if err != nil {
|
||||
// Key doesn't exist, start with 0
|
||||
count = 0
|
||||
}
|
||||
|
||||
count++
|
||||
|
||||
// Store updated count with TTL
|
||||
ttl := s.config.GetDuration("AUTH_FAILURE_WINDOW")
|
||||
if ttl <= 0 {
|
||||
ttl = 15 * time.Minute // Default window
|
||||
}
|
||||
|
||||
s.cacheManager.SetJSON(ctx, key, count, ttl)
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) checkAndBlockIP(clientIP string) {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("auth_failures_ip", clientIP)
|
||||
|
||||
var count int
|
||||
err := s.cacheManager.GetJSON(ctx, key, &count)
|
||||
if err != nil {
|
||||
return // No failures recorded
|
||||
}
|
||||
|
||||
maxFailures := s.config.GetInt("MAX_AUTH_FAILURES")
|
||||
if maxFailures <= 0 {
|
||||
maxFailures = 5 // Default
|
||||
}
|
||||
|
||||
if count >= maxFailures {
|
||||
// Block the IP
|
||||
blockKey := cache.CacheKey("blocked_ips", clientIP)
|
||||
blockDuration := s.config.GetDuration("IP_BLOCK_DURATION")
|
||||
if blockDuration <= 0 {
|
||||
blockDuration = 1 * time.Hour // Default
|
||||
}
|
||||
|
||||
blockInfo := map[string]interface{}{
|
||||
"blocked_at": time.Now().Unix(),
|
||||
"failure_count": count,
|
||||
"reason": "excessive_auth_failures",
|
||||
}
|
||||
|
||||
s.cacheManager.SetJSON(ctx, blockKey, blockInfo, blockDuration)
|
||||
|
||||
s.logger.Warn("IP blocked due to excessive authentication failures",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Int("failure_count", count),
|
||||
zap.Duration("block_duration", blockDuration))
|
||||
}
|
||||
}
|
||||
|
||||
// RequestSignatureMiddleware validates request signatures (for API key requests)
|
||||
func (s *SecurityMiddleware) RequestSignatureMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only validate signatures for certain endpoints
|
||||
if !s.shouldValidateSignature(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
signature := r.Header.Get("X-Signature")
|
||||
timestamp := r.Header.Get("X-Timestamp")
|
||||
|
||||
if signature == "" || timestamp == "" {
|
||||
s.logger.Warn("Missing signature headers",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("client_ip", s.getClientIP(r)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"missing_signature","message":"Request signature required"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate timestamp (prevent replay attacks)
|
||||
if !s.isTimestampValid(timestamp) {
|
||||
s.logger.Warn("Invalid timestamp in request",
|
||||
zap.String("timestamp", timestamp),
|
||||
zap.String("client_ip", s.getClientIP(r)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"invalid_timestamp","message":"Request timestamp is invalid or too old"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Implement actual signature validation
|
||||
// This would involve validating the HMAC signature using the client's secret
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) shouldValidateSignature(r *http.Request) bool {
|
||||
// Define which endpoints require signature validation
|
||||
signatureRequiredPaths := []string{
|
||||
"/api/v1/tokens",
|
||||
"/api/v1/applications",
|
||||
}
|
||||
|
||||
for _, path := range signatureRequiredPaths {
|
||||
if strings.HasPrefix(r.URL.Path, path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
|
||||
// Parse timestamp
|
||||
timestamp, err := time.Parse(time.RFC3339, timestampStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if timestamp is within acceptable window
|
||||
now := time.Now()
|
||||
maxAge := s.config.GetDuration("REQUEST_MAX_AGE")
|
||||
if maxAge <= 0 {
|
||||
maxAge = 5 * time.Minute // Default
|
||||
}
|
||||
|
||||
return now.Sub(timestamp) <= maxAge && timestamp.Before(now.Add(1*time.Minute))
|
||||
}
|
||||
|
||||
// GetSecurityMetrics returns security-related metrics
|
||||
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
|
||||
ctx := context.Background()
|
||||
|
||||
// This is a simplified version - in production you'd want more comprehensive metrics
|
||||
metrics := map[string]interface{}{
|
||||
"active_rate_limiters": len(s.rateLimiters),
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
// Count blocked IPs (this is expensive, so you might want to cache this)
|
||||
// For now, we'll just return the basic metrics
|
||||
|
||||
return metrics
|
||||
}
|
||||
Reference in New Issue
Block a user