Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Support to modify non_blocking parameters. #2567

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 56 additions & 54 deletions mmdet3d/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from numbers import Number
from typing import Dict, List, Optional, Sequence, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from mmdet.models import DetDataPreprocessor
from mmdet.models.utils.misc import samplelist_boxtype2tensor
from mmengine.model import stack_batch
from mmengine.utils import is_list_of
from mmengine.utils import is_seq_of
from torch import Tensor
from torch.nn import functional as F

from mmdet3d.registry import MODELS
Expand All @@ -27,52 +29,56 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
- Collate and move image and point cloud data to the target device.

- 1) For image data:
- Pad images in inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``.
- Stack images in inputs to batch_imgs.
- Convert images in inputs from bgr to rgb if the shape of input is
(3, H, W).
- Normalize images in inputs with defined std and mean.
- Do batch augmentations during training.

- Pad images in inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``.
- Stack images in inputs to batch_imgs.
- Convert images in inputs from bgr to rgb if the shape of input is
(3, H, W).
- Normalize images in inputs with defined std and mean.
- Do batch augmentations during training.

- 2) For point cloud data:
- If no voxelization, directly return list of point cloud data.
- If voxelization is applied, voxelize point cloud according to
``voxel_type`` and obtain ``voxels``.

- If no voxelization, directly return list of point cloud data.
- If voxelization is applied, voxelize point cloud according to
``voxel_type`` and obtain ``voxels``.

