This commit is contained in:
2025-08-26 19:16:41 -04:00
parent 7ca61eb712
commit 6725529b01
113 changed files with 0 additions and 337 deletions

View File

@ -0,0 +1,235 @@
package middleware
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"net/http"
"strconv"
"strings"
"time"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
)
// CSRFMiddleware provides CSRF protection
type CSRFMiddleware struct {
config config.ConfigProvider
logger *zap.Logger
}
// NewCSRFMiddleware creates a new CSRF middleware
func NewCSRFMiddleware(config config.ConfigProvider, logger *zap.Logger) *CSRFMiddleware {
return &CSRFMiddleware{
config: config,
logger: logger,
}
}
// CSRFProtection implements CSRF protection for state-changing operations
func (cm *CSRFMiddleware) CSRFProtection(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip CSRF protection for safe methods
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
// Skip CSRF protection for specific endpoints that use other authentication
if cm.shouldSkipCSRF(r) {
next.ServeHTTP(w, r)
return
}
// Get CSRF token from header
csrfToken := r.Header.Get("X-CSRF-Token")
if csrfToken == "" {
cm.logger.Warn("Missing CSRF token",
zap.String("path", r.URL.Path),
zap.String("method", r.Method),
zap.String("remote_addr", r.RemoteAddr))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"csrf_token_missing","message":"CSRF token required"}`))
return
}
// Validate CSRF token
if !cm.validateCSRFToken(csrfToken, r) {
cm.logger.Warn("Invalid CSRF token",
zap.String("path", r.URL.Path),
zap.String("method", r.Method),
zap.String("remote_addr", r.RemoteAddr))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"csrf_token_invalid","message":"Invalid CSRF token"}`))
return
}
cm.logger.Debug("CSRF token validated successfully",
zap.String("path", r.URL.Path))
next.ServeHTTP(w, r)
})
}
// GenerateCSRFToken generates a new CSRF token for a user session
func (cm *CSRFMiddleware) GenerateCSRFToken(userID string) (string, error) {
// Generate random bytes for token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
cm.logger.Error("Failed to generate CSRF token", zap.Error(err))
return "", err
}
// Create timestamp
timestamp := time.Now().Unix()
// Create token data
tokenData := hex.EncodeToString(tokenBytes)
// Create signing string: userID:timestamp:tokenData
timestampStr := strconv.FormatInt(timestamp, 10)
signingString := userID + ":" + timestampStr + ":" + tokenData
// Sign the token with HMAC
signature := cm.signData(signingString)
// Return encoded token: tokenData.timestamp.signature
token := tokenData + "." + timestampStr + "." + signature
return token, nil
}
// validateCSRFToken validates a CSRF token
func (cm *CSRFMiddleware) validateCSRFToken(token string, r *http.Request) bool {
// Parse token parts
parts := strings.Split(token, ".")
if len(parts) != 3 {
cm.logger.Debug("Invalid CSRF token format")
return false
}
tokenData, timestampStr, signature := parts[0], parts[1], parts[2]
// Get user ID from request context or headers
userID := cm.getUserIDFromRequest(r)
if userID == "" {
cm.logger.Debug("No user ID found for CSRF validation")
return false
}
// Recreate signing string
signingString := userID + ":" + timestampStr + ":" + tokenData
// Verify signature
expectedSignature := cm.signData(signingString)
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
cm.logger.Debug("CSRF token signature verification failed")
return false
}
// Parse timestamp
timestampInt, err := strconv.ParseInt(timestampStr, 10, 64)
if err != nil {
cm.logger.Debug("Invalid timestamp in CSRF token", zap.Error(err))
return false
}
timestamp := time.Unix(timestampInt, 0)
// Check if token is expired (valid for 1 hour by default)
maxAge := cm.config.GetDuration("CSRF_TOKEN_MAX_AGE")
if maxAge <= 0 {
maxAge = 1 * time.Hour
}
if time.Since(timestamp) > maxAge {
cm.logger.Debug("CSRF token expired",
zap.Time("timestamp", timestamp),
zap.Duration("age", time.Since(timestamp)),
zap.Duration("max_age", maxAge))
return false
}
return true
}
// signData signs data with HMAC
func (cm *CSRFMiddleware) signData(data string) string {
// Use the same signing key as for authentication
signingKey := cm.config.GetString("AUTH_SIGNING_KEY")
if signingKey == "" {
cm.logger.Error("AUTH_SIGNING_KEY not configured for CSRF protection")
return ""
}
mac := hmac.New(sha256.New, []byte(signingKey))
mac.Write([]byte(data))
return hex.EncodeToString(mac.Sum(nil))
}
// getUserIDFromRequest extracts user ID from request
func (cm *CSRFMiddleware) getUserIDFromRequest(r *http.Request) string {
// Try to get from X-User-Email header
userEmail := r.Header.Get(cm.config.GetString("AUTH_HEADER_USER_EMAIL"))
if userEmail != "" {
return userEmail
}
// Try to get from context (if set by authentication middleware)
if userID := r.Context().Value("user_id"); userID != nil {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}
// shouldSkipCSRF determines if CSRF protection should be skipped for a request
func (cm *CSRFMiddleware) shouldSkipCSRF(r *http.Request) bool {
// Skip for API endpoints that use API key authentication
if strings.HasPrefix(r.URL.Path, "/api/verify") {
return true
}
// Skip for health check endpoints
if r.URL.Path == "/health" || r.URL.Path == "/ready" {
return true
}
// Skip for webhook endpoints (if any)
if strings.HasPrefix(r.URL.Path, "/webhook/") {
return true
}
return false
}
// SetCSRFCookie sets a secure CSRF token cookie
func (cm *CSRFMiddleware) SetCSRFCookie(w http.ResponseWriter, token string) {
cookie := &http.Cookie{
Name: "csrf_token",
Value: token,
Path: "/",
MaxAge: 3600, // 1 hour
HttpOnly: false, // JavaScript needs to read this for AJAX requests
Secure: true, // HTTPS only
SameSite: http.SameSiteStrictMode,
}
http.SetCookie(w, cookie)
}
// GetCSRFTokenFromCookie gets CSRF token from cookie
func (cm *CSRFMiddleware) GetCSRFTokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("csrf_token")
if err != nil {
return ""
}
return cookie.Value
}

