package auth import ( "context" "crypto/rand" "encoding/base64" "fmt" "time" "github.com/golang-jwt/jwt/v5" "go.uber.org/zap" "github.com/kms/api-key-service/internal/cache" "github.com/kms/api-key-service/internal/config" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/errors" ) // JWTManager handles JWT token operations type JWTManager struct { config config.ConfigProvider logger *zap.Logger cacheManager *cache.CacheManager } // NewJWTManager creates a new JWT manager func NewJWTManager(config config.ConfigProvider, logger *zap.Logger) *JWTManager { cacheManager := cache.NewCacheManager(config, logger) return &JWTManager{ config: config, logger: logger, cacheManager: cacheManager, } } // CustomClaims represents the custom claims in our JWT tokens type CustomClaims struct { UserID string `json:"user_id"` AppID string `json:"app_id"` Permissions []string `json:"permissions"` TokenType domain.TokenType `json:"token_type"` MaxValidAt int64 `json:"max_valid_at"` Claims map[string]string `json:"claims,omitempty"` jwt.RegisteredClaims } // GenerateToken generates a new JWT token for a user func (j *JWTManager) GenerateToken(userToken *domain.UserToken) (string, error) { j.logger.Debug("Generating JWT token", zap.String("user_id", userToken.UserID), zap.String("app_id", userToken.AppID), zap.Strings("permissions", userToken.Permissions)) // Get JWT secret from config jwtSecret := j.config.GetJWTSecret() if jwtSecret == "" { return "", errors.NewValidationError("JWT secret not configured") } // Create custom claims claims := CustomClaims{ UserID: userToken.UserID, AppID: userToken.AppID, Permissions: userToken.Permissions, TokenType: userToken.TokenType, MaxValidAt: userToken.MaxValidAt.Unix(), Claims: userToken.Claims, RegisteredClaims: jwt.RegisteredClaims{ Issuer: "kms-api-service", Subject: userToken.UserID, Audience: []string{userToken.AppID}, ExpiresAt: jwt.NewNumericDate(userToken.ExpiresAt), IssuedAt: jwt.NewNumericDate(userToken.IssuedAt), NotBefore: jwt.NewNumericDate(userToken.IssuedAt), ID: j.generateJTI(), }, } // Create token with claims token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) // Sign token with secret tokenString, err := token.SignedString([]byte(jwtSecret)) if err != nil { j.logger.Error("Failed to sign JWT token", zap.Error(err)) return "", errors.NewInternalError("Failed to generate token").WithInternal(err) } j.logger.Debug("JWT token generated successfully", zap.String("user_id", userToken.UserID), zap.String("app_id", userToken.AppID)) return tokenString, nil } // ValidateToken validates and parses a JWT token func (j *JWTManager) ValidateToken(tokenString string) (*CustomClaims, error) { j.logger.Debug("Validating JWT token") // Get JWT secret from config jwtSecret := j.config.GetJWTSecret() if jwtSecret == "" { return nil, errors.NewValidationError("JWT secret not configured") } // Parse token with custom claims token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { // Validate signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(jwtSecret), nil }) if err != nil { j.logger.Warn("Failed to parse JWT token", zap.Error(err)) return nil, errors.NewAuthenticationError("Invalid token").WithInternal(err) } // Extract custom claims claims, ok := token.Claims.(*CustomClaims) if !ok || !token.Valid { j.logger.Warn("Invalid JWT token claims") return nil, errors.NewAuthenticationError("Invalid token claims") } // Check if token is expired beyond max valid time if time.Now().Unix() > claims.MaxValidAt { j.logger.Warn("JWT token expired beyond max valid time", zap.Int64("max_valid_at", claims.MaxValidAt), zap.Int64("current_time", time.Now().Unix())) return nil, errors.NewAuthenticationError("Token expired beyond maximum validity") } j.logger.Debug("JWT token validated successfully", zap.String("user_id", claims.UserID), zap.String("app_id", claims.AppID)) return claims, nil } // RefreshToken generates a new token with updated expiration func (j *JWTManager) RefreshToken(oldTokenString string, newExpiration time.Time) (string, error) { j.logger.Debug("Refreshing JWT token") // Validate the old token first claims, err := j.ValidateToken(oldTokenString) if err != nil { return "", err } // Check if we can still refresh (not past max valid time) if time.Now().Unix() > claims.MaxValidAt { return "", errors.NewAuthenticationError("Token cannot be refreshed - past maximum validity") } // Create new user token with updated expiration userToken := &domain.UserToken{ AppID: claims.AppID, UserID: claims.UserID, Permissions: claims.Permissions, IssuedAt: time.Now(), ExpiresAt: newExpiration, MaxValidAt: time.Unix(claims.MaxValidAt, 0), TokenType: claims.TokenType, Claims: claims.Claims, } // Generate new token return j.GenerateToken(userToken) } // ExtractClaims extracts claims from a token without full validation (for expired tokens) func (j *JWTManager) ExtractClaims(tokenString string) (*CustomClaims, error) { j.logger.Debug("Extracting claims from JWT token") // Parse token without validation to extract claims token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &CustomClaims{}) if err != nil { j.logger.Warn("Failed to parse JWT token for claims extraction", zap.Error(err)) return nil, errors.NewValidationError("Invalid token format").WithInternal(err) } claims, ok := token.Claims.(*CustomClaims) if !ok { j.logger.Warn("Invalid JWT token claims format") return nil, errors.NewValidationError("Invalid token claims format") } return claims, nil } // RevokeToken adds a token to the revocation list (blacklist) func (j *JWTManager) RevokeToken(tokenString string) error { j.logger.Debug("Revoking JWT token") // Extract claims to get token ID and expiration claims, err := j.ExtractClaims(tokenString) if err != nil { return err } // Calculate TTL for the blacklist entry (until token would naturally expire) ttl := time.Until(claims.ExpiresAt.Time) if ttl <= 0 { // Token is already expired, no need to blacklist j.logger.Debug("Token already expired, skipping blacklist", zap.String("jti", claims.ID)) return nil } // Store token ID in blacklist cache ctx := context.Background() blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID) // Store revocation info revocationInfo := map[string]interface{}{ "revoked_at": time.Now().Unix(), "user_id": claims.UserID, "app_id": claims.AppID, "reason": "manual_revocation", } if err := j.cacheManager.SetJSON(ctx, blacklistKey, revocationInfo, ttl); err != nil { j.logger.Error("Failed to blacklist token", zap.String("jti", claims.ID), zap.Error(err)) return errors.NewInternalError("Failed to revoke token").WithInternal(err) } j.logger.Info("Token successfully revoked", zap.String("jti", claims.ID), zap.String("user_id", claims.UserID), zap.String("app_id", claims.AppID), zap.Duration("ttl", ttl)) return nil } // IsTokenRevoked checks if a token has been revoked func (j *JWTManager) IsTokenRevoked(tokenString string) (bool, error) { j.logger.Debug("Checking if JWT token is revoked") // Extract claims to get token ID claims, err := j.ExtractClaims(tokenString) if err != nil { return false, err } // Check blacklist cache ctx := context.Background() blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID) exists, err := j.cacheManager.Exists(ctx, blacklistKey) if err != nil { j.logger.Error("Failed to check token blacklist", zap.String("jti", claims.ID), zap.Error(err)) // In case of cache error, we'll assume token is not revoked to avoid blocking valid requests // This could be made configurable based on security requirements return false, nil } j.logger.Debug("Token revocation check completed", zap.String("jti", claims.ID), zap.Bool("revoked", exists)) return exists, nil } // generateJTI generates a unique JWT ID func (j *JWTManager) generateJTI() string { bytes := make([]byte, 16) if _, err := rand.Read(bytes); err != nil { // Fallback to timestamp-based ID if random generation fails return fmt.Sprintf("jti_%d", time.Now().UnixNano()) } return base64.URLEncoding.EncodeToString(bytes) } // GetTokenInfo extracts token information for debugging/logging func (j *JWTManager) GetTokenInfo(tokenString string) map[string]interface{} { claims, err := j.ExtractClaims(tokenString) if err != nil { return map[string]interface{}{ "error": err.Error(), } } return map[string]interface{}{ "user_id": claims.UserID, "app_id": claims.AppID, "permissions": claims.Permissions, "token_type": claims.TokenType, "issued_at": time.Unix(claims.IssuedAt.Unix(), 0), "expires_at": time.Unix(claims.ExpiresAt.Unix(), 0), "max_valid_at": time.Unix(claims.MaxValidAt, 0), "jti": claims.ID, } }