Args:
voxel (bool): Whether to apply voxelization to point cloud.
Defaults to False.
voxel_type (str): Voxelization type. Two voxelization types are
provided: 'hard' and 'dynamic', respectively for hard
voxelization and dynamic voxelization. Defaults to 'hard'.
provided: 'hard' and 'dynamic', respectively for hard voxelization
and dynamic voxelization. Defaults to 'hard'.
voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
config. Defaults to None.
batch_first (bool): Whether to put the batch dimension to the first
dimension when getting voxel coordinates. Defaults to True.
max_voxels (int): Maximum number of voxels in each voxel grid. Defaults
to None.
max_voxels (int, optional): Maximum number of voxels in each voxel
grid. Defaults to None.
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0.
pad_size_divisor (int): The size of padded image should be divisible by
``pad_size_divisor``. Defaults to 1.
pad_value (float or int): The padded pixel value. Defaults to 0.
pad_mask (bool): Whether to pad instance masks. Defaults to False.
mask_pad_value (int): The padded pixel value for instance masks.
Defaults to 0.
pad_seg (bool): Whether to pad semantic segmentation maps.
Defaults to False.
seg_pad_value (int): The padded pixel value for semantic
segmentation maps. Defaults to 255.
seg_pad_value (int): The padded pixel value for semantic segmentation
maps. Defaults to 255.
bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
Defaults to False.
boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
bboxes data or not. Defaults to True.
boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of
bboxes data to ``Tensor`` type. Defaults to True.
non_blocking (bool): Whether to block current process when transferring
data to device. Defaults to False.
batch_augments (List[dict], optional): Batch-level augmentations.
Defaults to None.
"""
Expand All @@ -94,6 +100,7 @@ def __init__(self,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
boxtype2tensor: bool = True,
non_blocking: bool = False,
batch_augments: Optional[List[dict]] = None) -> None:
super(Det3DDataPreprocessor, self).__init__(
mean=mean,
Expand All @@ -106,6 +113,8 @@ def __init__(self,
seg_pad_value=seg_pad_value,
bgr_to_rgb=bgr_to_rgb,
rgb_to_bgr=rgb_to_bgr,
boxtype2tensor=boxtype2tensor,
non_blocking=non_blocking,
batch_augments=batch_augments)
self.voxel = voxel
self.voxel_type = voxel_type
Expand All @@ -121,9 +130,9 @@ def forward(self,
``BaseDataPreprocessor``.

Args:
data (dict or List[dict]): Data from dataloader.
The dict contains the whole batch data, when it is
a list[dict], the list indicate test time augmentation.
data (dict or List[dict]): Data from dataloader. The dict contains
the whole batch data, when it is a list[dict], the list
indicates test time augmentation.
training (bool): Whether to enable training time augmentation.
Defaults to False.

Expand Down Expand Up @@ -184,17 +193,10 @@ def simple_process(self, data: dict, training: bool = False) -> dict:
'pad_shape': pad_shape
})

if hasattr(self, 'boxtype2tensor') and self.boxtype2tensor:
from mmdet.models.utils.misc import \
samplelist_boxtype2tensor
if self.boxtype2tensor:
samplelist_boxtype2tensor(data_samples)
elif hasattr(self, 'boxlist2tensor') and self.boxlist2tensor:
from mmdet.models.utils.misc import \
samplelist_boxlist2tensor
samplelist_boxlist2tensor(data_samples)
if self.pad_mask:
self.pad_gt_masks(data_samples)

if self.pad_seg:
self.pad_gt_sem_seg(data_samples)

Expand All @@ -205,7 +207,7 @@ def simple_process(self, data: dict, training: bool = False) -> dict:

return {'inputs': batch_inputs, 'data_samples': data_samples}

def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
def preprocess_img(self, _batch_img: Tensor) -> Tensor:
# channel transform
if self._channel_conversion:
_batch_img = _batch_img[[2, 1, 0], ...]
Expand All @@ -223,12 +225,11 @@ def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
return _batch_img

def collate_data(self, data: dict) -> dict:
"""Copying data to the target device and Performs normalization,
padding and bgr2rgb conversion and stack based on
``BaseDataPreprocessor``.
"""Copy data to the target device and perform normalization, padding
and bgr2rgb conversion and stack based on ``BaseDataPreprocessor``.

Collates the data sampled from dataloader into a list of dict and
list of labels, and then copies tensor to the target device.
Collates the data sampled from dataloader into a list of dict and list
of labels, and then copies tensor to the target device.

Args:
data (dict): Data sampled from dataloader.
Expand All @@ -241,7 +242,7 @@ def collate_data(self, data: dict) -> dict:
if 'img' in data['inputs']:
_batch_imgs = data['inputs']['img']
# Process data with `pseudo_collate`.
if is_list_of(_batch_imgs, torch.Tensor):
if is_seq_of(_batch_imgs, torch.Tensor):
batch_imgs = []
img_dim = _batch_imgs[0].dim()
for _batch_img in _batch_imgs:
Expand Down Expand Up @@ -289,7 +290,7 @@ def collate_data(self, data: dict) -> dict:
else:
raise TypeError(
'Output of `cast_data` should be a list of dict '
'or a tuple with inputs and data_samples, but got'
'or a tuple with inputs and data_samples, but got '
f'{type(data)}: {data}')

data['inputs']['imgs'] = batch_imgs
Expand All @@ -298,13 +299,13 @@ def collate_data(self, data: dict) -> dict:

return data

def _get_pad_shape(self, data: dict) -> List[tuple]:
def _get_pad_shape(self, data: dict) -> List[Tuple[int, int]]:
"""Get the pad_shape of each image based on data and
pad_size_divisor."""
# rewrite `_get_pad_shape` for obtaining image inputs.
_batch_inputs = data['inputs']['img']
# Process data with `pseudo_collate`.
if is_list_of(_batch_inputs, torch.Tensor):
if is_seq_of(_batch_inputs, torch.Tensor):
batch_pad_shape = []
for ori_input in _batch_inputs:
if ori_input.dim() == 4:
Expand Down Expand Up @@ -338,8 +339,8 @@ def _get_pad_shape(self, data: dict) -> List[tuple]:
return batch_pad_shape

@torch.no_grad()
def voxelize(self, points: List[torch.Tensor],
data_samples: SampleList) -> Dict[str, torch.Tensor]:
def voxelize(self, points: List[Tensor],
data_samples: SampleList) -> Dict[str, Tensor]:
"""Apply voxelization to point cloud.

