Skip to content

Commit

Permalink
[Fix] Fix LoadMultiViewImageFromFiles to be compatible with `Defaul…
Browse files Browse the repository at this point in the history
…tFormatBundle` (#611)

* unravel multi-view img to list

* add unit test
  • Loading branch information
Wuziyi616 authored Jul 1, 2021
1 parent 0759041 commit 7fec1d5
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
11 changes: 8 additions & 3 deletions mmdet3d/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ def __call__(self, results):
- img_norm_cfg (dict): Normalization configuration of images.
"""
filename = results['img_filename']
# img is of shape (h, w, c, num_views)
img = np.stack(
[mmcv.imread(name, self.color_type) for name in filename], axis=-1)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename
results['img'] = img
# unravel to list, see `DefaultFormatBundle` in formating.py
# which will transpose each image separately and then stack into array
results['img'] = [img[..., i] for i in range(img.shape[-1])]
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
Expand All @@ -61,8 +64,10 @@ def __call__(self, results):

def __repr__(self):
"""str: Return a string that describes the module."""
return f'{self.__class__.__name__} (to_float32={self.to_float32}, '\
f"color_type='{self.color_type}')"
repr_str = self.__class__.__name__
repr_str += f'(to_float32={self.to_float32}, '
repr_str += f"color_type='{self.color_type}')"
return repr_str


@PIPELINES.register_module()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
import torch
from mmcv.parallel import DataContainer

from mmdet3d.datasets.pipelines import (DefaultFormatBundle,
LoadMultiViewImageFromFiles)


def test_load_multi_view_image_from_files():
multi_view_img_loader = LoadMultiViewImageFromFiles(to_float32=True)

num_views = 6
filename = 'tests/data/waymo/kitti_format/training/image_0/0000000.png'
filenames = [filename for _ in range(num_views)]

input_dict = dict(img_filename=filenames)
results = multi_view_img_loader(input_dict)
img = results['img']
img0 = img[0]
img_norm_cfg = results['img_norm_cfg']

assert isinstance(img, list)
assert len(img) == num_views
assert img0.dtype == np.float32
assert results['filename'] == filenames
assert results['img_shape'] == results['ori_shape'] == \
results['pad_shape'] == (1280, 1920, 3, num_views)
assert results['scale_factor'] == 1.0
assert np.all(img_norm_cfg['mean'] == np.zeros(3, dtype=np.float32))
assert np.all(img_norm_cfg['std'] == np.ones(3, dtype=np.float32))
assert not img_norm_cfg['to_rgb']

repr_str = repr(multi_view_img_loader)
expected_str = 'LoadMultiViewImageFromFiles(to_float32=True, ' \
"color_type='unchanged')"
assert repr_str == expected_str

# test LoadMultiViewImageFromFiles's compatibility with DefaultFormatBundle
# refer to https://github.com/open-mmlab/mmdetection3d/issues/227
default_format_bundle = DefaultFormatBundle()
results = default_format_bundle(results)
img = results['img']

assert isinstance(img, DataContainer)
assert img._data.shape == torch.Size((num_views, 3, 1280, 1920))

0 comments on commit 7fec1d5

Please sign in to comment.