Source code for pointtree.instance_segmentation._instance_segmentation_algorithm

""" Base class for implementing tree instance segmentation algorithms. """

__all__ = ["InstanceSegmentationAlgorithm"]

import abc
import logging
from typing import Optional, Tuple
import sys

from numba_kdtree import KDTree
import numpy as np
import pandas as pd

from pointtree.evaluation import Timer, TimeTracker
from pointtree.operations import voxel_downsampling


[docs] class InstanceSegmentationAlgorithm(abc.ABC): """ Base class for implementing tree instance segmentation algorithms. Args: trunk_class_id: Integer class ID that designates the tree trunk points. crown_class_id: Integer class ID that designates the tree crown points. branch_class_id: Integer class ID that designates the tree branch points. Defaults to `None`, assuming that branch points are not separately labeled. downsampling_voxel_size: Voxel size for the voxel-based downsampling of the tree points before performing the tree instance segmentation. Defaults to :code:`None`, which means that the tree instance segmentation is performed with the full resolution of the point cloud. """ def __init__( self, trunk_class_id: int, crown_class_id: int, branch_class_id: Optional[int] = None, downsampling_voxel_size: Optional[float] = None, ): self._trunk_class_id = trunk_class_id self._crown_class_id = crown_class_id self._branch_class_id = branch_class_id self._downsampling_voxel_size = downsampling_voxel_size self._time_tracker = TimeTracker() logging.basicConfig( level=logging.INFO, format="%(asctime)s:%(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) self._logger = logging.getLogger(__name__) self._logger.setLevel(logging.INFO) def __call__( # pylint: disable=too-many-locals self, point_cloud: pd.DataFrame, point_cloud_id: Optional[str] = None, semantic_segmentation_column: str = "classification", ) -> np.ndarray: r""" Segments tree instances in a point cloud. Args: point_cloud: Point cloud to be segmented. point_cloud_id: ID of the point cloud to be used in the file names of the visualization results. Defaults to `None`, which means that no visualizations are created. semantic_segmentation_column: Point cloud column containing the semantic segmentation class IDs. Defaults to `"classification"`. Returns: Instance IDs of each point. Points that do not belong to any instance are assigned the ID :math:`-1`. Shape: - :code:`point_cloud`: :math:`(N, D)` - :code:`classification`: :math:`(N)` - Output: :math:`(N)`. | where | | :math:`N = \text{ number of tree points}` | :math:`N = \text{ number of feature channels per point}` """ self._time_tracker.reset() with Timer("Total", self._time_tracker): # check that point cloud contains all required variables required_columns = ["x", "y", "z", semantic_segmentation_column] missing_columns = list(set(required_columns).difference(point_cloud.columns)) if len(missing_columns) > 0: if len(missing_columns) > 1: missing_columns_str = ", ".join(missing_columns[:-1]) + ", and " + missing_columns[-1] else: missing_columns_str = missing_columns[0] raise ValueError(f"The point cloud must contain the columns {missing_columns_str}.") point_cloud_size = len(point_cloud) tree_mask = np.logical_or( point_cloud[semantic_segmentation_column].to_numpy() == self._trunk_class_id, point_cloud[semantic_segmentation_column].to_numpy() == self._crown_class_id, ) if self._branch_class_id is not None: tree_mask = np.logical_or( tree_mask, point_cloud[semantic_segmentation_column].to_numpy() == self._branch_class_id ) if tree_mask.sum() == 0: return np.full(len(point_cloud), fill_value=-1, dtype=np.int64) tree_points = point_cloud[tree_mask] del point_cloud if self._downsampling_voxel_size is not None: with Timer("Voxel-based downsampling", self._time_tracker): self._logger.info("Downsample point cloud...") _, predicted_point_indices = voxel_downsampling( tree_points.to_numpy(), self._downsampling_voxel_size, point_aggregation="nearest_neighbor" ) downsampled_tree_points = tree_points.iloc[predicted_point_indices] self._logger.info("Points after downsampling: %d", len(downsampled_tree_points)) non_predicted_point_indices = np.setdiff1d( np.arange(len(tree_points), dtype=np.int64), predicted_point_indices ) else: downsampled_tree_points = tree_points tree_coords = downsampled_tree_points[["x", "y", "z"]].to_numpy() classification = downsampled_tree_points[semantic_segmentation_column].to_numpy() del downsampled_tree_points instance_ids, unique_instance_ids = self._segment_tree_points( tree_coords, classification, point_cloud_id=point_cloud_id ) if self._downsampling_voxel_size is not None: with Timer("Upsampling of labels", self._time_tracker): instance_ids = self.upsample_instance_ids( tree_points[["x", "y", "z"]].to_numpy(), tree_coords, instance_ids, non_predicted_point_indices, predicted_point_indices, ) full_instance_ids = np.full(point_cloud_size, fill_value=-1, dtype=np.int64) full_instance_ids[tree_mask] = instance_ids self._logger.info("Detected %d trees.", len(unique_instance_ids)) return full_instance_ids @abc.abstractmethod def _segment_tree_points( self, tree_coords: np.ndarray, classification: np.ndarray, point_cloud_id: Optional[str] = None ) -> Tuple[np.ndarray, np.ndarray]: r""" Performs tree instance segmentation. Has to be overridden by subclasses. Args: tree_coords: Coordinates of all tree points. classification: Class IDs of each tree point. point_cloud_id: ID of the point cloud to be used in the file names of the visualization results. Defaults to `None`, which means that no visualizations are created. Returns: Instance IDs of each point and unique instance IDs. Points that do not belong to any instance are assigned the ID :math:`-1`. Shape: - :code:`tree_coords`: :math:`(N, 3)` - :code:`classification`: :math:`(N)` - Output: Tuple of two arrays. The first has shape :math:`(N)` and the second :math:`(I)`. | where | | :math:`N = \text{ number of tree points}` | :math:`I = \text{ number of tree instances}` """
[docs] @staticmethod def upsample_instance_ids( original_coords: np.ndarray, downsampled_coords: np.ndarray, instance_ids: np.ndarray, non_predicted_point_indices: np.ndarray, predicted_point_indices: np.ndarray, ) -> np.ndarray: r""" Upsamples instance segmentation labels to a higher point cloud resolution. Args: original_coords: Coordinates of the points from the higher-resolution point cloud downsampled_coords: Coordinates of the points from the lower-resolution point cloud. instance_ids: Instance segmentation labels to be upsampled. non_predicted_point_indices: Indices of the points from the higher-resolution point cloud that were discarded during downsampling. predicted_point_indices: Indices of the points from the higher-resolution point cloud that were kept during downsampling. Returns: Upsampled instance segmentation labels. Shape: - :code:`original_coords`: :math:`(N, 3)` - :code:`downsampled_coords`: :math:`(N', 3)` - :code:`instance_ids`: :math:`(N')` - :code:`instance_ids`: :math:`(N - N')` - :code:`instance_ids`: :math:`(N')` - Output: :math:`(N)`. | where | | :math:`N = \text{ number of points before in the higher-resolution point cloud}` | :math:`N' = \text{ number of points before in the lower-resolution point cloud}` """ query_points = original_coords[non_predicted_point_indices] nearest_neighbor_indices = KDTree(downsampled_coords).query(query_points, k=1)[1] upsampled_instance_ids = np.full(len(original_coords), fill_value=-1, dtype=np.int64) upsampled_instance_ids[predicted_point_indices] = instance_ids upsampled_instance_ids[non_predicted_point_indices] = instance_ids[nearest_neighbor_indices[:, 0]] return upsampled_instance_ids
[docs] def runtime_stats( self, ) -> pd.DataFrame: """ Returns: Tracked execution times as `pandas.DataFrame <https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html>`__ with the columns :code:`"Description"` and :code:`"Runtime"`. """ return self._time_tracker.to_pandas()