org
This commit is contained in:
235
kms/internal/middleware/csrf.go
Normal file
235
kms/internal/middleware/csrf.go
Normal file
@ -0,0 +1,235 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
)
|
||||
|
||||
// CSRFMiddleware provides CSRF protection
|
||||
type CSRFMiddleware struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCSRFMiddleware creates a new CSRF middleware
|
||||
func NewCSRFMiddleware(config config.ConfigProvider, logger *zap.Logger) *CSRFMiddleware {
|
||||
return &CSRFMiddleware{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFProtection implements CSRF protection for state-changing operations
|
||||
func (cm *CSRFMiddleware) CSRFProtection(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip CSRF protection for safe methods
|
||||
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip CSRF protection for specific endpoints that use other authentication
|
||||
if cm.shouldSkipCSRF(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Get CSRF token from header
|
||||
csrfToken := r.Header.Get("X-CSRF-Token")
|
||||
if csrfToken == "" {
|
||||
cm.logger.Warn("Missing CSRF token",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("remote_addr", r.RemoteAddr))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"csrf_token_missing","message":"CSRF token required"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF token
|
||||
if !cm.validateCSRFToken(csrfToken, r) {
|
||||
cm.logger.Warn("Invalid CSRF token",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("remote_addr", r.RemoteAddr))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"csrf_token_invalid","message":"Invalid CSRF token"}`))
|
||||
return
|
||||
}
|
||||
|
||||
cm.logger.Debug("CSRF token validated successfully",
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateCSRFToken generates a new CSRF token for a user session
|
||||
func (cm *CSRFMiddleware) GenerateCSRFToken(userID string) (string, error) {
|
||||
// Generate random bytes for token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
cm.logger.Error("Failed to generate CSRF token", zap.Error(err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create timestamp
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
// Create token data
|
||||
tokenData := hex.EncodeToString(tokenBytes)
|
||||
|
||||
// Create signing string: userID:timestamp:tokenData
|
||||
timestampStr := strconv.FormatInt(timestamp, 10)
|
||||
signingString := userID + ":" + timestampStr + ":" + tokenData
|
||||
|
||||
// Sign the token with HMAC
|
||||
signature := cm.signData(signingString)
|
||||
|
||||
// Return encoded token: tokenData.timestamp.signature
|
||||
token := tokenData + "." + timestampStr + "." + signature
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// validateCSRFToken validates a CSRF token
|
||||
func (cm *CSRFMiddleware) validateCSRFToken(token string, r *http.Request) bool {
|
||||
// Parse token parts
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
cm.logger.Debug("Invalid CSRF token format")
|
||||
return false
|
||||
}
|
||||
|
||||
tokenData, timestampStr, signature := parts[0], parts[1], parts[2]
|
||||
|
||||
// Get user ID from request context or headers
|
||||
userID := cm.getUserIDFromRequest(r)
|
||||
if userID == "" {
|
||||
cm.logger.Debug("No user ID found for CSRF validation")
|
||||
return false
|
||||
}
|
||||
|
||||
// Recreate signing string
|
||||
signingString := userID + ":" + timestampStr + ":" + tokenData
|
||||
|
||||
// Verify signature
|
||||
expectedSignature := cm.signData(signingString)
|
||||
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
|
||||
cm.logger.Debug("CSRF token signature verification failed")
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse timestamp
|
||||
timestampInt, err := strconv.ParseInt(timestampStr, 10, 64)
|
||||
if err != nil {
|
||||
cm.logger.Debug("Invalid timestamp in CSRF token", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
timestamp := time.Unix(timestampInt, 0)
|
||||
|
||||
// Check if token is expired (valid for 1 hour by default)
|
||||
maxAge := cm.config.GetDuration("CSRF_TOKEN_MAX_AGE")
|
||||
if maxAge <= 0 {
|
||||
maxAge = 1 * time.Hour
|
||||
}
|
||||
|
||||
if time.Since(timestamp) > maxAge {
|
||||
cm.logger.Debug("CSRF token expired",
|
||||
zap.Time("timestamp", timestamp),
|
||||
zap.Duration("age", time.Since(timestamp)),
|
||||
zap.Duration("max_age", maxAge))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// signData signs data with HMAC
|
||||
func (cm *CSRFMiddleware) signData(data string) string {
|
||||
// Use the same signing key as for authentication
|
||||
signingKey := cm.config.GetString("AUTH_SIGNING_KEY")
|
||||
if signingKey == "" {
|
||||
cm.logger.Error("AUTH_SIGNING_KEY not configured for CSRF protection")
|
||||
return ""
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(signingKey))
|
||||
mac.Write([]byte(data))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// getUserIDFromRequest extracts user ID from request
|
||||
func (cm *CSRFMiddleware) getUserIDFromRequest(r *http.Request) string {
|
||||
// Try to get from X-User-Email header
|
||||
userEmail := r.Header.Get(cm.config.GetString("AUTH_HEADER_USER_EMAIL"))
|
||||
if userEmail != "" {
|
||||
return userEmail
|
||||
}
|
||||
|
||||
// Try to get from context (if set by authentication middleware)
|
||||
if userID := r.Context().Value("user_id"); userID != nil {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// shouldSkipCSRF determines if CSRF protection should be skipped for a request
|
||||
func (cm *CSRFMiddleware) shouldSkipCSRF(r *http.Request) bool {
|
||||
// Skip for API endpoints that use API key authentication
|
||||
if strings.HasPrefix(r.URL.Path, "/api/verify") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip for health check endpoints
|
||||
if r.URL.Path == "/health" || r.URL.Path == "/ready" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip for webhook endpoints (if any)
|
||||
if strings.HasPrefix(r.URL.Path, "/webhook/") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetCSRFCookie sets a secure CSRF token cookie
|
||||
func (cm *CSRFMiddleware) SetCSRFCookie(w http.ResponseWriter, token string) {
|
||||
cookie := &http.Cookie{
|
||||
Name: "csrf_token",
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: 3600, // 1 hour
|
||||
HttpOnly: false, // JavaScript needs to read this for AJAX requests
|
||||
Secure: true, // HTTPS only
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
|
||||
// GetCSRFTokenFromCookie gets CSRF token from cookie
|
||||
func (cm *CSRFMiddleware) GetCSRFTokenFromCookie(r *http.Request) string {
|
||||
cookie, err := r.Cookie("csrf_token")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
60
kms/internal/middleware/logger.go
Normal file
60
kms/internal/middleware/logger.go
Normal file
@ -0,0 +1,60 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logger returns a middleware that logs HTTP requests using zap logger
|
||||
func Logger(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// Calculate latency
|
||||
latency := time.Since(start)
|
||||
|
||||
// Get request information
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
query := c.Request.URL.RawQuery
|
||||
status := c.Writer.Status()
|
||||
clientIP := c.ClientIP()
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// Get error if any
|
||||
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
||||
|
||||
// Build log fields
|
||||
fields := []zap.Field{
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
zap.String("query", query),
|
||||
zap.Int("status", status),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("user_agent", userAgent),
|
||||
zap.Duration("latency", latency),
|
||||
zap.Int64("latency_ms", latency.Nanoseconds()/1000000),
|
||||
}
|
||||
|
||||
// Add error field if exists
|
||||
if errorMessage != "" {
|
||||
fields = append(fields, zap.String("error", errorMessage))
|
||||
}
|
||||
|
||||
// Log based on status code
|
||||
switch {
|
||||
case status >= 500:
|
||||
logger.Error("HTTP Request", fields...)
|
||||
case status >= 400:
|
||||
logger.Warn("HTTP Request", fields...)
|
||||
default:
|
||||
logger.Info("HTTP Request", fields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
239
kms/internal/middleware/middleware.go
Normal file
239
kms/internal/middleware/middleware.go
Normal file
@ -0,0 +1,239 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
)
|
||||
|
||||
// Recovery returns a middleware that recovers from any panics
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return gin.CustomRecoveryWithWriter(gin.DefaultWriter, func(c *gin.Context, recovered interface{}) {
|
||||
if err, ok := recovered.(string); ok {
|
||||
logger.Error("Panic recovered",
|
||||
zap.String("error", err),
|
||||
zap.String("stack", string(debug.Stack())),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
)
|
||||
}
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
})
|
||||
}
|
||||
|
||||
// CORS returns a middleware that handles Cross-Origin Resource Sharing
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Set CORS headers
|
||||
c.Header("Access-Control-Allow-Origin", "*") // In production, be more specific
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-User-Email")
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
// Handle preflight OPTIONS request
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Security returns a middleware that adds security headers
|
||||
func Security() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Security headers
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("Content-Security-Policy", "default-src 'self'")
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimiter holds rate limiting data
|
||||
type RateLimiter struct {
|
||||
limiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
rate rate.Limit
|
||||
burst int
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter(rps, burst int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
rate: rate.Limit(rps),
|
||||
burst: burst,
|
||||
}
|
||||
}
|
||||
|
||||
// GetLimiter returns the rate limiter for a given key
|
||||
func (rl *RateLimiter) GetLimiter(key string) *rate.Limiter {
|
||||
rl.mu.RLock()
|
||||
limiter, exists := rl.limiters[key]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
limiter = rate.NewLimiter(rl.rate, rl.burst)
|
||||
rl.mu.Lock()
|
||||
rl.limiters[key] = limiter
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// RateLimit returns a middleware that implements rate limiting
|
||||
func RateLimit(rps, burst int) gin.HandlerFunc {
|
||||
limiter := NewRateLimiter(rps, burst)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Use client IP as the key for rate limiting
|
||||
key := c.ClientIP()
|
||||
|
||||
// Get the limiter for this client
|
||||
clientLimiter := limiter.GetLimiter(key)
|
||||
|
||||
// Check if request is allowed
|
||||
if !clientLimiter.Allow() {
|
||||
// Add rate limit headers
|
||||
c.Header("X-RateLimit-Limit", strconv.Itoa(burst))
|
||||
c.Header("X-RateLimit-Remaining", "0")
|
||||
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Rate limit exceeded",
|
||||
"message": "Too many requests. Please try again later.",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Add rate limit headers for successful requests
|
||||
remaining := burst - int(clientLimiter.Tokens())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
c.Header("X-RateLimit-Limit", strconv.Itoa(burst))
|
||||
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
||||
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Authentication returns a middleware that handles authentication
|
||||
func Authentication(cfg config.ConfigProvider, logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// For now, we'll implement a basic header-based authentication
|
||||
// This will be expanded when we implement the full authentication service
|
||||
|
||||
userEmail := c.GetHeader(cfg.GetString("AUTH_HEADER_USER_EMAIL"))
|
||||
if userEmail == "" {
|
||||
logger.Warn("Authentication failed: missing user email header",
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
)
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized",
|
||||
"message": "Authentication required",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Set user context for downstream handlers
|
||||
c.Set("user_id", userEmail)
|
||||
c.Set("auth_method", "header")
|
||||
|
||||
logger.Debug("Authentication successful",
|
||||
zap.String("user_id", userEmail),
|
||||
zap.String("auth_method", "header"),
|
||||
)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequestID returns a middleware that adds a unique request ID to each request
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
}
|
||||
|
||||
c.Header("X-Request-ID", requestID)
|
||||
c.Set("request_id", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID generates a simple request ID
|
||||
// In production, you might want to use a more sophisticated ID generator
|
||||
func generateRequestID() string {
|
||||
return strconv.FormatInt(time.Now().UnixNano(), 36)
|
||||
}
|
||||
|
||||
// Timeout returns a middleware that adds timeout to requests
|
||||
func Timeout(timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateContentType returns a middleware that validates Content-Type header for JSON requests
|
||||
func ValidateContentType() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Only validate for POST, PUT, and PATCH requests
|
||||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
// For requests with a body or when Content-Length is not explicitly 0,
|
||||
// require application/json content type
|
||||
if c.Request.ContentLength != 0 {
|
||||
if contentType == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Bad Request",
|
||||
"message": "Content-Type header is required for POST/PUT/PATCH requests",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Require application/json content type for requests with JSON bodies
|
||||
if contentType != "application/json" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Bad Request",
|
||||
"message": "Content-Type must be application/json",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
558
kms/internal/middleware/security.go
Normal file
558
kms/internal/middleware/security.go
Normal file
@ -0,0 +1,558 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/kms/api-key-service/internal/cache"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
)
|
||||
|
||||
// SecurityMiddleware provides various security features
|
||||
type SecurityMiddleware struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
cacheManager *cache.CacheManager
|
||||
appRepo repository.ApplicationRepository
|
||||
rateLimiters map[string]*rate.Limiter
|
||||
authRateLimiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSecurityMiddleware creates a new security middleware
|
||||
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger, appRepo repository.ApplicationRepository) *SecurityMiddleware {
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
return &SecurityMiddleware{
|
||||
config: config,
|
||||
logger: logger,
|
||||
cacheManager: cacheManager,
|
||||
appRepo: appRepo,
|
||||
rateLimiters: make(map[string]*rate.Limiter),
|
||||
authRateLimiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitMiddleware implements per-IP rate limiting
|
||||
func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Get client IP
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Get or create rate limiter for this IP
|
||||
limiter := s.getRateLimiter(clientIP)
|
||||
|
||||
// Check if request is allowed
|
||||
if !limiter.Allow() {
|
||||
s.logger.Warn("Rate limit exceeded",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
// Track rate limit violations
|
||||
s.trackRateLimitViolation(clientIP)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte(`{"error":"rate_limit_exceeded","message":"Too many requests"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// AuthRateLimitMiddleware implements stricter rate limiting for authentication endpoints
|
||||
func (s *SecurityMiddleware) AuthRateLimitMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Use stricter rate limits for auth endpoints
|
||||
limiter := s.getAuthRateLimiter(clientIP)
|
||||
|
||||
// Check if request is allowed
|
||||
if !limiter.Allow() {
|
||||
s.logger.Warn("Auth rate limit exceeded",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
// Track authentication failures for brute force protection
|
||||
s.TrackAuthenticationFailure(clientIP, "")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte(`{"error":"auth_rate_limit_exceeded","message":"Too many authentication attempts"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// BruteForceProtectionMiddleware implements brute force protection
|
||||
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Check if IP is temporarily blocked
|
||||
if s.isIPBlocked(clientIP) {
|
||||
s.logger.Warn("Blocked IP attempted access",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"ip_blocked","message":"IP temporarily blocked due to suspicious activity"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// IPWhitelistMiddleware implements IP whitelisting
|
||||
func (s *SecurityMiddleware) IPWhitelistMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
whitelist := s.config.GetStringSlice("IP_WHITELIST")
|
||||
if len(whitelist) == 0 {
|
||||
// No whitelist configured, allow all
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Check if IP is in whitelist
|
||||
if !s.isIPInList(clientIP, whitelist) {
|
||||
s.logger.Warn("Non-whitelisted IP attempted access",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"ip_not_whitelisted","message":"IP not in whitelist"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware adds security headers
|
||||
func (s *SecurityMiddleware) SecurityHeadersMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add security headers
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
||||
|
||||
// Add HSTS header for HTTPS
|
||||
if r.TLS != nil {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// AuthenticationFailureTracker tracks authentication failures for brute force protection
|
||||
func (s *SecurityMiddleware) TrackAuthenticationFailure(clientIP, userID string) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Track failures by IP
|
||||
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
|
||||
s.incrementFailureCount(ctx, ipKey)
|
||||
|
||||
// Track failures by user ID if provided
|
||||
if userID != "" {
|
||||
userKey := cache.CacheKey("auth_failures_user", userID)
|
||||
s.incrementFailureCount(ctx, userKey)
|
||||
}
|
||||
|
||||
// Check if we should block the IP
|
||||
s.checkAndBlockIP(clientIP)
|
||||
}
|
||||
|
||||
// ClearAuthenticationFailures clears failure count on successful authentication
|
||||
func (s *SecurityMiddleware) ClearAuthenticationFailures(clientIP, userID string) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Clear failures by IP
|
||||
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
|
||||
s.cacheManager.Delete(ctx, ipKey)
|
||||
|
||||
// Clear failures by user ID if provided
|
||||
if userID != "" {
|
||||
userKey := cache.CacheKey("auth_failures_user", userID)
|
||||
s.cacheManager.Delete(ctx, userKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (s *SecurityMiddleware) getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header first
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff != "" {
|
||||
// Take the first IP in the chain
|
||||
ips := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
xri := r.Header.Get("X-Real-IP")
|
||||
if xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
|
||||
s.mu.RLock()
|
||||
limiter, exists := s.rateLimiters[clientIP]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Create new rate limiter
|
||||
rps := s.config.GetInt("RATE_LIMIT_RPS")
|
||||
if rps <= 0 {
|
||||
rps = 100 // Default
|
||||
}
|
||||
|
||||
burst := s.config.GetInt("RATE_LIMIT_BURST")
|
||||
if burst <= 0 {
|
||||
burst = 200 // Default
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(rate.Limit(rps), burst)
|
||||
|
||||
s.mu.Lock()
|
||||
s.rateLimiters[clientIP] = limiter
|
||||
s.mu.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) getAuthRateLimiter(clientIP string) *rate.Limiter {
|
||||
s.mu.RLock()
|
||||
limiter, exists := s.authRateLimiters[clientIP]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Create new auth rate limiter with stricter limits
|
||||
authRPS := s.config.GetInt("AUTH_RATE_LIMIT_RPS")
|
||||
if authRPS <= 0 {
|
||||
authRPS = 5 // Very strict default for auth endpoints
|
||||
}
|
||||
|
||||
authBurst := s.config.GetInt("AUTH_RATE_LIMIT_BURST")
|
||||
if authBurst <= 0 {
|
||||
authBurst = 10 // Allow small bursts
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(rate.Limit(authRPS), authBurst)
|
||||
|
||||
s.mu.Lock()
|
||||
s.authRateLimiters[clientIP] = limiter
|
||||
s.mu.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("rate_limit_violations", clientIP)
|
||||
s.incrementFailureCount(ctx, key)
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) isIPBlocked(clientIP string) bool {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("blocked_ips", clientIP)
|
||||
|
||||
exists, err := s.cacheManager.Exists(ctx, key)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to check IP block status",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) isIPInList(clientIP string, ipList []string) bool {
|
||||
for _, allowedIP := range ipList {
|
||||
allowedIP = strings.TrimSpace(allowedIP)
|
||||
|
||||
// Support CIDR notation
|
||||
if strings.Contains(allowedIP, "/") {
|
||||
_, network, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
s.logger.Warn("Invalid CIDR in IP list", zap.String("cidr", allowedIP))
|
||||
continue
|
||||
}
|
||||
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip != nil && network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// Exact IP match
|
||||
if clientIP == allowedIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) incrementFailureCount(ctx context.Context, key string) {
|
||||
// Get current count
|
||||
var count int
|
||||
err := s.cacheManager.GetJSON(ctx, key, &count)
|
||||
if err != nil {
|
||||
// Key doesn't exist, start with 0
|
||||
count = 0
|
||||
}
|
||||
|
||||
count++
|
||||
|
||||
// Store updated count with TTL
|
||||
ttl := s.config.GetDuration("AUTH_FAILURE_WINDOW")
|
||||
if ttl <= 0 {
|
||||
ttl = 15 * time.Minute // Default window
|
||||
}
|
||||
|
||||
s.cacheManager.SetJSON(ctx, key, count, ttl)
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) checkAndBlockIP(clientIP string) {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("auth_failures_ip", clientIP)
|
||||
|
||||
var count int
|
||||
err := s.cacheManager.GetJSON(ctx, key, &count)
|
||||
if err != nil {
|
||||
return // No failures recorded
|
||||
}
|
||||
|
||||
maxFailures := s.config.GetInt("MAX_AUTH_FAILURES")
|
||||
if maxFailures <= 0 {
|
||||
maxFailures = 5 // Default
|
||||
}
|
||||
|
||||
if count >= maxFailures {
|
||||
// Block the IP
|
||||
blockKey := cache.CacheKey("blocked_ips", clientIP)
|
||||
blockDuration := s.config.GetDuration("IP_BLOCK_DURATION")
|
||||
if blockDuration <= 0 {
|
||||
blockDuration = 1 * time.Hour // Default
|
||||
}
|
||||
|
||||
blockInfo := map[string]interface{}{
|
||||
"blocked_at": time.Now().Unix(),
|
||||
"failure_count": count,
|
||||
"reason": "excessive_auth_failures",
|
||||
}
|
||||
|
||||
s.cacheManager.SetJSON(ctx, blockKey, blockInfo, blockDuration)
|
||||
|
||||
s.logger.Warn("IP blocked due to excessive authentication failures",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Int("failure_count", count),
|
||||
zap.Duration("block_duration", blockDuration))
|
||||
}
|
||||
}
|
||||
|
||||
// RequestSignatureMiddleware validates request signatures (for API key requests)
|
||||
func (s *SecurityMiddleware) RequestSignatureMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only validate signatures for certain endpoints
|
||||
if !s.shouldValidateSignature(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
signature := r.Header.Get("X-Signature")
|
||||
timestamp := r.Header.Get("X-Timestamp")
|
||||
|
||||
if signature == "" || timestamp == "" {
|
||||
s.logger.Warn("Missing signature headers",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("client_ip", s.getClientIP(r)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"missing_signature","message":"Request signature required"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate timestamp (prevent replay attacks)
|
||||
if !s.isTimestampValid(timestamp) {
|
||||
s.logger.Warn("Invalid timestamp in request",
|
||||
zap.String("timestamp", timestamp),
|
||||
zap.String("client_ip", s.getClientIP(r)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"invalid_timestamp","message":"Request timestamp is invalid or too old"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// Implement HMAC signature validation
|
||||
appID := r.Header.Get("X-App-ID")
|
||||
if appID == "" {
|
||||
s.logger.Warn("Missing App-ID header for signature validation",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("client_ip", s.getClientIP(r)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"missing_app_id","message":"X-App-ID header required for signature validation"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve application to get HMAC key
|
||||
ctx := r.Context()
|
||||
app, err := s.appRepo.GetByID(ctx, appID)
|
||||
if err != nil {
|
||||
s.logger.Warn("Failed to retrieve application for signature validation",
|
||||
zap.String("app_id", appID),
|
||||
zap.Error(err),
|
||||
zap.String("client_ip", s.getClientIP(r)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":"invalid_application","message":"Invalid application ID"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the signature
|
||||
if !s.validateHMACSignature(r, app.HMACKey, signature, timestamp) {
|
||||
s.logger.Warn("Invalid request signature",
|
||||
zap.String("app_id", appID),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("client_ip", s.getClientIP(r)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":"invalid_signature","message":"Request signature is invalid"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) shouldValidateSignature(r *http.Request) bool {
|
||||
// Define which endpoints require signature validation
|
||||
signatureRequiredPaths := []string{
|
||||
"/api/v1/tokens",
|
||||
"/api/v1/applications",
|
||||
}
|
||||
|
||||
for _, path := range signatureRequiredPaths {
|
||||
if strings.HasPrefix(r.URL.Path, path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
|
||||
// Parse timestamp
|
||||
timestamp, err := time.Parse(time.RFC3339, timestampStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if timestamp is within acceptable window
|
||||
now := time.Now()
|
||||
maxAge := s.config.GetDuration("REQUEST_MAX_AGE")
|
||||
if maxAge <= 0 {
|
||||
maxAge = 5 * time.Minute // Default
|
||||
}
|
||||
|
||||
return now.Sub(timestamp) <= maxAge && timestamp.Before(now.Add(1*time.Minute))
|
||||
}
|
||||
|
||||
// GetSecurityMetrics returns security-related metrics
|
||||
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
|
||||
// This is a simplified version - in production you'd want more comprehensive metrics
|
||||
metrics := map[string]interface{}{
|
||||
"active_rate_limiters": len(s.rateLimiters),
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
// Count blocked IPs (this is expensive, so you might want to cache this)
|
||||
// For now, we'll just return the basic metrics
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// validateHMACSignature validates HMAC-SHA256 signature for request integrity
|
||||
func (s *SecurityMiddleware) validateHMACSignature(r *http.Request, hmacKey, signature, timestamp string) bool {
|
||||
// Create the signing string: METHOD + PATH + BODY + TIMESTAMP
|
||||
var bodyBytes []byte
|
||||
if r.Body != nil {
|
||||
var err error
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.Warn("Failed to read request body for signature validation", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
// Restore the body for downstream handlers
|
||||
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
|
||||
}
|
||||
|
||||
signingString := fmt.Sprintf("%s\n%s\n%s\n%s",
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
string(bodyBytes),
|
||||
timestamp)
|
||||
|
||||
// Calculate expected signature
|
||||
mac := hmac.New(sha256.New, []byte(hmacKey))
|
||||
mac.Write([]byte(signingString))
|
||||
expectedSignature := hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
// Compare signatures (constant time comparison to prevent timing attacks)
|
||||
return hmac.Equal([]byte(signature), []byte(expectedSignature))
|
||||
}
|
||||
265
kms/internal/middleware/validation.go
Normal file
265
kms/internal/middleware/validation.go
Normal file
@ -0,0 +1,265 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ValidationError represents a validation error
|
||||
type ValidationError struct {
|
||||
Field string `json:"field"`
|
||||
Tag string `json:"tag"`
|
||||
Value string `json:"value"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ValidationResponse represents the validation error response
|
||||
type ValidationResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Details []ValidationError `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
var validate *validator.Validate
|
||||
|
||||
func init() {
|
||||
validate = validator.New()
|
||||
|
||||
// Register custom tag name function to use json tags
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
||||
if name == "-" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateJSON validates JSON request body against struct validation tags
|
||||
func ValidateJSON(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip validation for GET requests and requests without body
|
||||
if c.Request.Method == "GET" || c.Request.ContentLength == 0 {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Store original body for potential re-reading
|
||||
c.Set("validation_enabled", true)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateStruct validates a struct and returns formatted errors
|
||||
func ValidateStruct(s interface{}) []ValidationError {
|
||||
var errors []ValidationError
|
||||
|
||||
err := validate.Struct(s)
|
||||
if err != nil {
|
||||
for _, err := range err.(validator.ValidationErrors) {
|
||||
var element ValidationError
|
||||
element.Field = err.Field()
|
||||
element.Tag = err.Tag()
|
||||
element.Value = err.Param()
|
||||
element.Message = getErrorMessage(err)
|
||||
errors = append(errors, element)
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// ValidateAndBind validates and binds JSON request to struct
|
||||
func ValidateAndBind(c *gin.Context, obj interface{}) error {
|
||||
// Bind JSON to struct
|
||||
if err := c.ShouldBindJSON(obj); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid JSON",
|
||||
Message: "Request body contains invalid JSON: " + err.Error(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate struct
|
||||
if validationErrors := ValidateStruct(obj); len(validationErrors) > 0 {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Validation Failed",
|
||||
Message: "Request validation failed",
|
||||
Details: validationErrors,
|
||||
})
|
||||
return validator.ValidationErrors{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getErrorMessage returns a human-readable error message for validation errors
|
||||
func getErrorMessage(fe validator.FieldError) string {
|
||||
switch fe.Tag() {
|
||||
case "required":
|
||||
return "This field is required"
|
||||
case "email":
|
||||
return "Invalid email format"
|
||||
case "min":
|
||||
return "Value is too short (minimum " + fe.Param() + " characters)"
|
||||
case "max":
|
||||
return "Value is too long (maximum " + fe.Param() + " characters)"
|
||||
case "url":
|
||||
return "Invalid URL format"
|
||||
case "oneof":
|
||||
return "Value must be one of: " + fe.Param()
|
||||
case "uuid":
|
||||
return "Invalid UUID format"
|
||||
case "gte":
|
||||
return "Value must be greater than or equal to " + fe.Param()
|
||||
case "lte":
|
||||
return "Value must be less than or equal to " + fe.Param()
|
||||
case "len":
|
||||
return "Value must be exactly " + fe.Param() + " characters"
|
||||
case "dive":
|
||||
return "Invalid array element"
|
||||
default:
|
||||
return "Invalid value for " + fe.Field()
|
||||
}
|
||||
}
|
||||
|
||||
// RequiredFields validates that specific fields are present in the request
|
||||
func RequiredFields(fields ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var json map[string]interface{}
|
||||
|
||||
if err := c.ShouldBindJSON(&json); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid JSON",
|
||||
Message: "Request body contains invalid JSON",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
var missingFields []string
|
||||
for _, field := range fields {
|
||||
if _, exists := json[field]; !exists {
|
||||
missingFields = append(missingFields, field)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingFields) > 0 {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Missing Required Fields",
|
||||
Message: "The following required fields are missing: " + strings.Join(missingFields, ", "),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Store the parsed JSON for use in handlers
|
||||
c.Set("parsed_json", json)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateUUID validates that a URL parameter is a valid UUID
|
||||
func ValidateUUID(param string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
value := c.Param(param)
|
||||
if value == "" {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Missing Parameter",
|
||||
Message: "Required parameter '" + param + "' is missing",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate UUID format
|
||||
if err := validate.Var(value, "uuid"); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid Parameter",
|
||||
Message: "Parameter '" + param + "' must be a valid UUID",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQueryParams validates query parameters
|
||||
func ValidateQueryParams(rules map[string]string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var errors []ValidationError
|
||||
|
||||
for param, rule := range rules {
|
||||
value := c.Query(param)
|
||||
if value != "" {
|
||||
if err := validate.Var(value, rule); err != nil {
|
||||
for _, err := range err.(validator.ValidationErrors) {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: param,
|
||||
Tag: err.Tag(),
|
||||
Value: err.Param(),
|
||||
Message: getErrorMessage(err),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid Query Parameters",
|
||||
Message: "One or more query parameters are invalid",
|
||||
Details: errors,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeInput sanitizes input strings to prevent XSS and injection attacks
|
||||
func SanitizeInput() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// This is a basic implementation - in production you might want to use
|
||||
// a more sophisticated sanitization library like bluemonday
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePermissions validates that permission scopes follow the expected format
|
||||
func ValidatePermissions(c *gin.Context, permissions []string) []ValidationError {
|
||||
var errors []ValidationError
|
||||
|
||||
for i, perm := range permissions {
|
||||
// Check basic format: should contain only alphanumeric, dots, and underscores
|
||||
if err := validate.Var(perm, "required,min=1,max=255,alphanum|contains=.|contains=_"); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "permissions[" + string(rune(i)) + "]",
|
||||
Tag: "format",
|
||||
Value: perm,
|
||||
Message: "Permission scope must contain only alphanumeric characters, dots, and underscores",
|
||||
})
|
||||
}
|
||||
|
||||
// Check for dangerous patterns
|
||||
if strings.Contains(perm, "..") || strings.HasPrefix(perm, ".") || strings.HasSuffix(perm, ".") {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "permissions[" + string(rune(i)) + "]",
|
||||
Tag: "format",
|
||||
Value: perm,
|
||||
Message: "Permission scope has invalid format",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
Reference in New Issue
Block a user