package postgres import ( "context" "database/sql" "fmt" "strings" "time" "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/RyanCopley/skybridge/user/internal/domain" "github.com/RyanCopley/skybridge/user/internal/repository/interfaces" ) type userRepository struct { db *sqlx.DB } // NewUserRepository creates a new user repository func NewUserRepository(db *sqlx.DB) interfaces.UserRepository { return &userRepository{db: db} } func (r *userRepository) Create(ctx context.Context, user *domain.User) error { query := ` INSERT INTO users ( id, email, first_name, last_name, display_name, avatar, role, status, password_hash, password_salt, email_verified, email_verification_token, email_verification_expires_at, two_factor_enabled, two_factor_secret, two_factor_backup_codes, created_at, updated_at, created_by, updated_by ) VALUES ( :id, :email, :first_name, :last_name, :display_name, :avatar, :role, :status, :password_hash, :password_salt, :email_verified, :email_verification_token, :email_verification_expires_at, :two_factor_enabled, :two_factor_secret, :two_factor_backup_codes, :created_at, :updated_at, :created_by, :updated_by )` if user.ID == uuid.Nil { user.ID = uuid.New() } user.CreatedAt = time.Now() user.UpdatedAt = time.Now() if user.Status == "" { user.Status = domain.UserStatusPending } _, err := r.db.NamedExecContext(ctx, query, user) if err != nil { if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "23505" { return fmt.Errorf("user with email %s already exists", user.Email) } return fmt.Errorf("failed to create user: %w", err) } return nil } func (r *userRepository) GetByID(ctx context.Context, id uuid.UUID) (*domain.User, error) { query := ` SELECT id, email, first_name, last_name, display_name, avatar, role, status, last_login_at, password_hash, password_salt, email_verified, email_verification_token, email_verification_expires_at, password_reset_token, password_reset_expires_at, failed_login_attempts, locked_until, two_factor_enabled, two_factor_secret, two_factor_backup_codes, last_password_change, created_at, updated_at, created_by, updated_by FROM users WHERE id = $1` var user domain.User err := r.db.GetContext(ctx, &user, query, id) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("user not found") } return nil, fmt.Errorf("failed to get user: %w", err) } return &user, nil } func (r *userRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) { query := ` SELECT id, email, first_name, last_name, display_name, avatar, role, status, last_login_at, password_hash, password_salt, email_verified, email_verification_token, email_verification_expires_at, password_reset_token, password_reset_expires_at, failed_login_attempts, locked_until, two_factor_enabled, two_factor_secret, two_factor_backup_codes, last_password_change, created_at, updated_at, created_by, updated_by FROM users WHERE email = $1` var user domain.User err := r.db.GetContext(ctx, &user, query, email) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("user not found") } return nil, fmt.Errorf("failed to get user: %w", err) } return &user, nil } func (r *userRepository) Update(ctx context.Context, user *domain.User) error { user.UpdatedAt = time.Now() query := ` UPDATE users SET email = :email, first_name = :first_name, last_name = :last_name, display_name = :display_name, avatar = :avatar, role = :role, status = :status, last_login_at = :last_login_at, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id` result, err := r.db.NamedExecContext(ctx, query, user) if err != nil { if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "23505" { return fmt.Errorf("user with email %s already exists", user.Email) } return fmt.Errorf("failed to update user: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rowsAffected == 0 { return fmt.Errorf("user not found") } return nil } func (r *userRepository) Delete(ctx context.Context, id uuid.UUID) error { query := `DELETE FROM users WHERE id = $1` result, err := r.db.ExecContext(ctx, query, id) if err != nil { return fmt.Errorf("failed to delete user: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rowsAffected == 0 { return fmt.Errorf("user not found") } return nil } func (r *userRepository) List(ctx context.Context, req *domain.ListUsersRequest) (*domain.ListUsersResponse, error) { // Build WHERE clause var conditions []string var args []interface{} argCounter := 1 if req.Status != nil { conditions = append(conditions, fmt.Sprintf("status = $%d", argCounter)) args = append(args, *req.Status) argCounter++ } if req.Role != nil { conditions = append(conditions, fmt.Sprintf("role = $%d", argCounter)) args = append(args, *req.Role) argCounter++ } if req.Search != "" { searchPattern := "%" + strings.ToLower(req.Search) + "%" conditions = append(conditions, fmt.Sprintf("(LOWER(email) LIKE $%d OR LOWER(first_name) LIKE $%d OR LOWER(last_name) LIKE $%d OR LOWER(display_name) LIKE $%d)", argCounter, argCounter, argCounter, argCounter)) args = append(args, searchPattern) argCounter++ } whereClause := "" if len(conditions) > 0 { whereClause = "WHERE " + strings.Join(conditions, " AND ") } // Build ORDER BY clause orderBy := "created_at" orderDir := "DESC" if req.OrderBy != "" { orderBy = req.OrderBy } if req.OrderDir != "" { orderDir = strings.ToUpper(req.OrderDir) } // Set default pagination limit := 20 if req.Limit > 0 { limit = req.Limit } offset := 0 if req.Offset > 0 { offset = req.Offset } // Query for users query := fmt.Sprintf(` SELECT id, email, first_name, last_name, display_name, avatar, role, status, last_login_at, created_at, updated_at, created_by, updated_by FROM users %s ORDER BY %s %s LIMIT $%d OFFSET $%d`, whereClause, orderBy, orderDir, argCounter, argCounter+1) args = append(args, limit, offset) var users []domain.User err := r.db.SelectContext(ctx, &users, query, args...) if err != nil { return nil, fmt.Errorf("failed to list users: %w", err) } // Get total count total, err := r.Count(ctx, req) if err != nil { return nil, fmt.Errorf("failed to get user count: %w", err) } hasMore := offset+len(users) < total return &domain.ListUsersResponse{ Users: users, Total: total, Limit: limit, Offset: offset, HasMore: hasMore, }, nil } func (r *userRepository) UpdateLastLogin(ctx context.Context, id uuid.UUID) error { query := `UPDATE users SET last_login_at = $1 WHERE id = $2` result, err := r.db.ExecContext(ctx, query, time.Now(), id) if err != nil { return fmt.Errorf("failed to update last login: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rowsAffected == 0 { return fmt.Errorf("user not found") } return nil } func (r *userRepository) Count(ctx context.Context, req *domain.ListUsersRequest) (int, error) { var conditions []string var args []interface{} argCounter := 1 if req.Status != nil { conditions = append(conditions, fmt.Sprintf("status = $%d", argCounter)) args = append(args, *req.Status) argCounter++ } if req.Role != nil { conditions = append(conditions, fmt.Sprintf("role = $%d", argCounter)) args = append(args, *req.Role) argCounter++ } if req.Search != "" { searchPattern := "%" + strings.ToLower(req.Search) + "%" conditions = append(conditions, fmt.Sprintf("(LOWER(email) LIKE $%d OR LOWER(first_name) LIKE $%d OR LOWER(last_name) LIKE $%d OR LOWER(display_name) LIKE $%d)", argCounter, argCounter, argCounter, argCounter)) args = append(args, searchPattern) argCounter++ } whereClause := "" if len(conditions) > 0 { whereClause = "WHERE " + strings.Join(conditions, " AND ") } query := fmt.Sprintf("SELECT COUNT(*) FROM users %s", whereClause) var count int err := r.db.GetContext(ctx, &count, query, args...) if err != nil { return 0, fmt.Errorf("failed to count users: %w", err) } return count, nil } func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { query := `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)` var exists bool err := r.db.GetContext(ctx, &exists, query, email) if err != nil { return false, fmt.Errorf("failed to check user existence: %w", err) } return exists, nil } // Security methods func (r *userRepository) IncrementFailedAttempts(ctx context.Context, userID uuid.UUID, lockoutDuration time.Duration) error { query := ` UPDATE users SET failed_login_attempts = failed_login_attempts + 1, locked_until = CASE WHEN failed_login_attempts + 1 >= 5 THEN $2 ELSE locked_until END, updated_at = $3 WHERE id = $1` _, err := r.db.ExecContext(ctx, query, userID, time.Now().Add(lockoutDuration), time.Now()) if err != nil { return fmt.Errorf("failed to increment failed attempts: %w", err) } return nil } func (r *userRepository) ResetFailedAttempts(ctx context.Context, userID uuid.UUID) error { query := ` UPDATE users SET failed_login_attempts = 0, locked_until = NULL, updated_at = $2 WHERE id = $1` _, err := r.db.ExecContext(ctx, query, userID, time.Now()) if err != nil { return fmt.Errorf("failed to reset failed attempts: %w", err) } return nil } func (r *userRepository) GetFailedAttempts(ctx context.Context, userID uuid.UUID) (int, *time.Time, error) { query := `SELECT failed_login_attempts, locked_until FROM users WHERE id = $1` var attempts int var lockedUntil *time.Time err := r.db.QueryRowContext(ctx, query, userID).Scan(&attempts, &lockedUntil) if err != nil { if err == sql.ErrNoRows { return 0, nil, fmt.Errorf("user not found") } return 0, nil, fmt.Errorf("failed to get failed attempts: %w", err) } return attempts, lockedUntil, nil } func (r *userRepository) SetEmailVerified(ctx context.Context, userID uuid.UUID, verified bool) error { query := ` UPDATE users SET email_verified = $2, email_verification_token = NULL, email_verification_expires_at = NULL, updated_at = $3 WHERE id = $1` result, err := r.db.ExecContext(ctx, query, userID, verified, time.Now()) if err != nil { return fmt.Errorf("failed to set email verified: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rowsAffected == 0 { return fmt.Errorf("user not found") } return nil } func (r *userRepository) UpdatePassword(ctx context.Context, userID uuid.UUID, passwordHash string) error { query := ` UPDATE users SET password_hash = $2, last_password_change = $3, password_reset_token = NULL, password_reset_expires_at = NULL, updated_at = $3 WHERE id = $1` result, err := r.db.ExecContext(ctx, query, userID, passwordHash, time.Now()) if err != nil { return fmt.Errorf("failed to update password: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rowsAffected == 0 { return fmt.Errorf("user not found") } return nil } func (r *userRepository) UpdateTwoFactorSettings(ctx context.Context, userID uuid.UUID, enabled bool, secret *string, backupCodes []string) error { query := ` UPDATE users SET two_factor_enabled = $2, two_factor_secret = $3, two_factor_backup_codes = $4, updated_at = $5 WHERE id = $1` result, err := r.db.ExecContext(ctx, query, userID, enabled, secret, pq.Array(backupCodes), time.Now()) if err != nil { return fmt.Errorf("failed to update two factor settings: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rowsAffected == 0 { return fmt.Errorf("user not found") } return nil }