Files
skybridge/internal/middleware/middleware.go
2025-08-26 19:15:37 -04:00

243 lines
6.8 KiB
Go

package middleware
import (
"context"
"net/http"
"runtime/debug"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"golang.org/x/time/rate"
"github.com/kms/api-key-service/internal/config"
)
// Recovery returns a middleware that recovers from any panics
func Recovery(logger *zap.Logger) gin.HandlerFunc {
return gin.CustomRecoveryWithWriter(gin.DefaultWriter, func(c *gin.Context, recovered interface{}) {
if err, ok := recovered.(string); ok {
logger.Error("Panic recovered",
zap.String("error", err),
zap.String("stack", string(debug.Stack())),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
)
}
c.AbortWithStatus(http.StatusInternalServerError)
})
}
// CORS returns a middleware that handles Cross-Origin Resource Sharing
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// Set CORS headers
c.Header("Access-Control-Allow-Origin", "*") // In production, be more specific
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-User-Email")
c.Header("Access-Control-Expose-Headers", "Content-Length")
c.Header("Access-Control-Max-Age", "86400")
// Handle preflight OPTIONS request
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// Security returns a middleware that adds security headers
func Security() gin.HandlerFunc {
return func(c *gin.Context) {
// Security headers
c.Header("X-Frame-Options", "DENY")
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
// Set Content Security Policy - more permissive for test pages in development
csp := "default-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline' 'unsafe-eval'"
c.Header("Content-Security-Policy", csp)
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
c.Next()
}
}
// RateLimiter holds rate limiting data
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
rate rate.Limit
burst int
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(rps, burst int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
rate: rate.Limit(rps),
burst: burst,
}
}
// GetLimiter returns the rate limiter for a given key
func (rl *RateLimiter) GetLimiter(key string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.limiters[key]
rl.mu.RUnlock()
if !exists {
limiter = rate.NewLimiter(rl.rate, rl.burst)
rl.mu.Lock()
rl.limiters[key] = limiter
rl.mu.Unlock()
}
return limiter
}
// RateLimit returns a middleware that implements rate limiting
func RateLimit(rps, burst int) gin.HandlerFunc {
limiter := NewRateLimiter(rps, burst)
return func(c *gin.Context) {
// Use client IP as the key for rate limiting
key := c.ClientIP()
// Get the limiter for this client
clientLimiter := limiter.GetLimiter(key)
// Check if request is allowed
if !clientLimiter.Allow() {
// Add rate limit headers
c.Header("X-RateLimit-Limit", strconv.Itoa(burst))
c.Header("X-RateLimit-Remaining", "0")
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"message": "Too many requests. Please try again later.",
})
c.Abort()
return
}
// Add rate limit headers for successful requests
remaining := burst - int(clientLimiter.Tokens())
if remaining < 0 {
remaining = 0
}
c.Header("X-RateLimit-Limit", strconv.Itoa(burst))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
c.Next()
}
}
// Authentication returns a middleware that handles authentication
func Authentication(cfg config.ConfigProvider, logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
// For now, we'll implement a basic header-based authentication
// This will be expanded when we implement the full authentication service
userEmail := c.GetHeader(cfg.GetString("AUTH_HEADER_USER_EMAIL"))
if userEmail == "" {
logger.Warn("Authentication failed: missing user email header",
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
)
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized",
"message": "Authentication required",
})
c.Abort()
return
}
// Set user context for downstream handlers
c.Set("user_id", userEmail)
c.Set("auth_method", "header")
logger.Debug("Authentication successful",
zap.String("user_id", userEmail),
zap.String("auth_method", "header"),
)
c.Next()
}
}
// RequestID returns a middleware that adds a unique request ID to each request
func RequestID() gin.HandlerFunc {
return func(c *gin.Context) {
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
c.Header("X-Request-ID", requestID)
c.Set("request_id", requestID)
c.Next()
}
}
// generateRequestID generates a simple request ID
// In production, you might want to use a more sophisticated ID generator
func generateRequestID() string {
return strconv.FormatInt(time.Now().UnixNano(), 36)
}
// Timeout returns a middleware that adds timeout to requests
func Timeout(timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
// ValidateContentType returns a middleware that validates Content-Type header for JSON requests
func ValidateContentType() gin.HandlerFunc {
return func(c *gin.Context) {
// Only validate for POST, PUT, and PATCH requests
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
contentType := c.GetHeader("Content-Type")
// For requests with a body or when Content-Length is not explicitly 0,
// require application/json content type
if c.Request.ContentLength != 0 {
if contentType == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Bad Request",
"message": "Content-Type header is required for POST/PUT/PATCH requests",
})
c.Abort()
return
}
// Require application/json content type for requests with JSON bodies
if contentType != "application/json" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Bad Request",
"message": "Content-Type must be application/json",
})
c.Abort()
return
}
}
}
c.Next()
}
}