Args:
Expand Down Expand Up @@ -466,7 +467,8 @@ def voxelize(self, points: List[torch.Tensor],

return voxel_dict

def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
def get_voxel_seg(self, res_coors: Tensor,
data_sample: SampleList) -> None:
"""Get voxel-wise segmentation label and point2voxel map.

Args:
Expand All @@ -490,7 +492,7 @@ def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
data_sample.point2voxel_map = point2voxel_map

def ravel_hash(self, x: np.ndarray) -> np.ndarray:
"""Get voxel coordinates hash for np.unique().
"""Get voxel coordinates hash for np.unique.

Args:
x (np.ndarray): The voxel coordinates of points, Nx3.
Expand Down Expand Up @@ -519,14 +521,14 @@ def sparse_quantize(self,

Args:
coords (np.ndarray): The voxel coordinates of points, Nx3.
return_index (bool): Whether to return the indices of the
unique coords, shape (M,).
return_index (bool): Whether to return the indices of the unique
coords, shape (M,).
return_inverse (bool): Whether to return the indices of the
original coords shape (N,).
original coords, shape (N,).

Returns:
List[np.ndarray] or None: Return index and inverse map if
return_index and return_inverse is True.
List[np.ndarray]: Return index and inverse map if return_index and
return_inverse is True.
"""
_, indices, inverse_indices = np.unique(
self.ravel_hash(coords), return_index=True, return_inverse=True)
Expand Down
40 changes: 19 additions & 21 deletions mmdet3d/models/data_preprocessors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,39 @@

import torch
import torch.nn.functional as F
from torch import Tensor


def multiview_img_stack_batch(
tensor_list: List[torch.Tensor],
pad_size_divisor: int = 1,
pad_value: Union[int, float] = 0) -> torch.Tensor:
"""
Compared to the stack_batch in mmengine.model.utils,
def multiview_img_stack_batch(tensor_list: List[Tensor],
pad_size_divisor: int = 1,
pad_value: Union[int, float] = 0) -> Tensor:
"""Compared to the ``stack_batch`` in `mmengine.model.utils`,
multiview_img_stack_batch further handle the multiview images.
see diff of padded_sizes[:, :-2] = 0 vs padded_sizes[:, 0] = 0 in line 47
Stack multiple tensors to form a batch and pad the tensor to the max
shape use the right bottom padding mode in these images. If

See diff of padded_sizes[:, :-2] = 0 vs padded_sizes[:, 0] = 0 in line 47.

Stack multiple tensors to form a batch and pad the tensor to the max shape
use the right bottom padding mode in these images. If
``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
divisible by ``pad_size_divisor``.

Args:
tensor_list (List[Tensor]): A list of tensors with the same dim.
pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding
to ensure the shape of each dim is divisible by
``pad_size_divisor``. This depends on the model, and many
models need to be divisible by 32. Defaults to 1.
pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding to
ensure the shape of each dim is divisible by ``pad_size_divisor``.
This depends on the model, and many models need to be divisible by
32. Defaults to 1.
pad_value (int or float): The padding value. Defaults to 0.

Returns:
Tensor: The n dim tensor.
"""
assert isinstance(
tensor_list,
list), f'Expected input type to be list, but got {type(tensor_list)}'
assert isinstance(tensor_list, list), \
f'Expected input type to be list, but got {type(tensor_list)}'
assert tensor_list, '`tensor_list` could not be an empty list'
assert len({
tensor.ndim
for tensor in tensor_list
}) == 1, ('Expected the dimensions of all tensors must be the same, '
f'but got {[tensor.ndim for tensor in tensor_list]}')
assert len({tensor.ndim for tensor in tensor_list}) == 1, \
'Expected the dimensions of all tensors must be the same, ' \
f'but got {[tensor.ndim for tensor in tensor_list]}'

dim = tensor_list[0].dim()
num_img = len(tensor_list)
Expand Down