package postgres import ( "context" "database/sql" "encoding/json" "fmt" "time" "github.com/google/uuid" "github.com/jmoiron/sqlx" "go.uber.org/zap" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/errors" "github.com/kms/api-key-service/internal/repository" ) // sessionRepository implements the SessionRepository interface type sessionRepository struct { db *sqlx.DB logger *zap.Logger } // NewSessionRepository creates a new session repository func NewSessionRepository(db *sqlx.DB, logger *zap.Logger) repository.SessionRepository { return &sessionRepository{ db: db, logger: logger, } } // Create creates a new user session func (r *sessionRepository) Create(ctx context.Context, session *domain.UserSession) error { r.logger.Debug("Creating new session", zap.String("user_id", session.UserID), zap.String("app_id", session.AppID), zap.String("session_type", string(session.SessionType))) // Generate ID if not provided if session.ID == uuid.Nil { session.ID = uuid.New() } // Set timestamps now := time.Now() session.CreatedAt = now session.UpdatedAt = now session.LastActivity = now // Serialize metadata metadataJSON, err := json.Marshal(session.Metadata) if err != nil { return errors.NewInternalError("Failed to serialize session metadata").WithInternal(err) } query := ` INSERT INTO user_sessions ( id, user_id, app_id, session_type, status, access_token, refresh_token, id_token, ip_address, user_agent, last_activity, expires_at, created_at, updated_at, metadata ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15 )` _, err = r.db.ExecContext(ctx, query, session.ID, session.UserID, session.AppID, session.SessionType, session.Status, session.AccessToken, session.RefreshToken, session.IDToken, session.IPAddress, session.UserAgent, session.LastActivity, session.ExpiresAt, session.CreatedAt, session.UpdatedAt, metadataJSON, ) if err != nil { r.logger.Error("Failed to create session", zap.Error(err)) return errors.NewInternalError("Failed to create session").WithInternal(err) } r.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String())) return nil } // GetByID retrieves a session by its ID func (r *sessionRepository) GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) { r.logger.Debug("Getting session by ID", zap.String("session_id", sessionID.String())) query := ` SELECT id, user_id, app_id, session_type, status, access_token, refresh_token, id_token, ip_address, user_agent, last_activity, expires_at, created_at, updated_at, revoked_at, revoked_by, metadata FROM user_sessions WHERE id = $1` var session domain.UserSession var metadataJSON []byte var revokedAt sql.NullTime var revokedBy sql.NullString err := r.db.QueryRowContext(ctx, query, sessionID).Scan( &session.ID, &session.UserID, &session.AppID, &session.SessionType, &session.Status, &session.AccessToken, &session.RefreshToken, &session.IDToken, &session.IPAddress, &session.UserAgent, &session.LastActivity, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &revokedAt, &revokedBy, &metadataJSON, ) if err != nil { if err == sql.ErrNoRows { return nil, errors.NewNotFoundError("Session not found") } r.logger.Error("Failed to get session by ID", zap.Error(err)) return nil, errors.NewInternalError("Failed to retrieve session").WithInternal(err) } // Handle nullable fields if revokedAt.Valid { session.RevokedAt = &revokedAt.Time } if revokedBy.Valid { session.RevokedBy = &revokedBy.String } // Deserialize metadata if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil { r.logger.Warn("Failed to deserialize session metadata", zap.Error(err)) session.Metadata = domain.SessionMetadata{} // Use empty metadata on error } r.logger.Debug("Session retrieved successfully", zap.String("session_id", sessionID.String())) return &session, nil } // GetByUserID retrieves all sessions for a user func (r *sessionRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) { r.logger.Debug("Getting sessions by user ID", zap.String("user_id", userID)) query := ` SELECT id, user_id, app_id, session_type, status, access_token, refresh_token, id_token, ip_address, user_agent, last_activity, expires_at, created_at, updated_at, revoked_at, revoked_by, metadata FROM user_sessions WHERE user_id = $1 ORDER BY created_at DESC` return r.scanSessions(ctx, query, userID) } // GetByUserAndApp retrieves sessions for a specific user and application func (r *sessionRepository) GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) { r.logger.Debug("Getting sessions by user and app", zap.String("user_id", userID), zap.String("app_id", appID)) query := ` SELECT id, user_id, app_id, session_type, status, access_token, refresh_token, id_token, ip_address, user_agent, last_activity, expires_at, created_at, updated_at, revoked_at, revoked_by, metadata FROM user_sessions WHERE user_id = $1 AND app_id = $2 ORDER BY created_at DESC` return r.scanSessions(ctx, query, userID, appID) } // GetActiveByUserID retrieves all active sessions for a user func (r *sessionRepository) GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) { r.logger.Debug("Getting active sessions by user ID", zap.String("user_id", userID)) query := ` SELECT id, user_id, app_id, session_type, status, access_token, refresh_token, id_token, ip_address, user_agent, last_activity, expires_at, created_at, updated_at, revoked_at, revoked_by, metadata FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW() ORDER BY last_activity DESC` return r.scanSessions(ctx, query, userID, domain.SessionStatusActive) } // List retrieves sessions with filtering and pagination func (r *sessionRepository) List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) { r.logger.Debug("Listing sessions with filters", zap.String("user_id", req.UserID), zap.String("app_id", req.AppID), zap.Int("limit", req.Limit), zap.Int("offset", req.Offset)) // Build WHERE clause dynamically whereClause := "WHERE 1=1" args := []interface{}{} argIndex := 1 if req.UserID != "" { whereClause += fmt.Sprintf(" AND user_id = $%d", argIndex) args = append(args, req.UserID) argIndex++ } if req.AppID != "" { whereClause += fmt.Sprintf(" AND app_id = $%d", argIndex) args = append(args, req.AppID) argIndex++ } if req.Status != nil { whereClause += fmt.Sprintf(" AND status = $%d", argIndex) args = append(args, *req.Status) argIndex++ } if req.SessionType != nil { whereClause += fmt.Sprintf(" AND session_type = $%d", argIndex) args = append(args, *req.SessionType) argIndex++ } if req.TenantID != "" { whereClause += fmt.Sprintf(" AND metadata->>'tenant_id' = $%d", argIndex) args = append(args, req.TenantID) argIndex++ } // Get total count countQuery := fmt.Sprintf("SELECT COUNT(*) FROM user_sessions %s", whereClause) var total int err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total) if err != nil { r.logger.Error("Failed to get session count", zap.Error(err)) return nil, errors.NewInternalError("Failed to count sessions").WithInternal(err) } // Get sessions with pagination query := fmt.Sprintf(` SELECT id, user_id, app_id, session_type, status, access_token, refresh_token, id_token, ip_address, user_agent, last_activity, expires_at, created_at, updated_at, revoked_at, revoked_by, metadata FROM user_sessions %s ORDER BY created_at DESC LIMIT $%d OFFSET $%d`, whereClause, argIndex, argIndex+1) args = append(args, req.Limit, req.Offset) sessions, err := r.scanSessions(ctx, query, args...) if err != nil { return nil, err } return &domain.SessionListResponse{ Sessions: sessions, Total: total, Limit: req.Limit, Offset: req.Offset, }, nil } // Update updates an existing session func (r *sessionRepository) Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error { r.logger.Debug("Updating session", zap.String("session_id", sessionID.String())) // Build UPDATE clause dynamically setParts := []string{"updated_at = NOW()"} args := []interface{}{} argIndex := 1 if updates.Status != nil { setParts = append(setParts, fmt.Sprintf("status = $%d", argIndex)) args = append(args, *updates.Status) argIndex++ } if updates.LastActivity != nil { setParts = append(setParts, fmt.Sprintf("last_activity = $%d", argIndex)) args = append(args, *updates.LastActivity) argIndex++ } if updates.ExpiresAt != nil { setParts = append(setParts, fmt.Sprintf("expires_at = $%d", argIndex)) args = append(args, *updates.ExpiresAt) argIndex++ } if updates.IPAddress != nil { setParts = append(setParts, fmt.Sprintf("ip_address = $%d", argIndex)) args = append(args, *updates.IPAddress) argIndex++ } if updates.UserAgent != nil { setParts = append(setParts, fmt.Sprintf("user_agent = $%d", argIndex)) args = append(args, *updates.UserAgent) argIndex++ } if len(setParts) == 1 { return errors.NewValidationError("No fields to update") } // Build the complete query setClause := fmt.Sprintf("%s", setParts[0]) for i := 1; i < len(setParts); i++ { setClause += fmt.Sprintf(", %s", setParts[i]) } query := fmt.Sprintf("UPDATE user_sessions SET %s WHERE id = $%d", setClause, argIndex) args = append(args, sessionID) result, err := r.db.ExecContext(ctx, query, args...) if err != nil { r.logger.Error("Failed to update session", zap.Error(err)) return errors.NewInternalError("Failed to update session").WithInternal(err) } rowsAffected, err := result.RowsAffected() if err != nil { return errors.NewInternalError("Failed to get affected rows").WithInternal(err) } if rowsAffected == 0 { return errors.NewNotFoundError("Session not found") } r.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String())) return nil } // UpdateActivity updates the last activity timestamp for a session func (r *sessionRepository) UpdateActivity(ctx context.Context, sessionID uuid.UUID) error { r.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String())) query := `UPDATE user_sessions SET last_activity = NOW(), updated_at = NOW() WHERE id = $1` result, err := r.db.ExecContext(ctx, query, sessionID) if err != nil { r.logger.Error("Failed to update session activity", zap.Error(err)) return errors.NewInternalError("Failed to update session activity").WithInternal(err) } rowsAffected, err := result.RowsAffected() if err != nil { return errors.NewInternalError("Failed to get affected rows").WithInternal(err) } if rowsAffected == 0 { return errors.NewNotFoundError("Session not found") } return nil } // Revoke revokes a session func (r *sessionRepository) Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error { r.logger.Debug("Revoking session", zap.String("session_id", sessionID.String()), zap.String("revoked_by", revokedBy)) query := ` UPDATE user_sessions SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW() WHERE id = $3` result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, sessionID) if err != nil { r.logger.Error("Failed to revoke session", zap.Error(err)) return errors.NewInternalError("Failed to revoke session").WithInternal(err) } rowsAffected, err := result.RowsAffected() if err != nil { return errors.NewInternalError("Failed to get affected rows").WithInternal(err) } if rowsAffected == 0 { return errors.NewNotFoundError("Session not found") } r.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String())) return nil } // RevokeAllByUser revokes all sessions for a user func (r *sessionRepository) RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error { r.logger.Debug("Revoking all sessions for user", zap.String("user_id", userID), zap.String("revoked_by", revokedBy)) query := ` UPDATE user_sessions SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW() WHERE user_id = $3 AND status = $4` result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, domain.SessionStatusActive) if err != nil { r.logger.Error("Failed to revoke user sessions", zap.Error(err)) return errors.NewInternalError("Failed to revoke user sessions").WithInternal(err) } rowsAffected, err := result.RowsAffected() if err != nil { return errors.NewInternalError("Failed to get affected rows").WithInternal(err) } r.logger.Debug("User sessions revoked", zap.String("user_id", userID), zap.Int64("sessions_revoked", rowsAffected)) return nil } // RevokeAllByUserAndApp revokes all sessions for a user and application func (r *sessionRepository) RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error { r.logger.Debug("Revoking all sessions for user and app", zap.String("user_id", userID), zap.String("app_id", appID), zap.String("revoked_by", revokedBy)) query := ` UPDATE user_sessions SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW() WHERE user_id = $3 AND app_id = $4 AND status = $5` result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, appID, domain.SessionStatusActive) if err != nil { r.logger.Error("Failed to revoke user app sessions", zap.Error(err)) return errors.NewInternalError("Failed to revoke user app sessions").WithInternal(err) } rowsAffected, err := result.RowsAffected() if err != nil { return errors.NewInternalError("Failed to get affected rows").WithInternal(err) } r.logger.Debug("User app sessions revoked", zap.String("user_id", userID), zap.String("app_id", appID), zap.Int64("sessions_revoked", rowsAffected)) return nil } // ExpireOldSessions marks expired sessions as expired func (r *sessionRepository) ExpireOldSessions(ctx context.Context) (int, error) { r.logger.Debug("Expiring old sessions") query := ` UPDATE user_sessions SET status = $1, updated_at = NOW() WHERE expires_at < NOW() AND status = $2` result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, domain.SessionStatusActive) if err != nil { r.logger.Error("Failed to expire old sessions", zap.Error(err)) return 0, errors.NewInternalError("Failed to expire old sessions").WithInternal(err) } rowsAffected, err := result.RowsAffected() if err != nil { return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err) } r.logger.Debug("Old sessions expired", zap.Int64("sessions_expired", rowsAffected)) return int(rowsAffected), nil } // DeleteExpiredSessions removes expired sessions older than the specified duration func (r *sessionRepository) DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) { r.logger.Debug("Deleting expired sessions", zap.Duration("older_than", olderThan)) cutoffTime := time.Now().Add(-olderThan) query := `DELETE FROM user_sessions WHERE status = $1 AND updated_at < $2` result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, cutoffTime) if err != nil { r.logger.Error("Failed to delete expired sessions", zap.Error(err)) return 0, errors.NewInternalError("Failed to delete expired sessions").WithInternal(err) } rowsAffected, err := result.RowsAffected() if err != nil { return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err) } r.logger.Debug("Expired sessions deleted", zap.Int64("sessions_deleted", rowsAffected)) return int(rowsAffected), nil } // Exists checks if a session exists func (r *sessionRepository) Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) { r.logger.Debug("Checking if session exists", zap.String("session_id", sessionID.String())) query := `SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1)` var exists bool err := r.db.QueryRowContext(ctx, query, sessionID).Scan(&exists) if err != nil { r.logger.Error("Failed to check session existence", zap.Error(err)) return false, errors.NewInternalError("Failed to check session existence").WithInternal(err) } return exists, nil } // GetSessionCount returns the total number of sessions for a user func (r *sessionRepository) GetSessionCount(ctx context.Context, userID string) (int, error) { r.logger.Debug("Getting session count for user", zap.String("user_id", userID)) query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1` var count int err := r.db.QueryRowContext(ctx, query, userID).Scan(&count) if err != nil { r.logger.Error("Failed to get session count", zap.Error(err)) return 0, errors.NewInternalError("Failed to get session count").WithInternal(err) } return count, nil } // GetActiveSessionCount returns the number of active sessions for a user func (r *sessionRepository) GetActiveSessionCount(ctx context.Context, userID string) (int, error) { r.logger.Debug("Getting active session count for user", zap.String("user_id", userID)) query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW()` var count int err := r.db.QueryRowContext(ctx, query, userID, domain.SessionStatusActive).Scan(&count) if err != nil { r.logger.Error("Failed to get active session count", zap.Error(err)) return 0, errors.NewInternalError("Failed to get active session count").WithInternal(err) } return count, nil } // scanSessions is a helper method to scan multiple sessions from query results func (r *sessionRepository) scanSessions(ctx context.Context, query string, args ...interface{}) ([]*domain.UserSession, error) { rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { r.logger.Error("Failed to execute session query", zap.Error(err)) return nil, errors.NewInternalError("Failed to retrieve sessions").WithInternal(err) } defer rows.Close() var sessions []*domain.UserSession for rows.Next() { var session domain.UserSession var metadataJSON []byte var revokedAt sql.NullTime var revokedBy sql.NullString err := rows.Scan( &session.ID, &session.UserID, &session.AppID, &session.SessionType, &session.Status, &session.AccessToken, &session.RefreshToken, &session.IDToken, &session.IPAddress, &session.UserAgent, &session.LastActivity, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &revokedAt, &revokedBy, &metadataJSON, ) if err != nil { r.logger.Error("Failed to scan session row", zap.Error(err)) return nil, errors.NewInternalError("Failed to scan session data").WithInternal(err) } // Handle nullable fields if revokedAt.Valid { session.RevokedAt = &revokedAt.Time } if revokedBy.Valid { session.RevokedBy = &revokedBy.String } // Deserialize metadata if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil { r.logger.Warn("Failed to deserialize session metadata", zap.Error(err)) session.Metadata = domain.SessionMetadata{} // Use empty metadata on error } sessions = append(sessions, &session) } if err := rows.Err(); err != nil { r.logger.Error("Error iterating session rows", zap.Error(err)) return nil, errors.NewInternalError("Failed to iterate session results").WithInternal(err) } return sessions, nil }