From 7fec1d533bacf5b5c89d456f65c775c3cc458c72 Mon Sep 17 00:00:00 2001 From: Ziyi Wu Date: Thu, 1 Jul 2021 19:26:40 +0800 Subject: [PATCH] [Fix] Fix `LoadMultiViewImageFromFiles` to be compatible with `DefaultFormatBundle` (#611) * unravel multi-view img to list * add unit test --- mmdet3d/datasets/pipelines/loading.py | 11 +++-- .../test_load_images_from_multi_views.py | 45 +++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 tests/test_data/test_pipelines/test_loadings/test_load_images_from_multi_views.py diff --git a/mmdet3d/datasets/pipelines/loading.py b/mmdet3d/datasets/pipelines/loading.py index e9ed68190..e171bb96d 100644 --- a/mmdet3d/datasets/pipelines/loading.py +++ b/mmdet3d/datasets/pipelines/loading.py @@ -41,12 +41,15 @@ def __call__(self, results): - img_norm_cfg (dict): Normalization configuration of images. """ filename = results['img_filename'] + # img is of shape (h, w, c, num_views) img = np.stack( [mmcv.imread(name, self.color_type) for name in filename], axis=-1) if self.to_float32: img = img.astype(np.float32) results['filename'] = filename - results['img'] = img + # unravel to list, see `DefaultFormatBundle` in formating.py + # which will transpose each image separately and then stack into array + results['img'] = [img[..., i] for i in range(img.shape[-1])] results['img_shape'] = img.shape results['ori_shape'] = img.shape # Set initial values for default meta_keys @@ -61,8 +64,10 @@ def __call__(self, results): def __repr__(self): """str: Return a string that describes the module.""" - return f'{self.__class__.__name__} (to_float32={self.to_float32}, '\ - f"color_type='{self.color_type}')" + repr_str = self.__class__.__name__ + repr_str += f'(to_float32={self.to_float32}, ' + repr_str += f"color_type='{self.color_type}')" + return repr_str @PIPELINES.register_module() diff --git a/tests/test_data/test_pipelines/test_loadings/test_load_images_from_multi_views.py b/tests/test_data/test_pipelines/test_loadings/test_load_images_from_multi_views.py new file mode 100644 index 000000000..9c227160e --- /dev/null +++ b/tests/test_data/test_pipelines/test_loadings/test_load_images_from_multi_views.py @@ -0,0 +1,45 @@ +import numpy as np +import torch +from mmcv.parallel import DataContainer + +from mmdet3d.datasets.pipelines import (DefaultFormatBundle, + LoadMultiViewImageFromFiles) + + +def test_load_multi_view_image_from_files(): + multi_view_img_loader = LoadMultiViewImageFromFiles(to_float32=True) + + num_views = 6 + filename = 'tests/data/waymo/kitti_format/training/image_0/0000000.png' + filenames = [filename for _ in range(num_views)] + + input_dict = dict(img_filename=filenames) + results = multi_view_img_loader(input_dict) + img = results['img'] + img0 = img[0] + img_norm_cfg = results['img_norm_cfg'] + + assert isinstance(img, list) + assert len(img) == num_views + assert img0.dtype == np.float32 + assert results['filename'] == filenames + assert results['img_shape'] == results['ori_shape'] == \ + results['pad_shape'] == (1280, 1920, 3, num_views) + assert results['scale_factor'] == 1.0 + assert np.all(img_norm_cfg['mean'] == np.zeros(3, dtype=np.float32)) + assert np.all(img_norm_cfg['std'] == np.ones(3, dtype=np.float32)) + assert not img_norm_cfg['to_rgb'] + + repr_str = repr(multi_view_img_loader) + expected_str = 'LoadMultiViewImageFromFiles(to_float32=True, ' \ + "color_type='unchanged')" + assert repr_str == expected_str + + # test LoadMultiViewImageFromFiles's compatibility with DefaultFormatBundle + # refer to https://github.com/open-mmlab/mmdetection3d/issues/227 + default_format_bundle = DefaultFormatBundle() + results = default_format_bundle(results) + img = results['img'] + + assert isinstance(img, DataContainer) + assert img._data.shape == torch.Size((num_views, 3, 1280, 1920))