Files
skybridge/user/internal/repository/postgres/user_repository.go
2025-09-01 18:26:44 -04:00

448 lines
12 KiB
Go

package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/RyanCopley/skybridge/user/internal/domain"
"github.com/RyanCopley/skybridge/user/internal/repository/interfaces"
)
type userRepository struct {
db *sqlx.DB
}
// NewUserRepository creates a new user repository
func NewUserRepository(db *sqlx.DB) interfaces.UserRepository {
return &userRepository{db: db}
}
func (r *userRepository) Create(ctx context.Context, user *domain.User) error {
query := `
INSERT INTO users (
id, email, first_name, last_name, display_name, avatar,
role, status, password_hash, password_salt, email_verified,
email_verification_token, email_verification_expires_at,
two_factor_enabled, two_factor_secret, two_factor_backup_codes,
created_at, updated_at, created_by, updated_by
) VALUES (
:id, :email, :first_name, :last_name, :display_name, :avatar,
:role, :status, :password_hash, :password_salt, :email_verified,
:email_verification_token, :email_verification_expires_at,
:two_factor_enabled, :two_factor_secret, :two_factor_backup_codes,
:created_at, :updated_at, :created_by, :updated_by
)`
if user.ID == uuid.Nil {
user.ID = uuid.New()
}
user.CreatedAt = time.Now()
user.UpdatedAt = time.Now()
if user.Status == "" {
user.Status = domain.UserStatusPending
}
_, err := r.db.NamedExecContext(ctx, query, user)
if err != nil {
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "23505" {
return fmt.Errorf("user with email %s already exists", user.Email)
}
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
func (r *userRepository) GetByID(ctx context.Context, id uuid.UUID) (*domain.User, error) {
query := `
SELECT id, email, first_name, last_name, display_name, avatar,
role, status, last_login_at, password_hash, password_salt,
email_verified, email_verification_token, email_verification_expires_at,
password_reset_token, password_reset_expires_at, failed_login_attempts,
locked_until, two_factor_enabled, two_factor_secret, two_factor_backup_codes,
last_password_change, created_at, updated_at, created_by, updated_by
FROM users
WHERE id = $1`
var user domain.User
err := r.db.GetContext(ctx, &user, query, id)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
}
return nil, fmt.Errorf("failed to get user: %w", err)
}
return &user, nil
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
query := `
SELECT id, email, first_name, last_name, display_name, avatar,
role, status, last_login_at, password_hash, password_salt,
email_verified, email_verification_token, email_verification_expires_at,
password_reset_token, password_reset_expires_at, failed_login_attempts,
locked_until, two_factor_enabled, two_factor_secret, two_factor_backup_codes,
last_password_change, created_at, updated_at, created_by, updated_by
FROM users
WHERE email = $1`
var user domain.User
err := r.db.GetContext(ctx, &user, query, email)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
}
return nil, fmt.Errorf("failed to get user: %w", err)
}
return &user, nil
}
func (r *userRepository) Update(ctx context.Context, user *domain.User) error {
user.UpdatedAt = time.Now()
query := `
UPDATE users SET
email = :email,
first_name = :first_name,
last_name = :last_name,
display_name = :display_name,
avatar = :avatar,
role = :role,
status = :status,
last_login_at = :last_login_at,
updated_at = :updated_at,
updated_by = :updated_by
WHERE id = :id`
result, err := r.db.NamedExecContext(ctx, query, user)
if err != nil {
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "23505" {
return fmt.Errorf("user with email %s already exists", user.Email)
}
return fmt.Errorf("failed to update user: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("user not found")
}
return nil
}
func (r *userRepository) Delete(ctx context.Context, id uuid.UUID) error {
query := `DELETE FROM users WHERE id = $1`
result, err := r.db.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("user not found")
}
return nil
}
func (r *userRepository) List(ctx context.Context, req *domain.ListUsersRequest) (*domain.ListUsersResponse, error) {
// Build WHERE clause
var conditions []string
var args []interface{}
argCounter := 1
if req.Status != nil {
conditions = append(conditions, fmt.Sprintf("status = $%d", argCounter))
args = append(args, *req.Status)
argCounter++
}
if req.Role != nil {
conditions = append(conditions, fmt.Sprintf("role = $%d", argCounter))
args = append(args, *req.Role)
argCounter++
}
if req.Search != "" {
searchPattern := "%" + strings.ToLower(req.Search) + "%"
conditions = append(conditions, fmt.Sprintf("(LOWER(email) LIKE $%d OR LOWER(first_name) LIKE $%d OR LOWER(last_name) LIKE $%d OR LOWER(display_name) LIKE $%d)", argCounter, argCounter, argCounter, argCounter))
args = append(args, searchPattern)
argCounter++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// Build ORDER BY clause
orderBy := "created_at"
orderDir := "DESC"
if req.OrderBy != "" {
orderBy = req.OrderBy
}
if req.OrderDir != "" {
orderDir = strings.ToUpper(req.OrderDir)
}
// Set default pagination
limit := 20
if req.Limit > 0 {
limit = req.Limit
}
offset := 0
if req.Offset > 0 {
offset = req.Offset
}
// Query for users
query := fmt.Sprintf(`
SELECT id, email, first_name, last_name, display_name, avatar,
role, status, last_login_at, created_at, updated_at, created_by, updated_by
FROM users
%s
ORDER BY %s %s
LIMIT $%d OFFSET $%d`,
whereClause, orderBy, orderDir, argCounter, argCounter+1)
args = append(args, limit, offset)
var users []domain.User
err := r.db.SelectContext(ctx, &users, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
// Get total count
total, err := r.Count(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get user count: %w", err)
}
hasMore := offset+len(users) < total
return &domain.ListUsersResponse{
Users: users,
Total: total,
Limit: limit,
Offset: offset,
HasMore: hasMore,
}, nil
}
func (r *userRepository) UpdateLastLogin(ctx context.Context, id uuid.UUID) error {
query := `UPDATE users SET last_login_at = $1 WHERE id = $2`
result, err := r.db.ExecContext(ctx, query, time.Now(), id)
if err != nil {
return fmt.Errorf("failed to update last login: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("user not found")
}
return nil
}
func (r *userRepository) Count(ctx context.Context, req *domain.ListUsersRequest) (int, error) {
var conditions []string
var args []interface{}
argCounter := 1
if req.Status != nil {
conditions = append(conditions, fmt.Sprintf("status = $%d", argCounter))
args = append(args, *req.Status)
argCounter++
}
if req.Role != nil {
conditions = append(conditions, fmt.Sprintf("role = $%d", argCounter))
args = append(args, *req.Role)
argCounter++
}
if req.Search != "" {
searchPattern := "%" + strings.ToLower(req.Search) + "%"
conditions = append(conditions, fmt.Sprintf("(LOWER(email) LIKE $%d OR LOWER(first_name) LIKE $%d OR LOWER(last_name) LIKE $%d OR LOWER(display_name) LIKE $%d)", argCounter, argCounter, argCounter, argCounter))
args = append(args, searchPattern)
argCounter++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
query := fmt.Sprintf("SELECT COUNT(*) FROM users %s", whereClause)
var count int
err := r.db.GetContext(ctx, &count, query, args...)
if err != nil {
return 0, fmt.Errorf("failed to count users: %w", err)
}
return count, nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
query := `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`
var exists bool
err := r.db.GetContext(ctx, &exists, query, email)
if err != nil {
return false, fmt.Errorf("failed to check user existence: %w", err)
}
return exists, nil
}
// Security methods
func (r *userRepository) IncrementFailedAttempts(ctx context.Context, userID uuid.UUID, lockoutDuration time.Duration) error {
query := `
UPDATE users SET
failed_login_attempts = failed_login_attempts + 1,
locked_until = CASE
WHEN failed_login_attempts + 1 >= 5 THEN $2
ELSE locked_until
END,
updated_at = $3
WHERE id = $1`
_, err := r.db.ExecContext(ctx, query, userID, time.Now().Add(lockoutDuration), time.Now())
if err != nil {
return fmt.Errorf("failed to increment failed attempts: %w", err)
}
return nil
}
func (r *userRepository) ResetFailedAttempts(ctx context.Context, userID uuid.UUID) error {
query := `
UPDATE users SET
failed_login_attempts = 0,
locked_until = NULL,
updated_at = $2
WHERE id = $1`
_, err := r.db.ExecContext(ctx, query, userID, time.Now())
if err != nil {
return fmt.Errorf("failed to reset failed attempts: %w", err)
}
return nil
}
func (r *userRepository) GetFailedAttempts(ctx context.Context, userID uuid.UUID) (int, *time.Time, error) {
query := `SELECT failed_login_attempts, locked_until FROM users WHERE id = $1`
var attempts int
var lockedUntil *time.Time
err := r.db.QueryRowContext(ctx, query, userID).Scan(&attempts, &lockedUntil)
if err != nil {
if err == sql.ErrNoRows {
return 0, nil, fmt.Errorf("user not found")
}
return 0, nil, fmt.Errorf("failed to get failed attempts: %w", err)
}
return attempts, lockedUntil, nil
}
func (r *userRepository) SetEmailVerified(ctx context.Context, userID uuid.UUID, verified bool) error {
query := `
UPDATE users SET
email_verified = $2,
email_verification_token = NULL,
email_verification_expires_at = NULL,
updated_at = $3
WHERE id = $1`
result, err := r.db.ExecContext(ctx, query, userID, verified, time.Now())
if err != nil {
return fmt.Errorf("failed to set email verified: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("user not found")
}
return nil
}
func (r *userRepository) UpdatePassword(ctx context.Context, userID uuid.UUID, passwordHash string) error {
query := `
UPDATE users SET
password_hash = $2,
last_password_change = $3,
password_reset_token = NULL,
password_reset_expires_at = NULL,
updated_at = $3
WHERE id = $1`
result, err := r.db.ExecContext(ctx, query, userID, passwordHash, time.Now())
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("user not found")
}
return nil
}
func (r *userRepository) UpdateTwoFactorSettings(ctx context.Context, userID uuid.UUID, enabled bool, secret *string, backupCodes []string) error {
query := `
UPDATE users SET
two_factor_enabled = $2,
two_factor_secret = $3,
two_factor_backup_codes = $4,
updated_at = $5
WHERE id = $1`
result, err := r.db.ExecContext(ctx, query, userID, enabled, secret, pq.Array(backupCodes), time.Now())
if err != nil {
return fmt.Errorf("failed to update two factor settings: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("user not found")
}
return nil
}