Files
skybridge/internal/services/session_service.go
2025-08-22 18:57:40 -04:00

415 lines
12 KiB
Go

package services
import (
"context"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/repository"
)
// sessionService implements the SessionService interface
type sessionService struct {
sessionRepo repository.SessionRepository
appRepo repository.ApplicationRepository
config config.ConfigProvider
logger *zap.Logger
}
// NewSessionService creates a new session service
func NewSessionService(
sessionRepo repository.SessionRepository,
appRepo repository.ApplicationRepository,
config config.ConfigProvider,
logger *zap.Logger,
) SessionService {
return &sessionService{
sessionRepo: sessionRepo,
appRepo: appRepo,
config: config,
logger: logger,
}
}
// CreateSession creates a new user session
func (s *sessionService) CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error) {
s.logger.Debug("Creating new session",
zap.String("user_id", req.UserID),
zap.String("app_id", req.AppID),
zap.String("session_type", string(req.SessionType)))
// Validate application exists
app, err := s.appRepo.GetByID(ctx, req.AppID)
if err != nil {
if errors.IsNotFound(err) {
return nil, errors.NewValidationError("Application not found")
}
return nil, err
}
// Check if application supports user tokens
supportsUser := false
for _, appType := range app.Type {
if appType == domain.ApplicationTypeUser {
supportsUser = true
break
}
}
if !supportsUser {
return nil, errors.NewValidationError("Application does not support user sessions")
}
// Create session object
session := &domain.UserSession{
ID: uuid.New(),
UserID: req.UserID,
AppID: req.AppID,
SessionType: req.SessionType,
Status: domain.SessionStatusActive,
IPAddress: req.IPAddress,
UserAgent: req.UserAgent,
ExpiresAt: req.ExpiresAt,
Metadata: domain.SessionMetadata{
TenantID: req.TenantID,
Permissions: req.Permissions,
Claims: req.Claims,
LoginMethod: "oauth2",
},
}
// Create session in repository
if err := s.sessionRepo.Create(ctx, session); err != nil {
s.logger.Error("Failed to create session", zap.Error(err))
return nil, err
}
s.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
return session, nil
}
// GetSession retrieves a session by its ID
func (s *sessionService) GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
s.logger.Debug("Getting session", zap.String("session_id", sessionID.String()))
session, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return nil, err
}
return session, nil
}
// GetUserSessions retrieves all sessions for a user
func (s *sessionService) GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
s.logger.Debug("Getting user sessions", zap.String("user_id", userID))
sessions, err := s.sessionRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, err
}
return sessions, nil
}
// GetUserAppSessions retrieves sessions for a specific user and application
func (s *sessionService) GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
s.logger.Debug("Getting user app sessions",
zap.String("user_id", userID),
zap.String("app_id", appID))
sessions, err := s.sessionRepo.GetByUserAndApp(ctx, userID, appID)
if err != nil {
return nil, err
}
return sessions, nil
}
// GetActiveSessions retrieves all active sessions for a user
func (s *sessionService) GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
s.logger.Debug("Getting active sessions", zap.String("user_id", userID))
sessions, err := s.sessionRepo.GetActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
return sessions, nil
}
// ListSessions retrieves sessions with filtering and pagination
func (s *sessionService) ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
s.logger.Debug("Listing sessions",
zap.String("user_id", req.UserID),
zap.String("app_id", req.AppID),
zap.Int("limit", req.Limit),
zap.Int("offset", req.Offset))
// Set default pagination if not provided
if req.Limit <= 0 {
req.Limit = 50
}
if req.Limit > 100 {
req.Limit = 100
}
response, err := s.sessionRepo.List(ctx, req)
if err != nil {
return nil, err
}
return response, nil
}
// UpdateSession updates an existing session
func (s *sessionService) UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
s.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
// Validate session exists
_, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return err
}
// Update session
if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
return err
}
s.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
return nil
}
// UpdateSessionActivity updates the last activity timestamp for a session
func (s *sessionService) UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error {
s.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
return err
}
return nil
}
// RevokeSession revokes a specific session
func (s *sessionService) RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
s.logger.Debug("Revoking session",
zap.String("session_id", sessionID.String()),
zap.String("revoked_by", revokedBy))
// Validate session exists and is active
session, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return err
}
if session.Status != domain.SessionStatusActive {
return errors.NewValidationError("Session is not active")
}
// Revoke session
if err := s.sessionRepo.Revoke(ctx, sessionID, revokedBy); err != nil {
return err
}
s.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
return nil
}
// RevokeUserSessions revokes all sessions for a user
func (s *sessionService) RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error {
s.logger.Debug("Revoking user sessions",
zap.String("user_id", userID),
zap.String("revoked_by", revokedBy))
if err := s.sessionRepo.RevokeAllByUser(ctx, userID, revokedBy); err != nil {
return err
}
s.logger.Debug("User sessions revoked successfully", zap.String("user_id", userID))
return nil
}
// RevokeUserAppSessions revokes all sessions for a user and application
func (s *sessionService) RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error {
s.logger.Debug("Revoking user app sessions",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.String("revoked_by", revokedBy))
if err := s.sessionRepo.RevokeAllByUserAndApp(ctx, userID, appID, revokedBy); err != nil {
return err
}
s.logger.Debug("User app sessions revoked successfully",
zap.String("user_id", userID),
zap.String("app_id", appID))
return nil
}
// ValidateSession validates if a session is active and valid
func (s *sessionService) ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
s.logger.Debug("Validating session", zap.String("session_id", sessionID.String()))
session, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return nil, err
}
// Check if session is active
if !session.IsActive() {
if session.IsExpired() {
return nil, errors.NewAuthenticationError("Session has expired")
}
if session.IsRevoked() {
return nil, errors.NewAuthenticationError("Session has been revoked")
}
return nil, errors.NewAuthenticationError("Session is not active")
}
// Update last activity
if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
s.logger.Warn("Failed to update session activity", zap.Error(err))
// Don't fail validation if we can't update activity
}
s.logger.Debug("Session validated successfully", zap.String("session_id", sessionID.String()))
return session, nil
}
// RefreshSession refreshes a session's expiration time
func (s *sessionService) RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error {
s.logger.Debug("Refreshing session",
zap.String("session_id", sessionID.String()),
zap.Time("new_expiration", newExpiration))
// Validate session exists and is active
session, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return err
}
if !session.IsActive() {
return errors.NewValidationError("Cannot refresh inactive session")
}
// Update expiration
updates := &domain.UpdateSessionRequest{
ExpiresAt: &newExpiration,
}
if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
return err
}
s.logger.Debug("Session refreshed successfully", zap.String("session_id", sessionID.String()))
return nil
}
// CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones
func (s *sessionService) CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error) {
s.logger.Debug("Cleaning up expired sessions")
// Mark expired sessions
expired, err = s.sessionRepo.ExpireOldSessions(ctx)
if err != nil {
s.logger.Error("Failed to expire old sessions", zap.Error(err))
return 0, 0, err
}
// Delete old expired sessions if requested
if deleteOlderThan != nil {
deleted, err = s.sessionRepo.DeleteExpiredSessions(ctx, *deleteOlderThan)
if err != nil {
s.logger.Error("Failed to delete expired sessions", zap.Error(err))
return expired, 0, err
}
}
s.logger.Debug("Session cleanup completed",
zap.Int("expired", expired),
zap.Int("deleted", deleted))
return expired, deleted, nil
}
// GetSessionStats returns session statistics for a user
func (s *sessionService) GetSessionStats(ctx context.Context, userID string) (total int, active int, err error) {
s.logger.Debug("Getting session stats", zap.String("user_id", userID))
total, err = s.sessionRepo.GetSessionCount(ctx, userID)
if err != nil {
return 0, 0, err
}
active, err = s.sessionRepo.GetActiveSessionCount(ctx, userID)
if err != nil {
return 0, 0, err
}
return total, active, nil
}
// CreateOAuth2Session creates a session from OAuth2 authentication flow
func (s *sessionService) CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error) {
s.logger.Debug("Creating OAuth2 session",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.String("session_type", string(sessionType)))
// Validate application exists
app, err := s.appRepo.GetByID(ctx, appID)
if err != nil {
if errors.IsNotFound(err) {
return nil, errors.NewValidationError("Application not found")
}
return nil, err
}
// Calculate expiration based on token response
expiresAt := time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second)
// Use application's max token duration if shorter
maxExpiration := time.Now().Add(app.MaxTokenDuration)
if expiresAt.After(maxExpiration) {
expiresAt = maxExpiration
}
// Create session object
session := &domain.UserSession{
ID: uuid.New(),
UserID: userID,
AppID: appID,
SessionType: sessionType,
Status: domain.SessionStatusActive,
AccessToken: tokenResponse.AccessToken, // In production, encrypt this
RefreshToken: tokenResponse.RefreshToken, // In production, encrypt this
IDToken: tokenResponse.IDToken, // In production, encrypt this
IPAddress: ipAddress,
UserAgent: userAgent,
ExpiresAt: expiresAt,
Metadata: domain.SessionMetadata{
LoginMethod: "oauth2",
Claims: map[string]string{
"sub": userInfo.Sub,
"email": userInfo.Email,
"name": userInfo.Name,
},
},
}
// Create session in repository
if err := s.sessionRepo.Create(ctx, session); err != nil {
s.logger.Error("Failed to create OAuth2 session", zap.Error(err))
return nil, err
}
s.logger.Debug("OAuth2 session created successfully", zap.String("session_id", session.ID.String()))
return session, nil
}