243 lines
6.8 KiB
Go
243 lines
6.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/time/rate"
|
|
|
|
"github.com/kms/api-key-service/internal/config"
|
|
)
|
|
|
|
// Recovery returns a middleware that recovers from any panics
|
|
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
|
return gin.CustomRecoveryWithWriter(gin.DefaultWriter, func(c *gin.Context, recovered interface{}) {
|
|
if err, ok := recovered.(string); ok {
|
|
logger.Error("Panic recovered",
|
|
zap.String("error", err),
|
|
zap.String("stack", string(debug.Stack())),
|
|
zap.String("path", c.Request.URL.Path),
|
|
zap.String("method", c.Request.Method),
|
|
)
|
|
}
|
|
c.AbortWithStatus(http.StatusInternalServerError)
|
|
})
|
|
}
|
|
|
|
// CORS returns a middleware that handles Cross-Origin Resource Sharing
|
|
func CORS() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Set CORS headers
|
|
c.Header("Access-Control-Allow-Origin", "*") // In production, be more specific
|
|
c.Header("Access-Control-Allow-Credentials", "true")
|
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-User-Email")
|
|
c.Header("Access-Control-Expose-Headers", "Content-Length")
|
|
c.Header("Access-Control-Max-Age", "86400")
|
|
|
|
// Handle preflight OPTIONS request
|
|
if c.Request.Method == "OPTIONS" {
|
|
c.AbortWithStatus(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// Security returns a middleware that adds security headers
|
|
func Security() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Security headers
|
|
c.Header("X-Frame-Options", "DENY")
|
|
c.Header("X-Content-Type-Options", "nosniff")
|
|
c.Header("X-XSS-Protection", "1; mode=block")
|
|
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
|
|
// Set Content Security Policy - more permissive for test pages in development
|
|
csp := "default-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline' 'unsafe-eval'"
|
|
c.Header("Content-Security-Policy", csp)
|
|
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RateLimiter holds rate limiting data
|
|
type RateLimiter struct {
|
|
limiters map[string]*rate.Limiter
|
|
mu sync.RWMutex
|
|
rate rate.Limit
|
|
burst int
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter
|
|
func NewRateLimiter(rps, burst int) *RateLimiter {
|
|
return &RateLimiter{
|
|
limiters: make(map[string]*rate.Limiter),
|
|
rate: rate.Limit(rps),
|
|
burst: burst,
|
|
}
|
|
}
|
|
|
|
// GetLimiter returns the rate limiter for a given key
|
|
func (rl *RateLimiter) GetLimiter(key string) *rate.Limiter {
|
|
rl.mu.RLock()
|
|
limiter, exists := rl.limiters[key]
|
|
rl.mu.RUnlock()
|
|
|
|
if !exists {
|
|
limiter = rate.NewLimiter(rl.rate, rl.burst)
|
|
rl.mu.Lock()
|
|
rl.limiters[key] = limiter
|
|
rl.mu.Unlock()
|
|
}
|
|
|
|
return limiter
|
|
}
|
|
|
|
// RateLimit returns a middleware that implements rate limiting
|
|
func RateLimit(rps, burst int) gin.HandlerFunc {
|
|
limiter := NewRateLimiter(rps, burst)
|
|
|
|
return func(c *gin.Context) {
|
|
// Use client IP as the key for rate limiting
|
|
key := c.ClientIP()
|
|
|
|
// Get the limiter for this client
|
|
clientLimiter := limiter.GetLimiter(key)
|
|
|
|
// Check if request is allowed
|
|
if !clientLimiter.Allow() {
|
|
// Add rate limit headers
|
|
c.Header("X-RateLimit-Limit", strconv.Itoa(burst))
|
|
c.Header("X-RateLimit-Remaining", "0")
|
|
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
|
|
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
|
"error": "Rate limit exceeded",
|
|
"message": "Too many requests. Please try again later.",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Add rate limit headers for successful requests
|
|
remaining := burst - int(clientLimiter.Tokens())
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
c.Header("X-RateLimit-Limit", strconv.Itoa(burst))
|
|
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
|
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// Authentication returns a middleware that handles authentication
|
|
func Authentication(cfg config.ConfigProvider, logger *zap.Logger) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// For now, we'll implement a basic header-based authentication
|
|
// This will be expanded when we implement the full authentication service
|
|
|
|
userEmail := c.GetHeader(cfg.GetString("AUTH_HEADER_USER_EMAIL"))
|
|
if userEmail == "" {
|
|
logger.Warn("Authentication failed: missing user email header",
|
|
zap.String("path", c.Request.URL.Path),
|
|
zap.String("method", c.Request.Method),
|
|
)
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{
|
|
"error": "Unauthorized",
|
|
"message": "Authentication required",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Set user context for downstream handlers
|
|
c.Set("user_id", userEmail)
|
|
c.Set("auth_method", "header")
|
|
|
|
logger.Debug("Authentication successful",
|
|
zap.String("user_id", userEmail),
|
|
zap.String("auth_method", "header"),
|
|
)
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RequestID returns a middleware that adds a unique request ID to each request
|
|
func RequestID() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
requestID := c.GetHeader("X-Request-ID")
|
|
if requestID == "" {
|
|
requestID = generateRequestID()
|
|
}
|
|
|
|
c.Header("X-Request-ID", requestID)
|
|
c.Set("request_id", requestID)
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// generateRequestID generates a simple request ID
|
|
// In production, you might want to use a more sophisticated ID generator
|
|
func generateRequestID() string {
|
|
return strconv.FormatInt(time.Now().UnixNano(), 36)
|
|
}
|
|
|
|
// Timeout returns a middleware that adds timeout to requests
|
|
func Timeout(timeout time.Duration) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
|
defer cancel()
|
|
|
|
c.Request = c.Request.WithContext(ctx)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// ValidateContentType returns a middleware that validates Content-Type header for JSON requests
|
|
func ValidateContentType() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Only validate for POST, PUT, and PATCH requests
|
|
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
|
|
contentType := c.GetHeader("Content-Type")
|
|
|
|
// For requests with a body or when Content-Length is not explicitly 0,
|
|
// require application/json content type
|
|
if c.Request.ContentLength != 0 {
|
|
if contentType == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"error": "Bad Request",
|
|
"message": "Content-Type header is required for POST/PUT/PATCH requests",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Require application/json content type for requests with JSON bodies
|
|
if contentType != "application/json" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"error": "Bad Request",
|
|
"message": "Content-Type must be application/json",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|