package middleware import ( "context" "net/http" "runtime/debug" "strconv" "sync" "time" "github.com/gin-gonic/gin" "go.uber.org/zap" "golang.org/x/time/rate" "github.com/kms/api-key-service/internal/config" ) // Recovery returns a middleware that recovers from any panics func Recovery(logger *zap.Logger) gin.HandlerFunc { return gin.CustomRecoveryWithWriter(gin.DefaultWriter, func(c *gin.Context, recovered interface{}) { if err, ok := recovered.(string); ok { logger.Error("Panic recovered", zap.String("error", err), zap.String("stack", string(debug.Stack())), zap.String("path", c.Request.URL.Path), zap.String("method", c.Request.Method), ) } c.AbortWithStatus(http.StatusInternalServerError) }) } // CORS returns a middleware that handles Cross-Origin Resource Sharing func CORS() gin.HandlerFunc { return func(c *gin.Context) { // Set CORS headers c.Header("Access-Control-Allow-Origin", "*") // In production, be more specific c.Header("Access-Control-Allow-Credentials", "true") c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-User-Email") c.Header("Access-Control-Expose-Headers", "Content-Length") c.Header("Access-Control-Max-Age", "86400") // Handle preflight OPTIONS request if c.Request.Method == "OPTIONS" { c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } // Security returns a middleware that adds security headers func Security() gin.HandlerFunc { return func(c *gin.Context) { // Security headers c.Header("X-Frame-Options", "DENY") c.Header("X-Content-Type-Options", "nosniff") c.Header("X-XSS-Protection", "1; mode=block") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") // Set Content Security Policy - more permissive for test pages in development csp := "default-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline' 'unsafe-eval'" c.Header("Content-Security-Policy", csp) c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains") c.Next() } } // RateLimiter holds rate limiting data type RateLimiter struct { limiters map[string]*rate.Limiter mu sync.RWMutex rate rate.Limit burst int } // NewRateLimiter creates a new rate limiter func NewRateLimiter(rps, burst int) *RateLimiter { return &RateLimiter{ limiters: make(map[string]*rate.Limiter), rate: rate.Limit(rps), burst: burst, } } // GetLimiter returns the rate limiter for a given key func (rl *RateLimiter) GetLimiter(key string) *rate.Limiter { rl.mu.RLock() limiter, exists := rl.limiters[key] rl.mu.RUnlock() if !exists { limiter = rate.NewLimiter(rl.rate, rl.burst) rl.mu.Lock() rl.limiters[key] = limiter rl.mu.Unlock() } return limiter } // RateLimit returns a middleware that implements rate limiting func RateLimit(rps, burst int) gin.HandlerFunc { limiter := NewRateLimiter(rps, burst) return func(c *gin.Context) { // Use client IP as the key for rate limiting key := c.ClientIP() // Get the limiter for this client clientLimiter := limiter.GetLimiter(key) // Check if request is allowed if !clientLimiter.Allow() { // Add rate limit headers c.Header("X-RateLimit-Limit", strconv.Itoa(burst)) c.Header("X-RateLimit-Remaining", "0") c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10)) c.JSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded", "message": "Too many requests. Please try again later.", }) c.Abort() return } // Add rate limit headers for successful requests remaining := burst - int(clientLimiter.Tokens()) if remaining < 0 { remaining = 0 } c.Header("X-RateLimit-Limit", strconv.Itoa(burst)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10)) c.Next() } } // Authentication returns a middleware that handles authentication func Authentication(cfg config.ConfigProvider, logger *zap.Logger) gin.HandlerFunc { return func(c *gin.Context) { // For now, we'll implement a basic header-based authentication // This will be expanded when we implement the full authentication service userEmail := c.GetHeader(cfg.GetString("AUTH_HEADER_USER_EMAIL")) if userEmail == "" { logger.Warn("Authentication failed: missing user email header", zap.String("path", c.Request.URL.Path), zap.String("method", c.Request.Method), ) c.JSON(http.StatusUnauthorized, gin.H{ "error": "Unauthorized", "message": "Authentication required", }) c.Abort() return } // Set user context for downstream handlers c.Set("user_id", userEmail) c.Set("auth_method", "header") logger.Debug("Authentication successful", zap.String("user_id", userEmail), zap.String("auth_method", "header"), ) c.Next() } } // RequestID returns a middleware that adds a unique request ID to each request func RequestID() gin.HandlerFunc { return func(c *gin.Context) { requestID := c.GetHeader("X-Request-ID") if requestID == "" { requestID = generateRequestID() } c.Header("X-Request-ID", requestID) c.Set("request_id", requestID) c.Next() } } // generateRequestID generates a simple request ID // In production, you might want to use a more sophisticated ID generator func generateRequestID() string { return strconv.FormatInt(time.Now().UnixNano(), 36) } // Timeout returns a middleware that adds timeout to requests func Timeout(timeout time.Duration) gin.HandlerFunc { return func(c *gin.Context) { ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() c.Request = c.Request.WithContext(ctx) c.Next() } } // ValidateContentType returns a middleware that validates Content-Type header for JSON requests func ValidateContentType() gin.HandlerFunc { return func(c *gin.Context) { // Only validate for POST, PUT, and PATCH requests if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" { contentType := c.GetHeader("Content-Type") // For requests with a body or when Content-Length is not explicitly 0, // require application/json content type if c.Request.ContentLength != 0 { if contentType == "" { c.JSON(http.StatusBadRequest, gin.H{ "error": "Bad Request", "message": "Content-Type header is required for POST/PUT/PATCH requests", }) c.Abort() return } // Require application/json content type for requests with JSON bodies if contentType != "application/json" { c.JSON(http.StatusBadRequest, gin.H{ "error": "Bad Request", "message": "Content-Type must be application/json", }) c.Abort() return } } } c.Next() } }