View File

@ -0,0 +1,60 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Logger returns a middleware that logs HTTP requests using zap logger
func Logger(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
// Start timer
start := time.Now()
// Process request
c.Next()
// Calculate latency
latency := time.Since(start)
// Get request information
method := c.Request.Method
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
status := c.Writer.Status()
clientIP := c.ClientIP()
userAgent := c.Request.UserAgent()
// Get error if any
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
// Build log fields
fields := []zap.Field{
zap.String("method", method),
zap.String("path", path),
zap.String("query", query),
zap.Int("status", status),
zap.String("client_ip", clientIP),
zap.String("user_agent", userAgent),
zap.Duration("latency", latency),
zap.Int64("latency_ms", latency.Nanoseconds()/1000000),
}
// Add error field if exists
if errorMessage != "" {
fields = append(fields, zap.String("error", errorMessage))
}
// Log based on status code
switch {
case status >= 500:
logger.Error("HTTP Request", fields...)
case status >= 400:
logger.Warn("HTTP Request", fields...)
default:
logger.Info("HTTP Request", fields...)
}
}
}

View File

@ -0,0 +1,239 @@
package middleware
import (
"context"
"net/http"
"runtime/debug"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"golang.org/x/time/rate"
"github.com/kms/api-key-service/internal/config"
)
// Recovery returns a middleware that recovers from any panics
func Recovery(logger *zap.Logger) gin.HandlerFunc {
return gin.CustomRecoveryWithWriter(gin.DefaultWriter, func(c *gin.Context, recovered interface{}) {
if err, ok := recovered.(string); ok {
logger.Error("Panic recovered",
zap.String("error", err),
zap.String("stack", string(debug.Stack())),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
)
}
c.AbortWithStatus(http.StatusInternalServerError)
})
}
// CORS returns a middleware that handles Cross-Origin Resource Sharing
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// Set CORS headers
c.Header("Access-Control-Allow-Origin", "*") // In production, be more specific
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-User-Email")
c.Header("Access-Control-Expose-Headers", "Content-Length")
c.Header("Access-Control-Max-Age", "86400")
// Handle preflight OPTIONS request
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// Security returns a middleware that adds security headers
func Security() gin.HandlerFunc {
return func(c *gin.Context) {
// Security headers
c.Header("X-Frame-Options", "DENY")
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("Content-Security-Policy", "default-src 'self'")
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
c.Next()
}
}
// RateLimiter holds rate limiting data
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
rate rate.Limit
burst int
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(rps, burst int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
rate: rate.Limit(rps),
burst: burst,
}
}
// GetLimiter returns the rate limiter for a given key
func (rl *RateLimiter) GetLimiter(key string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.limiters[key]
rl.mu.RUnlock()
if !exists {
limiter = rate.NewLimiter(rl.rate, rl.burst)
rl.mu.Lock()
rl.limiters[key] = limiter
rl.mu.Unlock()
}
return limiter
}
// RateLimit returns a middleware that implements rate limiting
func RateLimit(rps, burst int) gin.HandlerFunc {
limiter := NewRateLimiter(rps, burst)
return func(c *gin.Context) {
// Use client IP as the key for rate limiting
key := c.ClientIP()
// Get the limiter for this client
clientLimiter := limiter.GetLimiter(key)
// Check if request is allowed
if !clientLimiter.Allow() {
// Add rate limit headers
c.Header("X-RateLimit-Limit", strconv.Itoa(burst))
c.Header("X-RateLimit-Remaining", "0")
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"message": "Too many requests. Please try again later.",
})
c.Abort()
return
}
// Add rate limit headers for successful requests
remaining := burst - int(clientLimiter.Tokens())
if remaining < 0 {
remaining = 0
}
c.Header("X-RateLimit-Limit", strconv.Itoa(burst))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
c.Next()
}
}
// Authentication returns a middleware that handles authentication
func Authentication(cfg config.ConfigProvider, logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
// For now, we'll implement a basic header-based authentication
// This will be expanded when we implement the full authentication service
userEmail := c.GetHeader(cfg.GetString("AUTH_HEADER_USER_EMAIL"))
if userEmail == "" {
logger.Warn("Authentication failed: missing user email header",
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
)
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized",
"message": "Authentication required",
})
c.Abort()
return
}
// Set user context for downstream handlers
c.Set("user_id", userEmail)
c.Set("auth_method", "header")
logger.Debug("Authentication successful",
zap.String("user_id", userEmail),
zap.String("auth_method", "header"),
)
c.Next()
}
}
// RequestID returns a middleware that adds a unique request ID to each request
func RequestID() gin.HandlerFunc {
return func(c *gin.Context) {
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
c.Header("X-Request-ID", requestID)
c.Set("request_id", requestID)
c.Next()
}
}
// generateRequestID generates a simple request ID
// In production, you might want to use a more sophisticated ID generator
func generateRequestID() string {
return strconv.FormatInt(time.Now().UnixNano(), 36)
}
// Timeout returns a middleware that adds timeout to requests
func Timeout(timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
// ValidateContentType returns a middleware that validates Content-Type header for JSON requests
func ValidateContentType() gin.HandlerFunc {
return func(c *gin.Context) {
// Only validate for POST, PUT, and PATCH requests
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
contentType := c.GetHeader("Content-Type")
// For requests with a body or when Content-Length is not explicitly 0,
// require application/json content type
if c.Request.ContentLength != 0 {
if contentType == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Bad Request",
"message": "Content-Type header is required for POST/PUT/PATCH requests",
})
c.Abort()
return
}
// Require application/json content type for requests with JSON bodies
if contentType != "application/json" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Bad Request",
"message": "Content-Type must be application/json",
})
c.Abort()
return
}
}
}
c.Next()
}
}

