Source code for pointtorch.operations.torch._voxel_downsampling

"""Voxel-based downsampling of a point cloud."""

__all__ = ["voxel_downsampling"]

from typing import Literal, Optional, Tuple

import torch
from torch_scatter.scatter import scatter, scatter_min

from ._make_labels_consecutive import make_labels_consecutive
from ._ravel_index import ravel_multi_index, unravel_flat_index


[docs] def voxel_downsampling( # pylint: disable=too-many-locals coords: torch.Tensor, batch_indices: torch.Tensor, point_cloud_sizes: torch.Tensor, voxel_size: float, *, features: Optional[torch.Tensor] = None, feature_aggregation: Literal["max", "mean", "min", "nearest_neighbor"] = "mean", point_aggregation: Literal["mean", "nearest_neighbor"] = "mean", preserve_order: bool = True, start: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: r""" Voxel-based downsampling of a batch of point clouds. Args: coords: Coordinates of the points to be downsampled. batch_indices: Indices indicating to which input point cloud each point in the batch belongs. point_cloud_sizes: Number of points contained in each input point cloud. voxel_size: The size of the voxels used for downsampling. features: The features of the points to be downsampled. Defaults to `None`. feature_aggregation: Method to be used to aggregate features of points within the same voxel: `"max"` | \ `"mean"` | `"min"` | `"nearest_neighbor"`. `"nearest_neighbor"` means that the features of the point \ closest to the voxel center are selected. point_aggregation: Method to be used to aggregate the point coordinates within the same voxel: `"mean"` | `"nearest_neighbor"`. `"nearest_neighbor"` means that the coordinates of the point closest to the voxel center are selected. preserve_order: If this is set to `True` and `point_aggregation` is set to `"nearest_neighbor"`, the point order is preserved during downsampling. This means that for any two points included in the downsampled point cloud, the point that is first in the original point cloud is also first in the downsampled point cloud. Defaults to `True`. start: Coordinates of a point at which the voxel grid is to be aligned, i.e., the grid is placed so that it starts at a corner point of a voxel. Defaults to `None`, which means that the grid is aligned at the coordinate origin. Returns: Tuple of five tensors. The first contains the coordinates of the downsampled points. The second tensor contains the features of the downsampled points and is `None` if no input features are provided. The third contains indices indicating to which point cloud each downsampled point belongs. The fourth contains the size of each downsampled point cloud. The fifth contains indices indicating in which voxel each point from the original point clouds is located. Raises: ValueError: If `start` is not `None` and has an invalid shape. Shape: - :code:`coords`: :math:`(N, 3)` - :code:`batch_indices`: :math:`(N)` - :code:`point_cloud_sizes`: :math:`(B)` - :code:`features`: :math:`(N, D)` - :code:`start`: :math:`(B, 3)` - Output: Tuple of five tensors. The first has shape :math:`(N', 3)`. The second has shape :math:`(N', D)`. \ The third has shape :math:`(N')`. The fourth has shape :math:`(B)`. The fifth has shape :math:`(N)`. | where | | :math:`N = \text{ number of points before downsampling}` | :math:`N' = \text{ number of points after downsampling}` | :math:`B = \text{ batch size}` | :math:`D = \text{ number of feature channels}` """ if start is None: start_coords = torch.zeros((len(point_cloud_sizes), 3), dtype=torch.float, device=coords.device) else: if start.ndim != 2 or len(start) != len(point_cloud_sizes) or start.size(1) != coords.size(1): raise ValueError(f"The shape of the 'start' tensor is invalid: {start.shape}. ") start_coords = start shifted_coords = coords - start_coords[batch_indices] voxel_indices = torch.floor_divide(shifted_coords, voxel_size).long() # (N, 3) # add batch index as additional coordinate dimension so that points from different batch items are put into # different voxels voxel_indices = torch.column_stack([batch_indices.unsqueeze(-1), voxel_indices]) # (N, 4) shift = voxel_indices.amin(dim=0, keepdim=True) # (1, 4) voxel_indices = voxel_indices - shift shifted_coords = shifted_coords - shift[:, 1:].float() * voxel_size dimensions = voxel_indices.amax(0) + 1 # (4) flattened_indices = ravel_multi_index(voxel_indices, dimensions) # (N) unqiue_cluster_indices, cluster = torch.unique(flattened_indices, sorted=True, return_inverse=True) cluster_centers = unravel_flat_index(unqiue_cluster_indices, dimensions) batch_indices = cluster_centers[:, 0] cluster_centers = cluster_centers[:, 1:].float() * voxel_size + 0.5 * voxel_size scatter_indices: torch.Tensor = make_labels_consecutive( # type: ignore[assignment] flattened_indices - flattened_indices.min() ) if point_aggregation == "nearest_neighbor" or features is not None and feature_aggregation == "nearest_neighbor": point_indices = torch.arange(len(shifted_coords), device=coords.device, dtype=torch.long) dists_to_cluster_centers = torch.linalg.norm( # pylint: disable=not-callable shifted_coords - cluster_centers[cluster], dim=-1 ) _, argmin_indices = scatter_min(dists_to_cluster_centers, scatter_indices) selected_indices = point_indices[argmin_indices] if preserve_order and point_aggregation == "nearest_neighbor": selected_indices, sorting_indices = selected_indices.sort() if point_aggregation == "nearest_neighbor": coords = coords[selected_indices] else: coords = scatter(coords, scatter_indices, dim=0, reduce="mean") if features is not None: if feature_aggregation == "nearest_neighbor": features = features[selected_indices] else: features = scatter(features, scatter_indices, dim=0, reduce=feature_aggregation) if preserve_order and point_aggregation == "nearest_neighbor": features = features[sorting_indices] # pylint: disable=used-before-assignment point_cloud_sizes = scatter(torch.ones_like(batch_indices), batch_indices, reduce="sum") return (coords, features, batch_indices, point_cloud_sizes, cluster)