309 lines
9.5 KiB
Go
309 lines
9.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/kms/api-key-service/internal/cache"
|
|
"github.com/kms/api-key-service/internal/config"
|
|
"github.com/kms/api-key-service/internal/domain"
|
|
"github.com/kms/api-key-service/internal/errors"
|
|
)
|
|
|
|
// JWTManager handles JWT token operations
|
|
type JWTManager struct {
|
|
config config.ConfigProvider
|
|
logger *zap.Logger
|
|
cacheManager *cache.CacheManager
|
|
}
|
|
|
|
// NewJWTManager creates a new JWT manager
|
|
func NewJWTManager(config config.ConfigProvider, logger *zap.Logger) *JWTManager {
|
|
cacheManager := cache.NewCacheManager(config, logger)
|
|
return &JWTManager{
|
|
config: config,
|
|
logger: logger,
|
|
cacheManager: cacheManager,
|
|
}
|
|
}
|
|
|
|
// CustomClaims represents the custom claims in our JWT tokens
|
|
type CustomClaims struct {
|
|
UserID string `json:"user_id"`
|
|
AppID string `json:"app_id"`
|
|
Permissions []string `json:"permissions"`
|
|
TokenType domain.TokenType `json:"token_type"`
|
|
MaxValidAt int64 `json:"max_valid_at"`
|
|
Claims map[string]string `json:"claims,omitempty"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
// GenerateToken generates a new JWT token for a user
|
|
func (j *JWTManager) GenerateToken(userToken *domain.UserToken) (string, error) {
|
|
j.logger.Debug("Generating JWT token",
|
|
zap.String("user_id", userToken.UserID),
|
|
zap.String("app_id", userToken.AppID),
|
|
zap.Strings("permissions", userToken.Permissions))
|
|
|
|
// Get JWT secret from config
|
|
jwtSecret := j.config.GetJWTSecret()
|
|
if jwtSecret == "" {
|
|
return "", errors.NewValidationError("JWT secret not configured")
|
|
}
|
|
|
|
// Generate secure JWT ID
|
|
jti := j.generateJTI()
|
|
if jti == "" {
|
|
return "", errors.NewInternalError("Failed to generate secure JWT ID - cryptographic random number generation failed")
|
|
}
|
|
|
|
// Create custom claims
|
|
claims := CustomClaims{
|
|
UserID: userToken.UserID,
|
|
AppID: userToken.AppID,
|
|
Permissions: userToken.Permissions,
|
|
TokenType: userToken.TokenType,
|
|
MaxValidAt: userToken.MaxValidAt.Unix(),
|
|
Claims: userToken.Claims,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: "kms-api-service",
|
|
Subject: userToken.UserID,
|
|
Audience: []string{userToken.AppID},
|
|
ExpiresAt: jwt.NewNumericDate(userToken.ExpiresAt),
|
|
IssuedAt: jwt.NewNumericDate(userToken.IssuedAt),
|
|
NotBefore: jwt.NewNumericDate(userToken.IssuedAt),
|
|
ID: jti,
|
|
},
|
|
}
|
|
|
|
// Create token with claims
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
|
|
// Sign token with secret
|
|
tokenString, err := token.SignedString([]byte(jwtSecret))
|
|
if err != nil {
|
|
j.logger.Error("Failed to sign JWT token", zap.Error(err))
|
|
return "", errors.NewInternalError("Failed to generate token").WithInternal(err)
|
|
}
|
|
|
|
j.logger.Debug("JWT token generated successfully",
|
|
zap.String("user_id", userToken.UserID),
|
|
zap.String("app_id", userToken.AppID))
|
|
|
|
return tokenString, nil
|
|
}
|
|
|
|
// ValidateToken validates and parses a JWT token
|
|
func (j *JWTManager) ValidateToken(tokenString string) (*CustomClaims, error) {
|
|
j.logger.Debug("Validating JWT token")
|
|
|
|
// Get JWT secret from config
|
|
jwtSecret := j.config.GetJWTSecret()
|
|
if jwtSecret == "" {
|
|
return nil, errors.NewValidationError("JWT secret not configured")
|
|
}
|
|
|
|
// Parse token with custom claims
|
|
token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
// Validate signing method
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(jwtSecret), nil
|
|
})
|
|
|
|
if err != nil {
|
|
j.logger.Warn("Failed to parse JWT token", zap.Error(err))
|
|
return nil, errors.NewAuthenticationError("Invalid token").WithInternal(err)
|
|
}
|
|
|
|
// Extract custom claims
|
|
claims, ok := token.Claims.(*CustomClaims)
|
|
if !ok || !token.Valid {
|
|
j.logger.Warn("Invalid JWT token claims")
|
|
return nil, errors.NewAuthenticationError("Invalid token claims")
|
|
}
|
|
|
|
// Check if token is expired beyond max valid time
|
|
if time.Now().Unix() > claims.MaxValidAt {
|
|
j.logger.Warn("JWT token expired beyond max valid time",
|
|
zap.Int64("max_valid_at", claims.MaxValidAt),
|
|
zap.Int64("current_time", time.Now().Unix()))
|
|
return nil, errors.NewAuthenticationError("Token expired beyond maximum validity")
|
|
}
|
|
|
|
j.logger.Debug("JWT token validated successfully",
|
|
zap.String("user_id", claims.UserID),
|
|
zap.String("app_id", claims.AppID))
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// RefreshToken generates a new token with updated expiration
|
|
func (j *JWTManager) RefreshToken(oldTokenString string, newExpiration time.Time) (string, error) {
|
|
j.logger.Debug("Refreshing JWT token")
|
|
|
|
// Validate the old token first
|
|
claims, err := j.ValidateToken(oldTokenString)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Check if we can still refresh (not past max valid time)
|
|
if time.Now().Unix() > claims.MaxValidAt {
|
|
return "", errors.NewAuthenticationError("Token cannot be refreshed - past maximum validity")
|
|
}
|
|
|
|
// Create new user token with updated expiration
|
|
userToken := &domain.UserToken{
|
|
AppID: claims.AppID,
|
|
UserID: claims.UserID,
|
|
Permissions: claims.Permissions,
|
|
IssuedAt: time.Now(),
|
|
ExpiresAt: newExpiration,
|
|
MaxValidAt: time.Unix(claims.MaxValidAt, 0),
|
|
TokenType: claims.TokenType,
|
|
Claims: claims.Claims,
|
|
}
|
|
|
|
// Generate new token
|
|
return j.GenerateToken(userToken)
|
|
}
|
|
|
|
// ExtractClaims extracts claims from a token without full validation (for expired tokens)
|
|
func (j *JWTManager) ExtractClaims(tokenString string) (*CustomClaims, error) {
|
|
j.logger.Debug("Extracting claims from JWT token")
|
|
|
|
// Parse token without validation to extract claims
|
|
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &CustomClaims{})
|
|
if err != nil {
|
|
j.logger.Warn("Failed to parse JWT token for claims extraction", zap.Error(err))
|
|
return nil, errors.NewValidationError("Invalid token format").WithInternal(err)
|
|
}
|
|
|
|
claims, ok := token.Claims.(*CustomClaims)
|
|
if !ok {
|
|
j.logger.Warn("Invalid JWT token claims format")
|
|
return nil, errors.NewValidationError("Invalid token claims format")
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// RevokeToken adds a token to the revocation list (blacklist)
|
|
func (j *JWTManager) RevokeToken(tokenString string) error {
|
|
j.logger.Debug("Revoking JWT token")
|
|
|
|
// Extract claims to get token ID and expiration
|
|
claims, err := j.ExtractClaims(tokenString)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Calculate TTL for the blacklist entry (until token would naturally expire)
|
|
ttl := time.Until(claims.ExpiresAt.Time)
|
|
if ttl <= 0 {
|
|
// Token is already expired, no need to blacklist
|
|
j.logger.Debug("Token already expired, skipping blacklist",
|
|
zap.String("jti", claims.ID))
|
|
return nil
|
|
}
|
|
|
|
// Store token ID in blacklist cache
|
|
ctx := context.Background()
|
|
blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID)
|
|
|
|
// Store revocation info
|
|
revocationInfo := map[string]interface{}{
|
|
"revoked_at": time.Now().Unix(),
|
|
"user_id": claims.UserID,
|
|
"app_id": claims.AppID,
|
|
"reason": "manual_revocation",
|
|
}
|
|
|
|
if err := j.cacheManager.SetJSON(ctx, blacklistKey, revocationInfo, ttl); err != nil {
|
|
j.logger.Error("Failed to blacklist token",
|
|
zap.String("jti", claims.ID),
|
|
zap.Error(err))
|
|
return errors.NewInternalError("Failed to revoke token").WithInternal(err)
|
|
}
|
|
|
|
j.logger.Info("Token successfully revoked",
|
|
zap.String("jti", claims.ID),
|
|
zap.String("user_id", claims.UserID),
|
|
zap.String("app_id", claims.AppID),
|
|
zap.Duration("ttl", ttl))
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsTokenRevoked checks if a token has been revoked
|
|
func (j *JWTManager) IsTokenRevoked(tokenString string) (bool, error) {
|
|
j.logger.Debug("Checking if JWT token is revoked")
|
|
|
|
// Extract claims to get token ID
|
|
claims, err := j.ExtractClaims(tokenString)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
// Check blacklist cache
|
|
ctx := context.Background()
|
|
blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID)
|
|
|
|
exists, err := j.cacheManager.Exists(ctx, blacklistKey)
|
|
if err != nil {
|
|
j.logger.Error("Failed to check token blacklist",
|
|
zap.String("jti", claims.ID),
|
|
zap.Error(err))
|
|
// In case of cache error, we'll assume token is not revoked to avoid blocking valid requests
|
|
// This could be made configurable based on security requirements
|
|
return false, nil
|
|
}
|
|
|
|
j.logger.Debug("Token revocation check completed",
|
|
zap.String("jti", claims.ID),
|
|
zap.Bool("revoked", exists))
|
|
|
|
return exists, nil
|
|
}
|
|
|
|
// generateJTI generates a unique JWT ID
|
|
func (j *JWTManager) generateJTI() string {
|
|
bytes := make([]byte, 16)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
// Log the error and fail securely - do not generate predictable fallback IDs
|
|
j.logger.Error("Cryptographic random number generation failed - cannot generate secure JWT ID", zap.Error(err))
|
|
// Return an error indicator that will cause token generation to fail
|
|
return ""
|
|
}
|
|
return base64.URLEncoding.EncodeToString(bytes)
|
|
}
|
|
|
|
// GetTokenInfo extracts token information for debugging/logging
|
|
func (j *JWTManager) GetTokenInfo(tokenString string) map[string]interface{} {
|
|
claims, err := j.ExtractClaims(tokenString)
|
|
if err != nil {
|
|
return map[string]interface{}{
|
|
"error": err.Error(),
|
|
}
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"user_id": claims.UserID,
|
|
"app_id": claims.AppID,
|
|
"permissions": claims.Permissions,
|
|
"token_type": claims.TokenType,
|
|
"issued_at": time.Unix(claims.IssuedAt.Unix(), 0),
|
|
"expires_at": time.Unix(claims.ExpiresAt.Unix(), 0),
|
|
"max_valid_at": time.Unix(claims.MaxValidAt, 0),
|
|
"jti": claims.ID,
|
|
}
|
|
}
|