Files
skybridge/internal/auth/jwt.go
2025-08-23 22:31:47 -04:00

309 lines
9.5 KiB
Go

package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/cache"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
)
// JWTManager handles JWT token operations
type JWTManager struct {
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
}
// NewJWTManager creates a new JWT manager
func NewJWTManager(config config.ConfigProvider, logger *zap.Logger) *JWTManager {
cacheManager := cache.NewCacheManager(config, logger)
return &JWTManager{
config: config,
logger: logger,
cacheManager: cacheManager,
}
}
// CustomClaims represents the custom claims in our JWT tokens
type CustomClaims struct {
UserID string `json:"user_id"`
AppID string `json:"app_id"`
Permissions []string `json:"permissions"`
TokenType domain.TokenType `json:"token_type"`
MaxValidAt int64 `json:"max_valid_at"`
Claims map[string]string `json:"claims,omitempty"`
jwt.RegisteredClaims
}
// GenerateToken generates a new JWT token for a user
func (j *JWTManager) GenerateToken(userToken *domain.UserToken) (string, error) {
j.logger.Debug("Generating JWT token",
zap.String("user_id", userToken.UserID),
zap.String("app_id", userToken.AppID),
zap.Strings("permissions", userToken.Permissions))
// Get JWT secret from config
jwtSecret := j.config.GetJWTSecret()
if jwtSecret == "" {
return "", errors.NewValidationError("JWT secret not configured")
}
// Generate secure JWT ID
jti := j.generateJTI()
if jti == "" {
return "", errors.NewInternalError("Failed to generate secure JWT ID - cryptographic random number generation failed")
}
// Create custom claims
claims := CustomClaims{
UserID: userToken.UserID,
AppID: userToken.AppID,
Permissions: userToken.Permissions,
TokenType: userToken.TokenType,
MaxValidAt: userToken.MaxValidAt.Unix(),
Claims: userToken.Claims,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "kms-api-service",
Subject: userToken.UserID,
Audience: []string{userToken.AppID},
ExpiresAt: jwt.NewNumericDate(userToken.ExpiresAt),
IssuedAt: jwt.NewNumericDate(userToken.IssuedAt),
NotBefore: jwt.NewNumericDate(userToken.IssuedAt),
ID: jti,
},
}
// Create token with claims
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Sign token with secret
tokenString, err := token.SignedString([]byte(jwtSecret))
if err != nil {
j.logger.Error("Failed to sign JWT token", zap.Error(err))
return "", errors.NewInternalError("Failed to generate token").WithInternal(err)
}
j.logger.Debug("JWT token generated successfully",
zap.String("user_id", userToken.UserID),
zap.String("app_id", userToken.AppID))
return tokenString, nil
}
// ValidateToken validates and parses a JWT token
func (j *JWTManager) ValidateToken(tokenString string) (*CustomClaims, error) {
j.logger.Debug("Validating JWT token")
// Get JWT secret from config
jwtSecret := j.config.GetJWTSecret()
if jwtSecret == "" {
return nil, errors.NewValidationError("JWT secret not configured")
}
// Parse token with custom claims
token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
// Validate signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(jwtSecret), nil
})
if err != nil {
j.logger.Warn("Failed to parse JWT token", zap.Error(err))
return nil, errors.NewAuthenticationError("Invalid token").WithInternal(err)
}
// Extract custom claims
claims, ok := token.Claims.(*CustomClaims)
if !ok || !token.Valid {
j.logger.Warn("Invalid JWT token claims")
return nil, errors.NewAuthenticationError("Invalid token claims")
}
// Check if token is expired beyond max valid time
if time.Now().Unix() > claims.MaxValidAt {
j.logger.Warn("JWT token expired beyond max valid time",
zap.Int64("max_valid_at", claims.MaxValidAt),
zap.Int64("current_time", time.Now().Unix()))
return nil, errors.NewAuthenticationError("Token expired beyond maximum validity")
}
j.logger.Debug("JWT token validated successfully",
zap.String("user_id", claims.UserID),
zap.String("app_id", claims.AppID))
return claims, nil
}
// RefreshToken generates a new token with updated expiration
func (j *JWTManager) RefreshToken(oldTokenString string, newExpiration time.Time) (string, error) {
j.logger.Debug("Refreshing JWT token")
// Validate the old token first
claims, err := j.ValidateToken(oldTokenString)
if err != nil {
return "", err
}
// Check if we can still refresh (not past max valid time)
if time.Now().Unix() > claims.MaxValidAt {
return "", errors.NewAuthenticationError("Token cannot be refreshed - past maximum validity")
}
// Create new user token with updated expiration
userToken := &domain.UserToken{
AppID: claims.AppID,
UserID: claims.UserID,
Permissions: claims.Permissions,
IssuedAt: time.Now(),
ExpiresAt: newExpiration,
MaxValidAt: time.Unix(claims.MaxValidAt, 0),
TokenType: claims.TokenType,
Claims: claims.Claims,
}
// Generate new token
return j.GenerateToken(userToken)
}
// ExtractClaims extracts claims from a token without full validation (for expired tokens)
func (j *JWTManager) ExtractClaims(tokenString string) (*CustomClaims, error) {
j.logger.Debug("Extracting claims from JWT token")
// Parse token without validation to extract claims
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &CustomClaims{})
if err != nil {
j.logger.Warn("Failed to parse JWT token for claims extraction", zap.Error(err))
return nil, errors.NewValidationError("Invalid token format").WithInternal(err)
}
claims, ok := token.Claims.(*CustomClaims)
if !ok {
j.logger.Warn("Invalid JWT token claims format")
return nil, errors.NewValidationError("Invalid token claims format")
}
return claims, nil
}
// RevokeToken adds a token to the revocation list (blacklist)
func (j *JWTManager) RevokeToken(tokenString string) error {
j.logger.Debug("Revoking JWT token")
// Extract claims to get token ID and expiration
claims, err := j.ExtractClaims(tokenString)
if err != nil {
return err
}
// Calculate TTL for the blacklist entry (until token would naturally expire)
ttl := time.Until(claims.ExpiresAt.Time)
if ttl <= 0 {
// Token is already expired, no need to blacklist
j.logger.Debug("Token already expired, skipping blacklist",
zap.String("jti", claims.ID))
return nil
}
// Store token ID in blacklist cache
ctx := context.Background()
blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID)
// Store revocation info
revocationInfo := map[string]interface{}{
"revoked_at": time.Now().Unix(),
"user_id": claims.UserID,
"app_id": claims.AppID,
"reason": "manual_revocation",
}
if err := j.cacheManager.SetJSON(ctx, blacklistKey, revocationInfo, ttl); err != nil {
j.logger.Error("Failed to blacklist token",
zap.String("jti", claims.ID),
zap.Error(err))
return errors.NewInternalError("Failed to revoke token").WithInternal(err)
}
j.logger.Info("Token successfully revoked",
zap.String("jti", claims.ID),
zap.String("user_id", claims.UserID),
zap.String("app_id", claims.AppID),
zap.Duration("ttl", ttl))
return nil
}
// IsTokenRevoked checks if a token has been revoked
func (j *JWTManager) IsTokenRevoked(tokenString string) (bool, error) {
j.logger.Debug("Checking if JWT token is revoked")
// Extract claims to get token ID
claims, err := j.ExtractClaims(tokenString)
if err != nil {
return false, err
}
// Check blacklist cache
ctx := context.Background()
blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID)
exists, err := j.cacheManager.Exists(ctx, blacklistKey)
if err != nil {
j.logger.Error("Failed to check token blacklist",
zap.String("jti", claims.ID),
zap.Error(err))
// In case of cache error, we'll assume token is not revoked to avoid blocking valid requests
// This could be made configurable based on security requirements
return false, nil
}
j.logger.Debug("Token revocation check completed",
zap.String("jti", claims.ID),
zap.Bool("revoked", exists))
return exists, nil
}
// generateJTI generates a unique JWT ID
func (j *JWTManager) generateJTI() string {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
// Log the error and fail securely - do not generate predictable fallback IDs
j.logger.Error("Cryptographic random number generation failed - cannot generate secure JWT ID", zap.Error(err))
// Return an error indicator that will cause token generation to fail
return ""
}
return base64.URLEncoding.EncodeToString(bytes)
}
// GetTokenInfo extracts token information for debugging/logging
func (j *JWTManager) GetTokenInfo(tokenString string) map[string]interface{} {
claims, err := j.ExtractClaims(tokenString)
if err != nil {
return map[string]interface{}{
"error": err.Error(),
}
}
return map[string]interface{}{
"user_id": claims.UserID,
"app_id": claims.AppID,
"permissions": claims.Permissions,
"token_type": claims.TokenType,
"issued_at": time.Unix(claims.IssuedAt.Unix(), 0),
"expires_at": time.Unix(claims.ExpiresAt.Unix(), 0),
"max_valid_at": time.Unix(claims.MaxValidAt, 0),
"jti": claims.ID,
}
}