Source code for pointtorch.operations.torch._shuffle

"""Shuffling of points in a point cloud."""

__all__ = ["shuffle"]

from typing import Optional, Tuple

import torch


[docs] def shuffle( points: torch.Tensor, point_cloud_sizes: torch.Tensor, generator: Optional[torch.Generator] = None ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Shuffles points within a batch of point clouds. Each point cloud in the batch can contain a different number of points. Args: points: Batch of point clouds to shuffle. point_cloud_sizes: Number of points contained in each input point cloud. generator: Random generator to be used for shuffling. Defaults to `None`. Returns: Tuple of two tensors. The first is the shuffled tensor. The second contains the index of each point after \ shuffling. Shape: - :attr:`points`: :math:`(N, ...)` - :attr:`point_cloud_sizes`: :math:`(B)` - Output: Tuple of two tensors. The first has the same shape as :attr:`points`. The second has shape `(N)`. | where | | :math:`N = \text{ number of points}` | :math:`B = \text{ batch size}` """ max_point_cloud_size = int(point_cloud_sizes.max().item()) shuffled_indices = torch.randperm(max_point_cloud_size, dtype=torch.long, device=points.device, generator=generator) shuffled_indices = shuffled_indices.unsqueeze(0).repeat(len(point_cloud_sizes), 1) invalid_mask = shuffled_indices >= point_cloud_sizes.unsqueeze(-1) shuffled_indices[1:] += point_cloud_sizes.cumsum(dim=0)[:-1].unsqueeze(-1) # mask invalid indices resulting from different point cloud sizes shuffled_indices[invalid_mask] = -1 shuffled_indices = shuffled_indices.reshape(-1) shuffled_indices = shuffled_indices[shuffled_indices != -1] index_mapping = torch.argsort(shuffled_indices) return points[shuffled_indices], index_mapping