View File

@ -0,0 +1,558 @@
package middleware
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"net/http"
"io"
"strings"
"sync"
"time"
"go.uber.org/zap"
"golang.org/x/time/rate"
"github.com/kms/api-key-service/internal/cache"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/repository"
)
// SecurityMiddleware provides various security features
type SecurityMiddleware struct {
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
appRepo repository.ApplicationRepository
rateLimiters map[string]*rate.Limiter
authRateLimiters map[string]*rate.Limiter
mu sync.RWMutex
}
// NewSecurityMiddleware creates a new security middleware
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger, appRepo repository.ApplicationRepository) *SecurityMiddleware {
cacheManager := cache.NewCacheManager(config, logger)
return &SecurityMiddleware{
config: config,
logger: logger,
cacheManager: cacheManager,
appRepo: appRepo,
rateLimiters: make(map[string]*rate.Limiter),
authRateLimiters: make(map[string]*rate.Limiter),
}
}
// RateLimitMiddleware implements per-IP rate limiting
func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
next.ServeHTTP(w, r)
return
}
// Get client IP
clientIP := s.getClientIP(r)
// Get or create rate limiter for this IP
limiter := s.getRateLimiter(clientIP)
// Check if request is allowed
if !limiter.Allow() {
s.logger.Warn("Rate limit exceeded",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
// Track rate limit violations
s.trackRateLimitViolation(clientIP)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"rate_limit_exceeded","message":"Too many requests"}`))
return
}
next.ServeHTTP(w, r)
})
}
// AuthRateLimitMiddleware implements stricter rate limiting for authentication endpoints
func (s *SecurityMiddleware) AuthRateLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
next.ServeHTTP(w, r)
return
}
clientIP := s.getClientIP(r)
// Use stricter rate limits for auth endpoints
limiter := s.getAuthRateLimiter(clientIP)
// Check if request is allowed
if !limiter.Allow() {
s.logger.Warn("Auth rate limit exceeded",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
// Track authentication failures for brute force protection
s.TrackAuthenticationFailure(clientIP, "")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"auth_rate_limit_exceeded","message":"Too many authentication attempts"}`))
return
}
next.ServeHTTP(w, r)
})
}
// BruteForceProtectionMiddleware implements brute force protection
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := s.getClientIP(r)
// Check if IP is temporarily blocked
if s.isIPBlocked(clientIP) {
s.logger.Warn("Blocked IP attempted access",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"ip_blocked","message":"IP temporarily blocked due to suspicious activity"}`))
return
}
next.ServeHTTP(w, r)
})
}
// IPWhitelistMiddleware implements IP whitelisting
func (s *SecurityMiddleware) IPWhitelistMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
whitelist := s.config.GetStringSlice("IP_WHITELIST")
if len(whitelist) == 0 {
// No whitelist configured, allow all
next.ServeHTTP(w, r)
return
}
clientIP := s.getClientIP(r)
// Check if IP is in whitelist
if !s.isIPInList(clientIP, whitelist) {
s.logger.Warn("Non-whitelisted IP attempted access",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"ip_not_whitelisted","message":"IP not in whitelist"}`))
return
}
next.ServeHTTP(w, r)
})
}
// SecurityHeadersMiddleware adds security headers
func (s *SecurityMiddleware) SecurityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
// Add HSTS header for HTTPS
if r.TLS != nil {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
next.ServeHTTP(w, r)
})
}
// AuthenticationFailureTracker tracks authentication failures for brute force protection
func (s *SecurityMiddleware) TrackAuthenticationFailure(clientIP, userID string) {
ctx := context.Background()
// Track failures by IP
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
s.incrementFailureCount(ctx, ipKey)
// Track failures by user ID if provided
if userID != "" {
userKey := cache.CacheKey("auth_failures_user", userID)
s.incrementFailureCount(ctx, userKey)
}
// Check if we should block the IP
s.checkAndBlockIP(clientIP)
}
// ClearAuthenticationFailures clears failure count on successful authentication
func (s *SecurityMiddleware) ClearAuthenticationFailures(clientIP, userID string) {
ctx := context.Background()
// Clear failures by IP
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
s.cacheManager.Delete(ctx, ipKey)
// Clear failures by user ID if provided
if userID != "" {
userKey := cache.CacheKey("auth_failures_user", userID)
s.cacheManager.Delete(ctx, userKey)
}
}
// Helper methods
func (s *SecurityMiddleware) getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
// Take the first IP in the chain
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
xri := r.Header.Get("X-Real-IP")
if xri != "" {
return xri
}
// Fall back to RemoteAddr
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
s.mu.RLock()
limiter, exists := s.rateLimiters[clientIP]
s.mu.RUnlock()
if exists {
return limiter
}
// Create new rate limiter
rps := s.config.GetInt("RATE_LIMIT_RPS")
if rps <= 0 {
rps = 100 // Default
}
burst := s.config.GetInt("RATE_LIMIT_BURST")
if burst <= 0 {
burst = 200 // Default
}
limiter = rate.NewLimiter(rate.Limit(rps), burst)
s.mu.Lock()
s.rateLimiters[clientIP] = limiter
s.mu.Unlock()
return limiter
}
func (s *SecurityMiddleware) getAuthRateLimiter(clientIP string) *rate.Limiter {
s.mu.RLock()
limiter, exists := s.authRateLimiters[clientIP]
s.mu.RUnlock()
if exists {
return limiter
}
// Create new auth rate limiter with stricter limits
authRPS := s.config.GetInt("AUTH_RATE_LIMIT_RPS")
if authRPS <= 0 {
authRPS = 5 // Very strict default for auth endpoints
}
authBurst := s.config.GetInt("AUTH_RATE_LIMIT_BURST")
if authBurst <= 0 {
authBurst = 10 // Allow small bursts
}
limiter = rate.NewLimiter(rate.Limit(authRPS), authBurst)
s.mu.Lock()
s.authRateLimiters[clientIP] = limiter
s.mu.Unlock()
return limiter
}
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
ctx := context.Background()
key := cache.CacheKey("rate_limit_violations", clientIP)
s.incrementFailureCount(ctx, key)
}
func (s *SecurityMiddleware) isIPBlocked(clientIP string) bool {
ctx := context.Background()
key := cache.CacheKey("blocked_ips", clientIP)
exists, err := s.cacheManager.Exists(ctx, key)
if err != nil {
s.logger.Error("Failed to check IP block status",
zap.String("client_ip", clientIP),
zap.Error(err))
return false
}
return exists
}
func (s *SecurityMiddleware) isIPInList(clientIP string, ipList []string) bool {
for _, allowedIP := range ipList {
allowedIP = strings.TrimSpace(allowedIP)
// Support CIDR notation
if strings.Contains(allowedIP, "/") {
_, network, err := net.ParseCIDR(allowedIP)
if err != nil {
s.logger.Warn("Invalid CIDR in IP list", zap.String("cidr", allowedIP))
continue
}
ip := net.ParseIP(clientIP)
if ip != nil && network.Contains(ip) {
return true
}
} else {
// Exact IP match
if clientIP == allowedIP {
return true
}
}
}
return false
}
func (s *SecurityMiddleware) incrementFailureCount(ctx context.Context, key string) {
// Get current count
var count int
err := s.cacheManager.GetJSON(ctx, key, &count)
if err != nil {
// Key doesn't exist, start with 0
count = 0
}
count++
// Store updated count with TTL
ttl := s.config.GetDuration("AUTH_FAILURE_WINDOW")
if ttl <= 0 {
ttl = 15 * time.Minute // Default window
}
s.cacheManager.SetJSON(ctx, key, count, ttl)
}
func (s *SecurityMiddleware) checkAndBlockIP(clientIP string) {
ctx := context.Background()
key := cache.CacheKey("auth_failures_ip", clientIP)
var count int
err := s.cacheManager.GetJSON(ctx, key, &count)
if err != nil {
return // No failures recorded
}
maxFailures := s.config.GetInt("MAX_AUTH_FAILURES")
if maxFailures <= 0 {
maxFailures = 5 // Default
}
if count >= maxFailures {
// Block the IP
blockKey := cache.CacheKey("blocked_ips", clientIP)
blockDuration := s.config.GetDuration("IP_BLOCK_DURATION")
if blockDuration <= 0 {
blockDuration = 1 * time.Hour // Default
}
blockInfo := map[string]interface{}{
"blocked_at": time.Now().Unix(),
"failure_count": count,
"reason": "excessive_auth_failures",
}
s.cacheManager.SetJSON(ctx, blockKey, blockInfo, blockDuration)
s.logger.Warn("IP blocked due to excessive authentication failures",
zap.String("client_ip", clientIP),
zap.Int("failure_count", count),
zap.Duration("block_duration", blockDuration))
}
}
// RequestSignatureMiddleware validates request signatures (for API key requests)
func (s *SecurityMiddleware) RequestSignatureMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only validate signatures for certain endpoints
if !s.shouldValidateSignature(r) {
next.ServeHTTP(w, r)
return
}
signature := r.Header.Get("X-Signature")
timestamp := r.Header.Get("X-Timestamp")
if signature == "" || timestamp == "" {
s.logger.Warn("Missing signature headers",
zap.String("path", r.URL.Path),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"missing_signature","message":"Request signature required"}`))
return
}
// Validate timestamp (prevent replay attacks)
if !s.isTimestampValid(timestamp) {
s.logger.Warn("Invalid timestamp in request",
zap.String("timestamp", timestamp),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"invalid_timestamp","message":"Request timestamp is invalid or too old"}`))
return
}
// Implement HMAC signature validation
appID := r.Header.Get("X-App-ID")
if appID == "" {
s.logger.Warn("Missing App-ID header for signature validation",
zap.String("path", r.URL.Path),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"missing_app_id","message":"X-App-ID header required for signature validation"}`))
return
}
// Retrieve application to get HMAC key
ctx := r.Context()
app, err := s.appRepo.GetByID(ctx, appID)
if err != nil {
s.logger.Warn("Failed to retrieve application for signature validation",
zap.String("app_id", appID),
zap.Error(err),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"invalid_application","message":"Invalid application ID"}`))
return
}
// Validate the signature
if !s.validateHMACSignature(r, app.HMACKey, signature, timestamp) {
s.logger.Warn("Invalid request signature",
zap.String("app_id", appID),
zap.String("path", r.URL.Path),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"invalid_signature","message":"Request signature is invalid"}`))
return
}
next.ServeHTTP(w, r)
})
}
func (s *SecurityMiddleware) shouldValidateSignature(r *http.Request) bool {
// Define which endpoints require signature validation
signatureRequiredPaths := []string{
"/api/v1/tokens",
"/api/v1/applications",
}
for _, path := range signatureRequiredPaths {
if strings.HasPrefix(r.URL.Path, path) {
return true
}
}
return false
}
func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
// Parse timestamp
timestamp, err := time.Parse(time.RFC3339, timestampStr)
if err != nil {
return false
}
// Check if timestamp is within acceptable window
now := time.Now()
maxAge := s.config.GetDuration("REQUEST_MAX_AGE")
if maxAge <= 0 {
maxAge = 5 * time.Minute // Default
}
return now.Sub(timestamp) <= maxAge && timestamp.Before(now.Add(1*time.Minute))
}
// GetSecurityMetrics returns security-related metrics
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
// This is a simplified version - in production you'd want more comprehensive metrics
metrics := map[string]interface{}{
"active_rate_limiters": len(s.rateLimiters),
"timestamp": time.Now().Unix(),
}
// Count blocked IPs (this is expensive, so you might want to cache this)
// For now, we'll just return the basic metrics
return metrics
}
// validateHMACSignature validates HMAC-SHA256 signature for request integrity
func (s *SecurityMiddleware) validateHMACSignature(r *http.Request, hmacKey, signature, timestamp string) bool {
// Create the signing string: METHOD + PATH + BODY + TIMESTAMP
var bodyBytes []byte
if r.Body != nil {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.Warn("Failed to read request body for signature validation", zap.Error(err))
return false
}
// Restore the body for downstream handlers
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
}
signingString := fmt.Sprintf("%s\n%s\n%s\n%s",
r.Method,
r.URL.Path,
string(bodyBytes),
timestamp)
// Calculate expected signature
mac := hmac.New(sha256.New, []byte(hmacKey))
mac.Write([]byte(signingString))
expectedSignature := hex.EncodeToString(mac.Sum(nil))
// Compare signatures (constant time comparison to prevent timing attacks)
return hmac.Equal([]byte(signature), []byte(expectedSignature))
}

View File

@ -0,0 +1,265 @@
package middleware
import (
"net/http"
"reflect"
"strings"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"go.uber.org/zap"
)
// ValidationError represents a validation error
type ValidationError struct {
Field string `json:"field"`
Tag string `json:"tag"`
Value string `json:"value"`
Message string `json:"message"`
}
// ValidationResponse represents the validation error response
type ValidationResponse struct {
Error string `json:"error"`
Message string `json:"message"`
Details []ValidationError `json:"details,omitempty"`
}
var validate *validator.Validate
func init() {
validate = validator.New()
// Register custom tag name function to use json tags
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
}
// ValidateJSON validates JSON request body against struct validation tags
func ValidateJSON(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip validation for GET requests and requests without body
if c.Request.Method == "GET" || c.Request.ContentLength == 0 {
c.Next()
return
}
// Store original body for potential re-reading
c.Set("validation_enabled", true)
c.Next()
}
}
// ValidateStruct validates a struct and returns formatted errors
func ValidateStruct(s interface{}) []ValidationError {
var errors []ValidationError
err := validate.Struct(s)
if err != nil {
for _, err := range err.(validator.ValidationErrors) {
var element ValidationError
element.Field = err.Field()
element.Tag = err.Tag()
element.Value = err.Param()
element.Message = getErrorMessage(err)
errors = append(errors, element)
}
}
return errors
}
// ValidateAndBind validates and binds JSON request to struct
func ValidateAndBind(c *gin.Context, obj interface{}) error {
// Bind JSON to struct
if err := c.ShouldBindJSON(obj); err != nil {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Invalid JSON",
Message: "Request body contains invalid JSON: " + err.Error(),
})
return err
}
// Validate struct
if validationErrors := ValidateStruct(obj); len(validationErrors) > 0 {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Validation Failed",
Message: "Request validation failed",
Details: validationErrors,
})
return validator.ValidationErrors{}
}
return nil
}
// getErrorMessage returns a human-readable error message for validation errors
func getErrorMessage(fe validator.FieldError) string {
switch fe.Tag() {
case "required":
return "This field is required"
case "email":
return "Invalid email format"
case "min":
return "Value is too short (minimum " + fe.Param() + " characters)"
case "max":
return "Value is too long (maximum " + fe.Param() + " characters)"
case "url":
return "Invalid URL format"
case "oneof":
return "Value must be one of: " + fe.Param()
case "uuid":
return "Invalid UUID format"
case "gte":
return "Value must be greater than or equal to " + fe.Param()
case "lte":
return "Value must be less than or equal to " + fe.Param()
case "len":
return "Value must be exactly " + fe.Param() + " characters"
case "dive":
return "Invalid array element"
default:
return "Invalid value for " + fe.Field()
}
}
// RequiredFields validates that specific fields are present in the request
func RequiredFields(fields ...string) gin.HandlerFunc {
return func(c *gin.Context) {
var json map[string]interface{}
if err := c.ShouldBindJSON(&json); err != nil {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Invalid JSON",
Message: "Request body contains invalid JSON",
})
c.Abort()
return
}
var missingFields []string
for _, field := range fields {
if _, exists := json[field]; !exists {
missingFields = append(missingFields, field)
}
}
if len(missingFields) > 0 {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Missing Required Fields",
Message: "The following required fields are missing: " + strings.Join(missingFields, ", "),
})
c.Abort()
return
}
// Store the parsed JSON for use in handlers
c.Set("parsed_json", json)
c.Next()
}
}
// ValidateUUID validates that a URL parameter is a valid UUID
func ValidateUUID(param string) gin.HandlerFunc {
return func(c *gin.Context) {
value := c.Param(param)
if value == "" {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Missing Parameter",
Message: "Required parameter '" + param + "' is missing",
})
c.Abort()
return
}
// Validate UUID format
if err := validate.Var(value, "uuid"); err != nil {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Invalid Parameter",
Message: "Parameter '" + param + "' must be a valid UUID",
})
c.Abort()
return
}
c.Next()
}
}
// ValidateQueryParams validates query parameters
func ValidateQueryParams(rules map[string]string) gin.HandlerFunc {
return func(c *gin.Context) {
var errors []ValidationError
for param, rule := range rules {
value := c.Query(param)
if value != "" {
if err := validate.Var(value, rule); err != nil {
for _, err := range err.(validator.ValidationErrors) {
errors = append(errors, ValidationError{
Field: param,
Tag: err.Tag(),
Value: err.Param(),
Message: getErrorMessage(err),
})
}
}
}
}
if len(errors) > 0 {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Invalid Query Parameters",
Message: "One or more query parameters are invalid",
Details: errors,
})
c.Abort()
return
}
c.Next()
}
}
// SanitizeInput sanitizes input strings to prevent XSS and injection attacks
func SanitizeInput() gin.HandlerFunc {
return func(c *gin.Context) {
// This is a basic implementation - in production you might want to use
// a more sophisticated sanitization library like bluemonday
c.Next()
}
}
// ValidatePermissions validates that permission scopes follow the expected format
func ValidatePermissions(c *gin.Context, permissions []string) []ValidationError {
var errors []ValidationError
for i, perm := range permissions {
// Check basic format: should contain only alphanumeric, dots, and underscores
if err := validate.Var(perm, "required,min=1,max=255,alphanum|contains=.|contains=_"); err != nil {
errors = append(errors, ValidationError{
Field: "permissions[" + string(rune(i)) + "]",
Tag: "format",
Value: perm,
Message: "Permission scope must contain only alphanumeric characters, dots, and underscores",
})
}
// Check for dangerous patterns
if strings.Contains(perm, "..") || strings.HasPrefix(perm, ".") || strings.HasSuffix(perm, ".") {
errors = append(errors, ValidationError{
Field: "permissions[" + string(rune(i)) + "]",
Tag: "format",
Value: perm,
Message: "Permission scope has invalid format",
})
}
}
return errors
}