415 lines
12 KiB
Go
415 lines
12 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/kms/api-key-service/internal/config"
|
|
"github.com/kms/api-key-service/internal/domain"
|
|
"github.com/kms/api-key-service/internal/errors"
|
|
"github.com/kms/api-key-service/internal/repository"
|
|
)
|
|
|
|
// sessionService implements the SessionService interface
|
|
type sessionService struct {
|
|
sessionRepo repository.SessionRepository
|
|
appRepo repository.ApplicationRepository
|
|
config config.ConfigProvider
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewSessionService creates a new session service
|
|
func NewSessionService(
|
|
sessionRepo repository.SessionRepository,
|
|
appRepo repository.ApplicationRepository,
|
|
config config.ConfigProvider,
|
|
logger *zap.Logger,
|
|
) SessionService {
|
|
return &sessionService{
|
|
sessionRepo: sessionRepo,
|
|
appRepo: appRepo,
|
|
config: config,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// CreateSession creates a new user session
|
|
func (s *sessionService) CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error) {
|
|
s.logger.Debug("Creating new session",
|
|
zap.String("user_id", req.UserID),
|
|
zap.String("app_id", req.AppID),
|
|
zap.String("session_type", string(req.SessionType)))
|
|
|
|
// Validate application exists
|
|
app, err := s.appRepo.GetByID(ctx, req.AppID)
|
|
if err != nil {
|
|
if errors.IsNotFound(err) {
|
|
return nil, errors.NewValidationError("Application not found")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Check if application supports user tokens
|
|
supportsUser := false
|
|
for _, appType := range app.Type {
|
|
if appType == domain.ApplicationTypeUser {
|
|
supportsUser = true
|
|
break
|
|
}
|
|
}
|
|
if !supportsUser {
|
|
return nil, errors.NewValidationError("Application does not support user sessions")
|
|
}
|
|
|
|
// Create session object
|
|
session := &domain.UserSession{
|
|
ID: uuid.New(),
|
|
UserID: req.UserID,
|
|
AppID: req.AppID,
|
|
SessionType: req.SessionType,
|
|
Status: domain.SessionStatusActive,
|
|
IPAddress: req.IPAddress,
|
|
UserAgent: req.UserAgent,
|
|
ExpiresAt: req.ExpiresAt,
|
|
Metadata: domain.SessionMetadata{
|
|
TenantID: req.TenantID,
|
|
Permissions: req.Permissions,
|
|
Claims: req.Claims,
|
|
LoginMethod: "oauth2",
|
|
},
|
|
}
|
|
|
|
// Create session in repository
|
|
if err := s.sessionRepo.Create(ctx, session); err != nil {
|
|
s.logger.Error("Failed to create session", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
s.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
|
|
return session, nil
|
|
}
|
|
|
|
// GetSession retrieves a session by its ID
|
|
func (s *sessionService) GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
|
|
s.logger.Debug("Getting session", zap.String("session_id", sessionID.String()))
|
|
|
|
session, err := s.sessionRepo.GetByID(ctx, sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// GetUserSessions retrieves all sessions for a user
|
|
func (s *sessionService) GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
|
s.logger.Debug("Getting user sessions", zap.String("user_id", userID))
|
|
|
|
sessions, err := s.sessionRepo.GetByUserID(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
// GetUserAppSessions retrieves sessions for a specific user and application
|
|
func (s *sessionService) GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
|
|
s.logger.Debug("Getting user app sessions",
|
|
zap.String("user_id", userID),
|
|
zap.String("app_id", appID))
|
|
|
|
sessions, err := s.sessionRepo.GetByUserAndApp(ctx, userID, appID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
// GetActiveSessions retrieves all active sessions for a user
|
|
func (s *sessionService) GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
|
s.logger.Debug("Getting active sessions", zap.String("user_id", userID))
|
|
|
|
sessions, err := s.sessionRepo.GetActiveByUserID(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
// ListSessions retrieves sessions with filtering and pagination
|
|
func (s *sessionService) ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
|
|
s.logger.Debug("Listing sessions",
|
|
zap.String("user_id", req.UserID),
|
|
zap.String("app_id", req.AppID),
|
|
zap.Int("limit", req.Limit),
|
|
zap.Int("offset", req.Offset))
|
|
|
|
// Set default pagination if not provided
|
|
if req.Limit <= 0 {
|
|
req.Limit = 50
|
|
}
|
|
if req.Limit > 100 {
|
|
req.Limit = 100
|
|
}
|
|
|
|
response, err := s.sessionRepo.List(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// UpdateSession updates an existing session
|
|
func (s *sessionService) UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
|
|
s.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
|
|
|
|
// Validate session exists
|
|
_, err := s.sessionRepo.GetByID(ctx, sessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update session
|
|
if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
|
|
return nil
|
|
}
|
|
|
|
// UpdateSessionActivity updates the last activity timestamp for a session
|
|
func (s *sessionService) UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error {
|
|
s.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
|
|
|
|
if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RevokeSession revokes a specific session
|
|
func (s *sessionService) RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
|
|
s.logger.Debug("Revoking session",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("revoked_by", revokedBy))
|
|
|
|
// Validate session exists and is active
|
|
session, err := s.sessionRepo.GetByID(ctx, sessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if session.Status != domain.SessionStatusActive {
|
|
return errors.NewValidationError("Session is not active")
|
|
}
|
|
|
|
// Revoke session
|
|
if err := s.sessionRepo.Revoke(ctx, sessionID, revokedBy); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
|
|
return nil
|
|
}
|
|
|
|
// RevokeUserSessions revokes all sessions for a user
|
|
func (s *sessionService) RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error {
|
|
s.logger.Debug("Revoking user sessions",
|
|
zap.String("user_id", userID),
|
|
zap.String("revoked_by", revokedBy))
|
|
|
|
if err := s.sessionRepo.RevokeAllByUser(ctx, userID, revokedBy); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.logger.Debug("User sessions revoked successfully", zap.String("user_id", userID))
|
|
return nil
|
|
}
|
|
|
|
// RevokeUserAppSessions revokes all sessions for a user and application
|
|
func (s *sessionService) RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error {
|
|
s.logger.Debug("Revoking user app sessions",
|
|
zap.String("user_id", userID),
|
|
zap.String("app_id", appID),
|
|
zap.String("revoked_by", revokedBy))
|
|
|
|
if err := s.sessionRepo.RevokeAllByUserAndApp(ctx, userID, appID, revokedBy); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.logger.Debug("User app sessions revoked successfully",
|
|
zap.String("user_id", userID),
|
|
zap.String("app_id", appID))
|
|
return nil
|
|
}
|
|
|
|
// ValidateSession validates if a session is active and valid
|
|
func (s *sessionService) ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
|
|
s.logger.Debug("Validating session", zap.String("session_id", sessionID.String()))
|
|
|
|
session, err := s.sessionRepo.GetByID(ctx, sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check if session is active
|
|
if !session.IsActive() {
|
|
if session.IsExpired() {
|
|
return nil, errors.NewAuthenticationError("Session has expired")
|
|
}
|
|
if session.IsRevoked() {
|
|
return nil, errors.NewAuthenticationError("Session has been revoked")
|
|
}
|
|
return nil, errors.NewAuthenticationError("Session is not active")
|
|
}
|
|
|
|
// Update last activity
|
|
if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
|
|
s.logger.Warn("Failed to update session activity", zap.Error(err))
|
|
// Don't fail validation if we can't update activity
|
|
}
|
|
|
|
s.logger.Debug("Session validated successfully", zap.String("session_id", sessionID.String()))
|
|
return session, nil
|
|
}
|
|
|
|
// RefreshSession refreshes a session's expiration time
|
|
func (s *sessionService) RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error {
|
|
s.logger.Debug("Refreshing session",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.Time("new_expiration", newExpiration))
|
|
|
|
// Validate session exists and is active
|
|
session, err := s.sessionRepo.GetByID(ctx, sessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !session.IsActive() {
|
|
return errors.NewValidationError("Cannot refresh inactive session")
|
|
}
|
|
|
|
// Update expiration
|
|
updates := &domain.UpdateSessionRequest{
|
|
ExpiresAt: &newExpiration,
|
|
}
|
|
|
|
if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.logger.Debug("Session refreshed successfully", zap.String("session_id", sessionID.String()))
|
|
return nil
|
|
}
|
|
|
|
// CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones
|
|
func (s *sessionService) CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error) {
|
|
s.logger.Debug("Cleaning up expired sessions")
|
|
|
|
// Mark expired sessions
|
|
expired, err = s.sessionRepo.ExpireOldSessions(ctx)
|
|
if err != nil {
|
|
s.logger.Error("Failed to expire old sessions", zap.Error(err))
|
|
return 0, 0, err
|
|
}
|
|
|
|
// Delete old expired sessions if requested
|
|
if deleteOlderThan != nil {
|
|
deleted, err = s.sessionRepo.DeleteExpiredSessions(ctx, *deleteOlderThan)
|
|
if err != nil {
|
|
s.logger.Error("Failed to delete expired sessions", zap.Error(err))
|
|
return expired, 0, err
|
|
}
|
|
}
|
|
|
|
s.logger.Debug("Session cleanup completed",
|
|
zap.Int("expired", expired),
|
|
zap.Int("deleted", deleted))
|
|
|
|
return expired, deleted, nil
|
|
}
|
|
|
|
// GetSessionStats returns session statistics for a user
|
|
func (s *sessionService) GetSessionStats(ctx context.Context, userID string) (total int, active int, err error) {
|
|
s.logger.Debug("Getting session stats", zap.String("user_id", userID))
|
|
|
|
total, err = s.sessionRepo.GetSessionCount(ctx, userID)
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
|
|
active, err = s.sessionRepo.GetActiveSessionCount(ctx, userID)
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
|
|
return total, active, nil
|
|
}
|
|
|
|
// CreateOAuth2Session creates a session from OAuth2 authentication flow
|
|
func (s *sessionService) CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error) {
|
|
s.logger.Debug("Creating OAuth2 session",
|
|
zap.String("user_id", userID),
|
|
zap.String("app_id", appID),
|
|
zap.String("session_type", string(sessionType)))
|
|
|
|
// Validate application exists
|
|
app, err := s.appRepo.GetByID(ctx, appID)
|
|
if err != nil {
|
|
if errors.IsNotFound(err) {
|
|
return nil, errors.NewValidationError("Application not found")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Calculate expiration based on token response
|
|
expiresAt := time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second)
|
|
|
|
// Use application's max token duration if shorter
|
|
maxExpiration := time.Now().Add(app.MaxTokenDuration.Duration)
|
|
if expiresAt.After(maxExpiration) {
|
|
expiresAt = maxExpiration
|
|
}
|
|
|
|
// Create session object
|
|
session := &domain.UserSession{
|
|
ID: uuid.New(),
|
|
UserID: userID,
|
|
AppID: appID,
|
|
SessionType: sessionType,
|
|
Status: domain.SessionStatusActive,
|
|
AccessToken: tokenResponse.AccessToken, // In production, encrypt this
|
|
RefreshToken: tokenResponse.RefreshToken, // In production, encrypt this
|
|
IDToken: tokenResponse.IDToken, // In production, encrypt this
|
|
IPAddress: ipAddress,
|
|
UserAgent: userAgent,
|
|
ExpiresAt: expiresAt,
|
|
Metadata: domain.SessionMetadata{
|
|
LoginMethod: "oauth2",
|
|
Claims: map[string]string{
|
|
"sub": userInfo.Sub,
|
|
"email": userInfo.Email,
|
|
"name": userInfo.Name,
|
|
},
|
|
},
|
|
}
|
|
|
|
// Create session in repository
|
|
if err := s.sessionRepo.Create(ctx, session); err != nil {
|
|
s.logger.Error("Failed to create OAuth2 session", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
s.logger.Debug("OAuth2 session created successfully", zap.String("session_id", session.ID.String()))
|
|
return session, nil
|
|
}
|