Files
skybridge/kms/internal/repository/postgres/session_repository.go
2025-08-26 19:16:41 -04:00

625 lines
19 KiB
Go

package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/repository"
)
// sessionRepository implements the SessionRepository interface
type sessionRepository struct {
db *sqlx.DB
logger *zap.Logger
}
// NewSessionRepository creates a new session repository
func NewSessionRepository(db *sqlx.DB, logger *zap.Logger) repository.SessionRepository {
return &sessionRepository{
db: db,
logger: logger,
}
}
// Create creates a new user session
func (r *sessionRepository) Create(ctx context.Context, session *domain.UserSession) error {
r.logger.Debug("Creating new session",
zap.String("user_id", session.UserID),
zap.String("app_id", session.AppID),
zap.String("session_type", string(session.SessionType)))
// Generate ID if not provided
if session.ID == uuid.Nil {
session.ID = uuid.New()
}
// Set timestamps
now := time.Now()
session.CreatedAt = now
session.UpdatedAt = now
session.LastActivity = now
// Serialize metadata
metadataJSON, err := json.Marshal(session.Metadata)
if err != nil {
return errors.NewInternalError("Failed to serialize session metadata").WithInternal(err)
}
query := `
INSERT INTO user_sessions (
id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at, metadata
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
)`
_, err = r.db.ExecContext(ctx, query,
session.ID,
session.UserID,
session.AppID,
session.SessionType,
session.Status,
session.AccessToken,
session.RefreshToken,
session.IDToken,
session.IPAddress,
session.UserAgent,
session.LastActivity,
session.ExpiresAt,
session.CreatedAt,
session.UpdatedAt,
metadataJSON,
)
if err != nil {
r.logger.Error("Failed to create session", zap.Error(err))
return errors.NewInternalError("Failed to create session").WithInternal(err)
}
r.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
return nil
}
// GetByID retrieves a session by its ID
func (r *sessionRepository) GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
r.logger.Debug("Getting session by ID", zap.String("session_id", sessionID.String()))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE id = $1`
var session domain.UserSession
var metadataJSON []byte
var revokedAt sql.NullTime
var revokedBy sql.NullString
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(
&session.ID,
&session.UserID,
&session.AppID,
&session.SessionType,
&session.Status,
&session.AccessToken,
&session.RefreshToken,
&session.IDToken,
&session.IPAddress,
&session.UserAgent,
&session.LastActivity,
&session.ExpiresAt,
&session.CreatedAt,
&session.UpdatedAt,
&revokedAt,
&revokedBy,
&metadataJSON,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, errors.NewNotFoundError("Session not found")
}
r.logger.Error("Failed to get session by ID", zap.Error(err))
return nil, errors.NewInternalError("Failed to retrieve session").WithInternal(err)
}
// Handle nullable fields
if revokedAt.Valid {
session.RevokedAt = &revokedAt.Time
}
if revokedBy.Valid {
session.RevokedBy = &revokedBy.String
}
// Deserialize metadata
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
}
r.logger.Debug("Session retrieved successfully", zap.String("session_id", sessionID.String()))
return &session, nil
}
// GetByUserID retrieves all sessions for a user
func (r *sessionRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
r.logger.Debug("Getting sessions by user ID", zap.String("user_id", userID))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE user_id = $1
ORDER BY created_at DESC`
return r.scanSessions(ctx, query, userID)
}
// GetByUserAndApp retrieves sessions for a specific user and application
func (r *sessionRepository) GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
r.logger.Debug("Getting sessions by user and app",
zap.String("user_id", userID),
zap.String("app_id", appID))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE user_id = $1 AND app_id = $2
ORDER BY created_at DESC`
return r.scanSessions(ctx, query, userID, appID)
}
// GetActiveByUserID retrieves all active sessions for a user
func (r *sessionRepository) GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
r.logger.Debug("Getting active sessions by user ID", zap.String("user_id", userID))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE user_id = $1 AND status = $2 AND expires_at > NOW()
ORDER BY last_activity DESC`
return r.scanSessions(ctx, query, userID, domain.SessionStatusActive)
}
// List retrieves sessions with filtering and pagination
func (r *sessionRepository) List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
r.logger.Debug("Listing sessions with filters",
zap.String("user_id", req.UserID),
zap.String("app_id", req.AppID),
zap.Int("limit", req.Limit),
zap.Int("offset", req.Offset))
// Build WHERE clause dynamically
whereClause := "WHERE 1=1"
args := []interface{}{}
argIndex := 1
if req.UserID != "" {
whereClause += fmt.Sprintf(" AND user_id = $%d", argIndex)
args = append(args, req.UserID)
argIndex++
}
if req.AppID != "" {
whereClause += fmt.Sprintf(" AND app_id = $%d", argIndex)
args = append(args, req.AppID)
argIndex++
}
if req.Status != nil {
whereClause += fmt.Sprintf(" AND status = $%d", argIndex)
args = append(args, *req.Status)
argIndex++
}
if req.SessionType != nil {
whereClause += fmt.Sprintf(" AND session_type = $%d", argIndex)
args = append(args, *req.SessionType)
argIndex++
}
if req.TenantID != "" {
whereClause += fmt.Sprintf(" AND metadata->>'tenant_id' = $%d", argIndex)
args = append(args, req.TenantID)
argIndex++
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM user_sessions %s", whereClause)
var total int
err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
r.logger.Error("Failed to get session count", zap.Error(err))
return nil, errors.NewInternalError("Failed to count sessions").WithInternal(err)
}
// Get sessions with pagination
query := fmt.Sprintf(`
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
%s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d`, whereClause, argIndex, argIndex+1)
args = append(args, req.Limit, req.Offset)
sessions, err := r.scanSessions(ctx, query, args...)
if err != nil {
return nil, err
}
return &domain.SessionListResponse{
Sessions: sessions,
Total: total,
Limit: req.Limit,
Offset: req.Offset,
}, nil
}
// Update updates an existing session
func (r *sessionRepository) Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
r.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
// Build UPDATE clause dynamically
setParts := []string{"updated_at = NOW()"}
args := []interface{}{}
argIndex := 1
if updates.Status != nil {
setParts = append(setParts, fmt.Sprintf("status = $%d", argIndex))
args = append(args, *updates.Status)
argIndex++
}
if updates.LastActivity != nil {
setParts = append(setParts, fmt.Sprintf("last_activity = $%d", argIndex))
args = append(args, *updates.LastActivity)
argIndex++
}
if updates.ExpiresAt != nil {
setParts = append(setParts, fmt.Sprintf("expires_at = $%d", argIndex))
args = append(args, *updates.ExpiresAt)
argIndex++
}
if updates.IPAddress != nil {
setParts = append(setParts, fmt.Sprintf("ip_address = $%d", argIndex))
args = append(args, *updates.IPAddress)
argIndex++
}
if updates.UserAgent != nil {
setParts = append(setParts, fmt.Sprintf("user_agent = $%d", argIndex))
args = append(args, *updates.UserAgent)
argIndex++
}
if len(setParts) == 1 {
return errors.NewValidationError("No fields to update")
}
// Build the complete query
setClause := fmt.Sprintf("%s", setParts[0])
for i := 1; i < len(setParts); i++ {
setClause += fmt.Sprintf(", %s", setParts[i])
}
query := fmt.Sprintf("UPDATE user_sessions SET %s WHERE id = $%d", setClause, argIndex)
args = append(args, sessionID)
result, err := r.db.ExecContext(ctx, query, args...)
if err != nil {
r.logger.Error("Failed to update session", zap.Error(err))
return errors.NewInternalError("Failed to update session").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
if rowsAffected == 0 {
return errors.NewNotFoundError("Session not found")
}
r.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
return nil
}
// UpdateActivity updates the last activity timestamp for a session
func (r *sessionRepository) UpdateActivity(ctx context.Context, sessionID uuid.UUID) error {
r.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
query := `UPDATE user_sessions SET last_activity = NOW(), updated_at = NOW() WHERE id = $1`
result, err := r.db.ExecContext(ctx, query, sessionID)
if err != nil {
r.logger.Error("Failed to update session activity", zap.Error(err))
return errors.NewInternalError("Failed to update session activity").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
if rowsAffected == 0 {
return errors.NewNotFoundError("Session not found")
}
return nil
}
// Revoke revokes a session
func (r *sessionRepository) Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
r.logger.Debug("Revoking session",
zap.String("session_id", sessionID.String()),
zap.String("revoked_by", revokedBy))
query := `
UPDATE user_sessions
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
WHERE id = $3`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, sessionID)
if err != nil {
r.logger.Error("Failed to revoke session", zap.Error(err))
return errors.NewInternalError("Failed to revoke session").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
if rowsAffected == 0 {
return errors.NewNotFoundError("Session not found")
}
r.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
return nil
}
// RevokeAllByUser revokes all sessions for a user
func (r *sessionRepository) RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error {
r.logger.Debug("Revoking all sessions for user",
zap.String("user_id", userID),
zap.String("revoked_by", revokedBy))
query := `
UPDATE user_sessions
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
WHERE user_id = $3 AND status = $4`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, domain.SessionStatusActive)
if err != nil {
r.logger.Error("Failed to revoke user sessions", zap.Error(err))
return errors.NewInternalError("Failed to revoke user sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("User sessions revoked",
zap.String("user_id", userID),
zap.Int64("sessions_revoked", rowsAffected))
return nil
}
// RevokeAllByUserAndApp revokes all sessions for a user and application
func (r *sessionRepository) RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error {
r.logger.Debug("Revoking all sessions for user and app",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.String("revoked_by", revokedBy))
query := `
UPDATE user_sessions
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
WHERE user_id = $3 AND app_id = $4 AND status = $5`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, appID, domain.SessionStatusActive)
if err != nil {
r.logger.Error("Failed to revoke user app sessions", zap.Error(err))
return errors.NewInternalError("Failed to revoke user app sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("User app sessions revoked",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.Int64("sessions_revoked", rowsAffected))
return nil
}
// ExpireOldSessions marks expired sessions as expired
func (r *sessionRepository) ExpireOldSessions(ctx context.Context) (int, error) {
r.logger.Debug("Expiring old sessions")
query := `
UPDATE user_sessions
SET status = $1, updated_at = NOW()
WHERE expires_at < NOW() AND status = $2`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, domain.SessionStatusActive)
if err != nil {
r.logger.Error("Failed to expire old sessions", zap.Error(err))
return 0, errors.NewInternalError("Failed to expire old sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("Old sessions expired", zap.Int64("sessions_expired", rowsAffected))
return int(rowsAffected), nil
}
// DeleteExpiredSessions removes expired sessions older than the specified duration
func (r *sessionRepository) DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) {
r.logger.Debug("Deleting expired sessions", zap.Duration("older_than", olderThan))
cutoffTime := time.Now().Add(-olderThan)
query := `DELETE FROM user_sessions WHERE status = $1 AND updated_at < $2`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, cutoffTime)
if err != nil {
r.logger.Error("Failed to delete expired sessions", zap.Error(err))
return 0, errors.NewInternalError("Failed to delete expired sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("Expired sessions deleted", zap.Int64("sessions_deleted", rowsAffected))
return int(rowsAffected), nil
}
// Exists checks if a session exists
func (r *sessionRepository) Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) {
r.logger.Debug("Checking if session exists", zap.String("session_id", sessionID.String()))
query := `SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1)`
var exists bool
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(&exists)
if err != nil {
r.logger.Error("Failed to check session existence", zap.Error(err))
return false, errors.NewInternalError("Failed to check session existence").WithInternal(err)
}
return exists, nil
}
// GetSessionCount returns the total number of sessions for a user
func (r *sessionRepository) GetSessionCount(ctx context.Context, userID string) (int, error) {
r.logger.Debug("Getting session count for user", zap.String("user_id", userID))
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1`
var count int
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
if err != nil {
r.logger.Error("Failed to get session count", zap.Error(err))
return 0, errors.NewInternalError("Failed to get session count").WithInternal(err)
}
return count, nil
}
// GetActiveSessionCount returns the number of active sessions for a user
func (r *sessionRepository) GetActiveSessionCount(ctx context.Context, userID string) (int, error) {
r.logger.Debug("Getting active session count for user", zap.String("user_id", userID))
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW()`
var count int
err := r.db.QueryRowContext(ctx, query, userID, domain.SessionStatusActive).Scan(&count)
if err != nil {
r.logger.Error("Failed to get active session count", zap.Error(err))
return 0, errors.NewInternalError("Failed to get active session count").WithInternal(err)
}
return count, nil
}
// scanSessions is a helper method to scan multiple sessions from query results
func (r *sessionRepository) scanSessions(ctx context.Context, query string, args ...interface{}) ([]*domain.UserSession, error) {
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
r.logger.Error("Failed to execute session query", zap.Error(err))
return nil, errors.NewInternalError("Failed to retrieve sessions").WithInternal(err)
}
defer rows.Close()
var sessions []*domain.UserSession
for rows.Next() {
var session domain.UserSession
var metadataJSON []byte
var revokedAt sql.NullTime
var revokedBy sql.NullString
err := rows.Scan(
&session.ID,
&session.UserID,
&session.AppID,
&session.SessionType,
&session.Status,
&session.AccessToken,
&session.RefreshToken,
&session.IDToken,
&session.IPAddress,
&session.UserAgent,
&session.LastActivity,
&session.ExpiresAt,
&session.CreatedAt,
&session.UpdatedAt,
&revokedAt,
&revokedBy,
&metadataJSON,
)
if err != nil {
r.logger.Error("Failed to scan session row", zap.Error(err))
return nil, errors.NewInternalError("Failed to scan session data").WithInternal(err)
}
// Handle nullable fields
if revokedAt.Valid {
session.RevokedAt = &revokedAt.Time
}
if revokedBy.Valid {
session.RevokedBy = &revokedBy.String
}
// Deserialize metadata
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
}
sessions = append(sessions, &session)
}
if err := rows.Err(); err != nil {
r.logger.Error("Error iterating session rows", zap.Error(err))
return nil, errors.NewInternalError("Failed to iterate session results").WithInternal(err)
}
return sessions, nil
}