Files
skybridge/internal/services/token_service.go
2025-08-23 17:22:37 -04:00

647 lines
20 KiB
Go

package services
import (
"context"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/crypto"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
)
// tokenService implements the TokenService interface
type tokenService struct {
tokenRepo repository.StaticTokenRepository
appRepo repository.ApplicationRepository
permRepo repository.PermissionRepository
grantRepo repository.GrantedPermissionRepository
tokenGen *crypto.TokenGenerator
jwtManager *auth.JWTManager
logger *zap.Logger
}
// NewTokenService creates a new token service
func NewTokenService(
tokenRepo repository.StaticTokenRepository,
appRepo repository.ApplicationRepository,
permRepo repository.PermissionRepository,
grantRepo repository.GrantedPermissionRepository,
hmacKey string,
config config.ConfigProvider,
logger *zap.Logger,
) TokenService {
return &tokenService{
tokenRepo: tokenRepo,
appRepo: appRepo,
permRepo: permRepo,
grantRepo: grantRepo,
tokenGen: crypto.NewTokenGenerator(hmacKey),
jwtManager: auth.NewJWTManager(config, logger),
logger: logger,
}
}
// CreateStaticToken creates a new static token
func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.CreateStaticTokenRequest, userID string) (*domain.CreateStaticTokenResponse, error) {
s.logger.Info("Creating static token", zap.String("app_id", req.AppID), zap.String("user_id", userID))
// Validate application exists
app, err := s.appRepo.GetByID(ctx, req.AppID)
if err != nil {
s.logger.Error("Failed to get application", zap.Error(err), zap.String("app_id", req.AppID))
return nil, fmt.Errorf("application not found: %w", err)
}
// Validate permissions exist
validPermissions, err := s.permRepo.ValidatePermissionScopes(ctx, req.Permissions)
if err != nil {
s.logger.Error("Failed to validate permissions", zap.Error(err))
return nil, fmt.Errorf("failed to validate permissions: %w", err)
}
if len(validPermissions) != len(req.Permissions) {
s.logger.Warn("Some permissions are invalid",
zap.Strings("requested", req.Permissions),
zap.Strings("valid", validPermissions))
return nil, fmt.Errorf("some requested permissions are invalid")
}
// Generate secure token with custom prefix
tokenInfo, err := s.tokenGen.GenerateTokenWithInfoAndPrefix(app.TokenPrefix, "static")
if err != nil {
s.logger.Error("Failed to generate secure token", zap.Error(err))
return nil, fmt.Errorf("failed to generate token: %w", err)
}
tokenID := uuid.New()
now := time.Now()
// Create the token entity
token := &domain.StaticToken{
ID: tokenID,
AppID: req.AppID,
Owner: req.Owner,
KeyHash: tokenInfo.Hash,
Type: "hmac",
CreatedAt: now,
UpdatedAt: now,
}
// Save the token to the database
err = s.tokenRepo.Create(ctx, token)
if err != nil {
s.logger.Error("Failed to create token in database", zap.Error(err), zap.String("token_id", tokenID.String()))
return nil, fmt.Errorf("failed to create token: %w", err)
}
// Grant permissions to the token
var grants []*domain.GrantedPermission
for _, permScope := range validPermissions {
// Get permission by scope to get the ID
perm, err := s.permRepo.GetAvailablePermissionByScope(ctx, permScope)
if err != nil {
s.logger.Error("Failed to get permission by scope", zap.Error(err), zap.String("scope", permScope))
continue
}
grant := &domain.GrantedPermission{
ID: uuid.New(),
TokenType: domain.TokenTypeStatic,
TokenID: tokenID,
PermissionID: perm.ID,
Scope: permScope,
CreatedBy: userID,
}
grants = append(grants, grant)
}
if len(grants) > 0 {
err = s.grantRepo.GrantPermissions(ctx, grants)
if err != nil {
s.logger.Error("Failed to grant permissions", zap.Error(err))
// Clean up the token if permission granting fails
s.tokenRepo.Delete(ctx, tokenID)
return nil, fmt.Errorf("failed to grant permissions: %w", err)
}
}
response := &domain.CreateStaticTokenResponse{
ID: tokenID,
Token: tokenInfo.Token, // Return the actual token only once
Permissions: validPermissions,
CreatedAt: now,
}
s.logger.Info("Static token created successfully",
zap.String("token_id", tokenID.String()),
zap.String("app_id", app.AppID),
zap.Strings("permissions", validPermissions))
return response, nil
}
// ListByApp lists all tokens for an application
func (s *tokenService) ListByApp(ctx context.Context, appID string, limit, offset int) ([]*domain.StaticToken, error) {
s.logger.Debug("Listing tokens for application", zap.String("app_id", appID))
tokens, err := s.tokenRepo.GetByAppID(ctx, appID)
if err != nil {
s.logger.Error("Failed to list tokens from repository", zap.Error(err), zap.String("app_id", appID))
return nil, fmt.Errorf("failed to list tokens: %w", err)
}
// Apply pagination manually since GetByAppID doesn't support it
start := offset
end := offset + limit
if start > len(tokens) {
tokens = []*domain.StaticToken{}
} else if end > len(tokens) {
tokens = tokens[start:]
} else {
tokens = tokens[start:end]
}
s.logger.Debug("Listed tokens successfully", zap.String("app_id", appID), zap.Int("count", len(tokens)))
return tokens, nil
}
// Delete deletes a token
func (s *tokenService) Delete(ctx context.Context, tokenID uuid.UUID, userID string) error {
s.logger.Info("Deleting token", zap.String("token_id", tokenID.String()), zap.String("user_id", userID))
// Check if token exists
exists, err := s.tokenRepo.Exists(ctx, tokenID)
if err != nil {
s.logger.Error("Failed to check token existence", zap.Error(err), zap.String("token_id", tokenID.String()))
return err
}
if !exists {
s.logger.Error("Token not found", zap.String("token_id", tokenID.String()))
return fmt.Errorf("token with ID '%s' not found", tokenID.String())
}
// Delete the token
err = s.tokenRepo.Delete(ctx, tokenID)
if err != nil {
s.logger.Error("Failed to delete token", zap.Error(err), zap.String("token_id", tokenID.String()))
return err
}
// Revoke associated permissions when deleting a static token
err = s.grantRepo.RevokeAllPermissions(ctx, domain.TokenTypeStatic, tokenID, "system-cleanup")
if err != nil {
s.logger.Warn("Failed to revoke permissions for deleted token",
zap.String("token_id", tokenID.String()),
zap.Error(err))
// Don't fail the deletion if permission revocation fails
}
return nil
}
// GenerateUserToken generates a user token
func (s *tokenService) GenerateUserToken(ctx context.Context, appID, userID string, permissions []string) (string, error) {
s.logger.Info("Generating user token", zap.String("app_id", appID), zap.String("user_id", userID))
// Validate application exists
app, err := s.appRepo.GetByID(ctx, appID)
if err != nil {
s.logger.Error("Failed to get application", zap.Error(err), zap.String("app_id", appID))
return "", fmt.Errorf("application not found: %w", err)
}
// Validate permissions exist (if any provided)
var validPermissions []string
if len(permissions) > 0 {
validPermissions, err = s.permRepo.ValidatePermissionScopes(ctx, permissions)
if err != nil {
s.logger.Error("Failed to validate permissions", zap.Error(err))
return "", fmt.Errorf("failed to validate permissions: %w", err)
}
if len(validPermissions) != len(permissions) {
s.logger.Warn("Some permissions are invalid",
zap.Strings("requested", permissions),
zap.Strings("valid", validPermissions))
return "", fmt.Errorf("some requested permissions are invalid")
}
}
// Create user token with proper timing
now := time.Now()
userToken := &domain.UserToken{
AppID: appID,
UserID: userID,
Permissions: validPermissions,
IssuedAt: now,
ExpiresAt: now.Add(app.TokenRenewalDuration.Duration),
MaxValidAt: now.Add(app.MaxTokenDuration.Duration),
TokenType: domain.TokenTypeUser,
}
// Generate JWT token using JWT manager
jwtTokenString, err := s.jwtManager.GenerateToken(userToken)
if err != nil {
s.logger.Error("Failed to generate JWT token", zap.Error(err))
return "", fmt.Errorf("failed to generate token: %w", err)
}
// Add custom prefix wrapper for user tokens if application has one
var finalToken string
if app.TokenPrefix != "" {
// For user JWT tokens, we wrap the JWT with custom prefix
finalToken = app.TokenPrefix + "UT-" + jwtTokenString
} else {
finalToken = jwtTokenString
}
s.logger.Info("User token generated successfully",
zap.String("app_id", appID),
zap.String("user_id", userID),
zap.Strings("permissions", validPermissions),
zap.Time("expires_at", userToken.ExpiresAt),
zap.Time("max_valid_at", userToken.MaxValidAt))
return finalToken, nil
}
// detectTokenType detects the token type based on its prefix
func (s *tokenService) detectTokenType(token string, app *domain.Application) domain.TokenType {
// Check for user token pattern first (UT- suffix)
if app.TokenPrefix != "" {
userPrefix := app.TokenPrefix + "UT-"
if strings.HasPrefix(token, userPrefix) {
return domain.TokenTypeUser
}
staticPrefix := app.TokenPrefix + "T-"
if strings.HasPrefix(token, staticPrefix) {
return domain.TokenTypeStatic
}
}
// Check for custom prefix pattern in case app prefix is not set
// Look for pattern: 2-4 uppercase letters + "UT-" or "T-"
if len(token) >= 6 {
dashIndex := strings.Index(token, "-")
if dashIndex >= 3 && dashIndex <= 6 { // 2-4 chars + "T" or "UT"
prefixPart := token[:dashIndex+1]
if strings.HasSuffix(prefixPart, "UT-") {
return domain.TokenTypeUser
}
if strings.HasSuffix(prefixPart, "T-") {
return domain.TokenTypeStatic
}
}
}
// Check for default kms_ prefix
if strings.HasPrefix(token, "kms_") {
return domain.TokenTypeStatic // Default tokens are static
}
// Default to static if pattern is unclear
return domain.TokenTypeStatic
}
// VerifyToken verifies a token and returns verification response
func (s *tokenService) VerifyToken(ctx context.Context, req *domain.VerifyRequest) (*domain.VerifyResponse, error) {
// Validate request
if req.Token == "" {
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Token is required",
}, nil
}
// Validate application exists
app, err := s.appRepo.GetByID(ctx, req.AppID)
if err != nil {
s.logger.Error("Failed to get application", zap.Error(err), zap.String("app_id", req.AppID))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Invalid application",
}, nil
}
// Always auto-detect token type from prefix
tokenType := s.detectTokenType(req.Token, app)
s.logger.Debug("Auto-detected token type",
zap.String("app_id", req.AppID),
zap.String("detected_type", string(tokenType)))
s.logger.Debug("Verifying token", zap.String("app_id", req.AppID), zap.String("type", string(tokenType)))
switch tokenType {
case domain.TokenTypeStatic:
return s.verifyStaticToken(ctx, req, app)
case domain.TokenTypeUser:
return s.verifyUserToken(ctx, req, app)
default:
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Invalid token type",
}, nil
}
}
// verifyStaticToken verifies a static token
func (s *tokenService) verifyStaticToken(ctx context.Context, req *domain.VerifyRequest, app *domain.Application) (*domain.VerifyResponse, error) {
s.logger.Debug("Verifying static token", zap.String("app_id", req.AppID))
// Check token format
if !crypto.IsValidTokenFormat(req.Token) {
s.logger.Warn("Invalid token format", zap.String("app_id", req.AppID))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Invalid token format",
}, nil
}
// Try to find token by testing against all stored hashes for this app
tokens, err := s.tokenRepo.GetByAppID(ctx, req.AppID)
if err != nil {
s.logger.Error("Failed to get tokens for app", zap.Error(err), zap.String("app_id", req.AppID))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Token verification failed",
}, nil
}
var matchedToken *domain.StaticToken
for _, token := range tokens {
if s.tokenGen.VerifyToken(req.Token, token.KeyHash) {
matchedToken = token
break
}
}
if matchedToken == nil {
s.logger.Warn("Token not found or invalid", zap.String("app_id", req.AppID))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Invalid token",
}, nil
}
// Get granted permissions for this token
permissions, err := s.grantRepo.GetGrantedPermissionScopes(ctx, domain.TokenTypeStatic, matchedToken.ID)
if err != nil {
s.logger.Error("Failed to get token permissions", zap.Error(err), zap.String("token_id", matchedToken.ID.String()))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Failed to retrieve permissions",
}, nil
}
// Check specific permissions if requested
var permissionResults map[string]bool
var permitted bool = true // Default to true if no specific permissions requested
if len(req.Permissions) > 0 {
permissionResults, err = s.grantRepo.HasAnyPermission(ctx, domain.TokenTypeStatic, matchedToken.ID, req.Permissions)
if err != nil {
s.logger.Error("Failed to check specific permissions", zap.Error(err))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Failed to check permissions",
}, nil
}
// Check if all requested permissions are granted
for _, requestedPerm := range req.Permissions {
if hasPermission, exists := permissionResults[requestedPerm]; !exists || !hasPermission {
permitted = false
break
}
}
}
s.logger.Info("Static token verified successfully",
zap.String("token_id", matchedToken.ID.String()),
zap.String("app_id", req.AppID),
zap.Strings("permissions", permissions),
zap.Bool("permitted", permitted))
return &domain.VerifyResponse{
Valid: true,
Permitted: permitted,
Permissions: permissions,
PermissionResults: permissionResults,
TokenType: domain.TokenTypeStatic,
}, nil
}
// verifyUserToken verifies a user token (JWT-based)
func (s *tokenService) verifyUserToken(ctx context.Context, req *domain.VerifyRequest, app *domain.Application) (*domain.VerifyResponse, error) {
s.logger.Debug("Verifying user token", zap.String("app_id", req.AppID))
// Extract JWT token from potentially prefixed format
jwtToken := req.Token
if app.TokenPrefix != "" {
expectedPrefix := app.TokenPrefix + "UT-"
if strings.HasPrefix(req.Token, expectedPrefix) {
jwtToken = strings.TrimPrefix(req.Token, expectedPrefix)
} else {
// Token doesn't have expected prefix
s.logger.Warn("User token missing expected prefix",
zap.String("app_id", req.AppID),
zap.String("expected_prefix", expectedPrefix))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Invalid token format",
}, nil
}
}
// Check if token is revoked first
isRevoked, err := s.jwtManager.IsTokenRevoked(jwtToken)
if err != nil {
s.logger.Error("Failed to check token revocation status", zap.Error(err))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Token verification failed",
}, nil
}
if isRevoked {
s.logger.Warn("Token is revoked", zap.String("app_id", req.AppID))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Token has been revoked",
}, nil
}
// Validate JWT token
claims, err := s.jwtManager.ValidateToken(jwtToken)
if err != nil {
s.logger.Warn("JWT token validation failed", zap.Error(err), zap.String("app_id", req.AppID))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Invalid token",
}, nil
}
// Verify the token is for the correct application
if claims.AppID != req.AppID {
s.logger.Warn("Token app_id mismatch",
zap.String("expected", req.AppID),
zap.String("actual", claims.AppID))
return &domain.VerifyResponse{
Valid: false,
Permitted: false,
Error: "Token not valid for this application",
}, nil
}
// Check specific permissions if requested
var permissionResults map[string]bool
var permitted bool = true // Default to true if no specific permissions requested
if len(req.Permissions) > 0 {
permissionResults = make(map[string]bool)
// Check each requested permission against token permissions
for _, requestedPerm := range req.Permissions {
hasPermission := false
for _, tokenPerm := range claims.Permissions {
if tokenPerm == requestedPerm {
hasPermission = true
break
}
}
permissionResults[requestedPerm] = hasPermission
// If any permission is missing, set permitted to false
if !hasPermission {
permitted = false
}
}
}
// Convert timestamps
var expiresAt, maxValidAt *time.Time
if claims.ExpiresAt != nil {
expTime := claims.ExpiresAt.Time
expiresAt = &expTime
}
if claims.MaxValidAt > 0 {
maxTime := time.Unix(claims.MaxValidAt, 0)
maxValidAt = &maxTime
}
s.logger.Info("User token verified successfully",
zap.String("user_id", claims.UserID),
zap.String("app_id", req.AppID),
zap.Strings("permissions", claims.Permissions),
zap.Bool("permitted", permitted))
return &domain.VerifyResponse{
Valid: true,
Permitted: permitted,
UserID: claims.UserID,
Permissions: claims.Permissions,
PermissionResults: permissionResults,
ExpiresAt: expiresAt,
MaxValidAt: maxValidAt,
TokenType: domain.TokenTypeUser,
Claims: claims.Claims,
}, nil
}
// RenewUserToken renews a user token
func (s *tokenService) RenewUserToken(ctx context.Context, req *domain.RenewRequest) (*domain.RenewResponse, error) {
s.logger.Info("Renewing user token", zap.String("app_id", req.AppID), zap.String("user_id", req.UserID))
// Get application to validate against and get HMAC key
app, err := s.appRepo.GetByID(ctx, req.AppID)
if err != nil {
s.logger.Error("Failed to get application for token renewal", zap.Error(err), zap.String("app_id", req.AppID))
return &domain.RenewResponse{
Error: "invalid_application",
}, nil
}
// Validate current token
currentToken, err := s.tokenProvider.ValidateUserToken(ctx, req.Token, app.HMACKey)
if err != nil {
s.logger.Warn("Invalid token for renewal", zap.Error(err), zap.String("app_id", req.AppID), zap.String("user_id", req.UserID))
return &domain.RenewResponse{
Error: "invalid_token",
}, nil
}
// Verify token belongs to the requested user
if currentToken.UserID != req.UserID {
s.logger.Warn("Token user ID mismatch during renewal",
zap.String("expected", req.UserID),
zap.String("actual", currentToken.UserID))
return &domain.RenewResponse{
Error: "invalid_token",
}, nil
}
// Check if token is still within its maximum validity period
if time.Now().After(currentToken.MaxValidAt) {
s.logger.Warn("Token is past maximum validity period",
zap.String("user_id", req.UserID),
zap.Time("max_valid_at", currentToken.MaxValidAt))
return &domain.RenewResponse{
Error: "token_expired",
}, nil
}
// Generate new token with extended expiry but same max valid date and permissions
newToken := &domain.UserToken{
AppID: req.AppID,
UserID: req.UserID,
Permissions: currentToken.Permissions,
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(app.TokenRenewalDuration)),
MaxValidAt: currentToken.MaxValidAt, // Keep original max validity
TokenType: domain.TokenTypeUser,
Claims: currentToken.Claims,
}
// Ensure the new expiry doesn't exceed max valid date
if newToken.ExpiresAt.After(newToken.MaxValidAt) {
newToken.ExpiresAt = newToken.MaxValidAt
}
// Generate the actual JWT token
tokenString, err := s.tokenProvider.GenerateUserToken(ctx, newToken, app.HMACKey)
if err != nil {
s.logger.Error("Failed to generate renewed token", zap.Error(err), zap.String("user_id", req.UserID))
return &domain.RenewResponse{
Error: "token_generation_failed",
}, nil
}
response := &domain.RenewResponse{
Token: tokenString,
ExpiresAt: newToken.ExpiresAt,
MaxValidAt: newToken.MaxValidAt,
}
return response, nil
}