v2
This commit is contained in:
@ -104,6 +104,57 @@ type GrantedPermissionRepository interface {
|
||||
HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error)
|
||||
}
|
||||
|
||||
// SessionRepository defines the interface for user session data operations
|
||||
type SessionRepository interface {
|
||||
// Create creates a new user session
|
||||
Create(ctx context.Context, session *domain.UserSession) error
|
||||
|
||||
// GetByID retrieves a session by its ID
|
||||
GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
|
||||
|
||||
// GetByUserID retrieves all sessions for a user
|
||||
GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
|
||||
|
||||
// GetByUserAndApp retrieves sessions for a specific user and application
|
||||
GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error)
|
||||
|
||||
// GetActiveByUserID retrieves all active sessions for a user
|
||||
GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
|
||||
|
||||
// List retrieves sessions with filtering and pagination
|
||||
List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error)
|
||||
|
||||
// Update updates an existing session
|
||||
Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error
|
||||
|
||||
// UpdateActivity updates the last activity timestamp for a session
|
||||
UpdateActivity(ctx context.Context, sessionID uuid.UUID) error
|
||||
|
||||
// Revoke revokes a session
|
||||
Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error
|
||||
|
||||
// RevokeAllByUser revokes all sessions for a user
|
||||
RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error
|
||||
|
||||
// RevokeAllByUserAndApp revokes all sessions for a user and application
|
||||
RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error
|
||||
|
||||
// ExpireOldSessions marks expired sessions as expired
|
||||
ExpireOldSessions(ctx context.Context) (int, error)
|
||||
|
||||
// DeleteExpiredSessions removes expired sessions older than the specified duration
|
||||
DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error)
|
||||
|
||||
// Exists checks if a session exists
|
||||
Exists(ctx context.Context, sessionID uuid.UUID) (bool, error)
|
||||
|
||||
// GetSessionCount returns the total number of sessions for a user
|
||||
GetSessionCount(ctx context.Context, userID string) (int, error)
|
||||
|
||||
// GetActiveSessionCount returns the number of active sessions for a user
|
||||
GetActiveSessionCount(ctx context.Context, userID string) (int, error)
|
||||
}
|
||||
|
||||
// DatabaseProvider defines the interface for database operations
|
||||
type DatabaseProvider interface {
|
||||
// GetDB returns the underlying database connection
|
||||
|
||||
624
internal/repository/postgres/session_repository.go
Normal file
624
internal/repository/postgres/session_repository.go
Normal file
@ -0,0 +1,624 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
)
|
||||
|
||||
// sessionRepository implements the SessionRepository interface
|
||||
type sessionRepository struct {
|
||||
db *sqlx.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewSessionRepository creates a new session repository
|
||||
func NewSessionRepository(db *sqlx.DB, logger *zap.Logger) repository.SessionRepository {
|
||||
return &sessionRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates a new user session
|
||||
func (r *sessionRepository) Create(ctx context.Context, session *domain.UserSession) error {
|
||||
r.logger.Debug("Creating new session",
|
||||
zap.String("user_id", session.UserID),
|
||||
zap.String("app_id", session.AppID),
|
||||
zap.String("session_type", string(session.SessionType)))
|
||||
|
||||
// Generate ID if not provided
|
||||
if session.ID == uuid.Nil {
|
||||
session.ID = uuid.New()
|
||||
}
|
||||
|
||||
// Set timestamps
|
||||
now := time.Now()
|
||||
session.CreatedAt = now
|
||||
session.UpdatedAt = now
|
||||
session.LastActivity = now
|
||||
|
||||
// Serialize metadata
|
||||
metadataJSON, err := json.Marshal(session.Metadata)
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to serialize session metadata").WithInternal(err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO user_sessions (
|
||||
id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at, metadata
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
|
||||
)`
|
||||
|
||||
_, err = r.db.ExecContext(ctx, query,
|
||||
session.ID,
|
||||
session.UserID,
|
||||
session.AppID,
|
||||
session.SessionType,
|
||||
session.Status,
|
||||
session.AccessToken,
|
||||
session.RefreshToken,
|
||||
session.IDToken,
|
||||
session.IPAddress,
|
||||
session.UserAgent,
|
||||
session.LastActivity,
|
||||
session.ExpiresAt,
|
||||
session.CreatedAt,
|
||||
session.UpdatedAt,
|
||||
metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to create session", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to create session").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a session by its ID
|
||||
func (r *sessionRepository) GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting session by ID", zap.String("session_id", sessionID.String()))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE id = $1`
|
||||
|
||||
var session domain.UserSession
|
||||
var metadataJSON []byte
|
||||
var revokedAt sql.NullTime
|
||||
var revokedBy sql.NullString
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.AppID,
|
||||
&session.SessionType,
|
||||
&session.Status,
|
||||
&session.AccessToken,
|
||||
&session.RefreshToken,
|
||||
&session.IDToken,
|
||||
&session.IPAddress,
|
||||
&session.UserAgent,
|
||||
&session.LastActivity,
|
||||
&session.ExpiresAt,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&revokedAt,
|
||||
&revokedBy,
|
||||
&metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
r.logger.Error("Failed to get session by ID", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to retrieve session").WithInternal(err)
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if revokedAt.Valid {
|
||||
session.RevokedAt = &revokedAt.Time
|
||||
}
|
||||
if revokedBy.Valid {
|
||||
session.RevokedBy = &revokedBy.String
|
||||
}
|
||||
|
||||
// Deserialize metadata
|
||||
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
|
||||
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
|
||||
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
|
||||
}
|
||||
|
||||
r.logger.Debug("Session retrieved successfully", zap.String("session_id", sessionID.String()))
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// GetByUserID retrieves all sessions for a user
|
||||
func (r *sessionRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting sessions by user ID", zap.String("user_id", userID))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC`
|
||||
|
||||
return r.scanSessions(ctx, query, userID)
|
||||
}
|
||||
|
||||
// GetByUserAndApp retrieves sessions for a specific user and application
|
||||
func (r *sessionRepository) GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting sessions by user and app",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1 AND app_id = $2
|
||||
ORDER BY created_at DESC`
|
||||
|
||||
return r.scanSessions(ctx, query, userID, appID)
|
||||
}
|
||||
|
||||
// GetActiveByUserID retrieves all active sessions for a user
|
||||
func (r *sessionRepository) GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting active sessions by user ID", zap.String("user_id", userID))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1 AND status = $2 AND expires_at > NOW()
|
||||
ORDER BY last_activity DESC`
|
||||
|
||||
return r.scanSessions(ctx, query, userID, domain.SessionStatusActive)
|
||||
}
|
||||
|
||||
// List retrieves sessions with filtering and pagination
|
||||
func (r *sessionRepository) List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
|
||||
r.logger.Debug("Listing sessions with filters",
|
||||
zap.String("user_id", req.UserID),
|
||||
zap.String("app_id", req.AppID),
|
||||
zap.Int("limit", req.Limit),
|
||||
zap.Int("offset", req.Offset))
|
||||
|
||||
// Build WHERE clause dynamically
|
||||
whereClause := "WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if req.UserID != "" {
|
||||
whereClause += fmt.Sprintf(" AND user_id = $%d", argIndex)
|
||||
args = append(args, req.UserID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.AppID != "" {
|
||||
whereClause += fmt.Sprintf(" AND app_id = $%d", argIndex)
|
||||
args = append(args, req.AppID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
whereClause += fmt.Sprintf(" AND status = $%d", argIndex)
|
||||
args = append(args, *req.Status)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.SessionType != nil {
|
||||
whereClause += fmt.Sprintf(" AND session_type = $%d", argIndex)
|
||||
args = append(args, *req.SessionType)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.TenantID != "" {
|
||||
whereClause += fmt.Sprintf(" AND metadata->>'tenant_id' = $%d", argIndex)
|
||||
args = append(args, req.TenantID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
// Get total count
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM user_sessions %s", whereClause)
|
||||
var total int
|
||||
err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get session count", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to count sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
// Get sessions with pagination
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
%s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d`, whereClause, argIndex, argIndex+1)
|
||||
|
||||
args = append(args, req.Limit, req.Offset)
|
||||
sessions, err := r.scanSessions(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &domain.SessionListResponse{
|
||||
Sessions: sessions,
|
||||
Total: total,
|
||||
Limit: req.Limit,
|
||||
Offset: req.Offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Update updates an existing session
|
||||
func (r *sessionRepository) Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
|
||||
r.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
|
||||
|
||||
// Build UPDATE clause dynamically
|
||||
setParts := []string{"updated_at = NOW()"}
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if updates.Status != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("status = $%d", argIndex))
|
||||
args = append(args, *updates.Status)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.LastActivity != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("last_activity = $%d", argIndex))
|
||||
args = append(args, *updates.LastActivity)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.ExpiresAt != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("expires_at = $%d", argIndex))
|
||||
args = append(args, *updates.ExpiresAt)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.IPAddress != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("ip_address = $%d", argIndex))
|
||||
args = append(args, *updates.IPAddress)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.UserAgent != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("user_agent = $%d", argIndex))
|
||||
args = append(args, *updates.UserAgent)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if len(setParts) == 1 {
|
||||
return errors.NewValidationError("No fields to update")
|
||||
}
|
||||
|
||||
// Build the complete query
|
||||
setClause := fmt.Sprintf("%s", setParts[0])
|
||||
for i := 1; i < len(setParts); i++ {
|
||||
setClause += fmt.Sprintf(", %s", setParts[i])
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE user_sessions SET %s WHERE id = $%d", setClause, argIndex)
|
||||
args = append(args, sessionID)
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to update session", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to update session").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
|
||||
r.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateActivity updates the last activity timestamp for a session
|
||||
func (r *sessionRepository) UpdateActivity(ctx context.Context, sessionID uuid.UUID) error {
|
||||
r.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
|
||||
|
||||
query := `UPDATE user_sessions SET last_activity = NOW(), updated_at = NOW() WHERE id = $1`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, sessionID)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to update session activity", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to update session activity").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke revokes a session
|
||||
func (r *sessionRepository) Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
|
||||
r.logger.Debug("Revoking session",
|
||||
zap.String("session_id", sessionID.String()),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
|
||||
WHERE id = $3`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, sessionID)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to revoke session", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke session").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
|
||||
r.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllByUser revokes all sessions for a user
|
||||
func (r *sessionRepository) RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error {
|
||||
r.logger.Debug("Revoking all sessions for user",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
|
||||
WHERE user_id = $3 AND status = $4`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, domain.SessionStatusActive)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to revoke user sessions", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke user sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("User sessions revoked",
|
||||
zap.String("user_id", userID),
|
||||
zap.Int64("sessions_revoked", rowsAffected))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllByUserAndApp revokes all sessions for a user and application
|
||||
func (r *sessionRepository) RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error {
|
||||
r.logger.Debug("Revoking all sessions for user and app",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
|
||||
WHERE user_id = $3 AND app_id = $4 AND status = $5`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, appID, domain.SessionStatusActive)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to revoke user app sessions", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke user app sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("User app sessions revoked",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID),
|
||||
zap.Int64("sessions_revoked", rowsAffected))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExpireOldSessions marks expired sessions as expired
|
||||
func (r *sessionRepository) ExpireOldSessions(ctx context.Context) (int, error) {
|
||||
r.logger.Debug("Expiring old sessions")
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, updated_at = NOW()
|
||||
WHERE expires_at < NOW() AND status = $2`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, domain.SessionStatusActive)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to expire old sessions", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to expire old sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("Old sessions expired", zap.Int64("sessions_expired", rowsAffected))
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// DeleteExpiredSessions removes expired sessions older than the specified duration
|
||||
func (r *sessionRepository) DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) {
|
||||
r.logger.Debug("Deleting expired sessions", zap.Duration("older_than", olderThan))
|
||||
|
||||
cutoffTime := time.Now().Add(-olderThan)
|
||||
query := `DELETE FROM user_sessions WHERE status = $1 AND updated_at < $2`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, cutoffTime)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to delete expired sessions", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to delete expired sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("Expired sessions deleted", zap.Int64("sessions_deleted", rowsAffected))
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// Exists checks if a session exists
|
||||
func (r *sessionRepository) Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) {
|
||||
r.logger.Debug("Checking if session exists", zap.String("session_id", sessionID.String()))
|
||||
|
||||
query := `SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1)`
|
||||
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(&exists)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to check session existence", zap.Error(err))
|
||||
return false, errors.NewInternalError("Failed to check session existence").WithInternal(err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetSessionCount returns the total number of sessions for a user
|
||||
func (r *sessionRepository) GetSessionCount(ctx context.Context, userID string) (int, error) {
|
||||
r.logger.Debug("Getting session count for user", zap.String("user_id", userID))
|
||||
|
||||
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1`
|
||||
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get session count", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to get session count").WithInternal(err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetActiveSessionCount returns the number of active sessions for a user
|
||||
func (r *sessionRepository) GetActiveSessionCount(ctx context.Context, userID string) (int, error) {
|
||||
r.logger.Debug("Getting active session count for user", zap.String("user_id", userID))
|
||||
|
||||
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW()`
|
||||
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, query, userID, domain.SessionStatusActive).Scan(&count)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get active session count", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to get active session count").WithInternal(err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// scanSessions is a helper method to scan multiple sessions from query results
|
||||
func (r *sessionRepository) scanSessions(ctx context.Context, query string, args ...interface{}) ([]*domain.UserSession, error) {
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to execute session query", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to retrieve sessions").WithInternal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []*domain.UserSession
|
||||
for rows.Next() {
|
||||
var session domain.UserSession
|
||||
var metadataJSON []byte
|
||||
var revokedAt sql.NullTime
|
||||
var revokedBy sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.AppID,
|
||||
&session.SessionType,
|
||||
&session.Status,
|
||||
&session.AccessToken,
|
||||
&session.RefreshToken,
|
||||
&session.IDToken,
|
||||
&session.IPAddress,
|
||||
&session.UserAgent,
|
||||
&session.LastActivity,
|
||||
&session.ExpiresAt,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&revokedAt,
|
||||
&revokedBy,
|
||||
&metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to scan session row", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to scan session data").WithInternal(err)
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if revokedAt.Valid {
|
||||
session.RevokedAt = &revokedAt.Time
|
||||
}
|
||||
if revokedBy.Valid {
|
||||
session.RevokedBy = &revokedBy.String
|
||||
}
|
||||
|
||||
// Deserialize metadata
|
||||
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
|
||||
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
|
||||
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
|
||||
}
|
||||
|
||||
sessions = append(sessions, &session)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
r.logger.Error("Error iterating session rows", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to iterate session results").WithInternal(err)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
Reference in New Issue
Block a user