From 80f372e1a9b79b0a613902f635c1bbbe335ca60d Mon Sep 17 00:00:00 2001 From: ChaimZhu Date: Sun, 10 Oct 2021 21:26:25 +0800 Subject: [PATCH] [Feature] add SMOKE augmentation method (#955) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Refactor] Main code modification for coordinate system refactor (#677) * [Enhance] Add script for data update (#774) * Fixed wrong config paths and fixed a bug in test * Fixed metafile * Coord sys refactor (main code) * Update test_waymo_dataset.py * Manually resolve conflict * Removed unused lines and fixed imports * remove coord2box and box2coord * update dir_limit_offset * Some minor improvements * Removed some \s in comments * Revert a change * Change Box3DMode to Coord3DMode where points are converted * Fix points_in_bbox function * Fix Imvoxelnet config * Revert adding a line * Fix rotation bug when batch size is 0 * Keep sign of dir_scores as before * Fix several comments * Add a comment * Fix docstring * Add data update scripts * Fix comments * fix import (#839) * [Enhance] refactor iou_neg_piecewise_sampler.py (#842) * [Refactor] Main code modification for coordinate system refactor (#677) * [Enhance] Add script for data update (#774) * Fixed wrong config paths and fixed a bug in test * Fixed metafile * Coord sys refactor (main code) * Update test_waymo_dataset.py * Manually resolve conflict * Removed unused lines and fixed imports * remove coord2box and box2coord * update dir_limit_offset * Some minor improvements * Removed some \s in comments * Revert a change * Change Box3DMode to Coord3DMode where points are converted * Fix points_in_bbox function * Fix Imvoxelnet config * Revert adding a line * Fix rotation bug when batch size is 0 * Keep sign of dir_scores as before * Fix several comments * Add a comment * Fix docstring * Add data update scripts * Fix comments * fix import * refactor iou_neg_piecewise_sampler.py * add docstring * modify docstring Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com> Co-authored-by: THU17cyz * [Feature] Add roipooling cuda ops (#843) * [Refactor] Main code modification for coordinate system refactor (#677) * [Enhance] Add script for data update (#774) * Fixed wrong config paths and fixed a bug in test * Fixed metafile * Coord sys refactor (main code) * Update test_waymo_dataset.py * Manually resolve conflict * Removed unused lines and fixed imports * remove coord2box and box2coord * update dir_limit_offset * Some minor improvements * Removed some \s in comments * Revert a change * Change Box3DMode to Coord3DMode where points are converted * Fix points_in_bbox function * Fix Imvoxelnet config * Revert adding a line * Fix rotation bug when batch size is 0 * Keep sign of dir_scores as before * Fix several comments * Add a comment * Fix docstring * Add data update scripts * Fix comments * fix import * add roipooling cuda ops * add roi extractor * add test_roi_extractor unittest * Modify setup.py to install roipooling ops * modify docstring * remove enlarge bbox in roipoint pooling * add_roipooling_ops * modify docstring Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com> Co-authored-by: THU17cyz * [Refactor] Refactor code structure and docstrings (#803) * refactor points_in_boxes * Merge same functions of three boxes * More docstring fixes and unify x/y/z size * Add "optional" and fix "Default" * Add "optional" and fix "Default" * Add "optional" and fix "Default" * Add "optional" and fix "Default" * Add "optional" and fix "Default" * Remove None in function param type * Fix unittest * Add comments for NMS functions * Merge methods of Points * Add unittest * Add optional and default value * Fix box conversion and add unittest * Fix comments * Add unit test * Indent * Fix CI * Remove useless \\ * Remove useless \\ * Remove useless \\ * Remove useless \\ * Remove useless \\ * Add unit test for box bev * More unit tests and refine docstrings in box_np_ops * Fix comment * Add deprecation warning * [Feature] PointXYZWHLRBBoxCoder (#856) * support PointBasedBoxCoder * fix unittest bug * support unittest in gpu * support unittest in gpu * modified docstring * add args * add args * [Enhance] Change Groupfree3D config (#855) * All mods * PointSample * PointSample * [Doc] Add tutorials/data_pipeline Chinese version (#827) * [Doc] Add tutorials/data_pipeline Chinese version * refine doc * Use the absolute link * Use the absolute link Co-authored-by: Tai-Wang * [Doc] Add Chinese doc for `scannet_det.md` (#836) * Part * Complete * Fix comments * Fix comments * [Doc] Add Chinese doc for `waymo_det.md` (#859) * Add complete translation * Refinements * Fix comments * Fix a minor typo Co-authored-by: Tai-Wang * Remove 2D annotations on Lyft (#867) * Add header for files (#869) * Add header for files * Add header for files * Add header for files * Add header for files * [fix] fix typos (#872) * Fix 3 unworking configs (#882) * [Fix] Fix `index.rst` for Chinese docs (#873) * Fix index.rst for zh docs * Change switch language * [Fix] Centerpoint head nested list transpose (#879) * FIX Transpose nested lists without Numpy * Removed unused Numpy import * [Enhance] Update PointFusion (#791) * update point fusion * remove LIDAR hardcode * move get_proj_mat_by_coord_type to utils * fix lint * remove todo * fix lint * [Doc] Add nuscenes_det.md Chinese version (#854) * add nus chinese doc * add nuScenes Chinese doc * fix typo * fix typo * fix typo * fix typo * fix typo * [Fix] Fix RegNet pretrained weight loading (#889) * Fix regnet pretrained weight loading * Remove unused file * Fix centerpoint tta (#892) * [Enhance] Add benchmark regression script (#808) * Initial commit * [Feature] Support DGCNN (v1.0.0.dev0) (#896) * support dgcnn * support dgcnn * support dgcnn * support dgcnn * support dgcnn * support dgcnn * support dgcnn * support dgcnn * support dgcnn * support dgcnn * fix typo * fix typo * fix typo * del gf&fa registry (wo reuse pointnet module) * fix typo * add benchmark and add copyright header (for DGCNN only) * fix typo * fix typo * fix typo * fix typo * fix typo * support dgcnn * Change cam rot_3d_in_axis (#906) * [Doc] Add coord sys tutorial pic and change links to dev branch (#912) * Modify link branch and add pic * Fix pic * [Feature] add kitti AP40 evaluation metric (v1.0.0.dev0) (#927) * Add citation (#901) * [Feature] Add python3.9 in CI (#900) * Add python3.0 in CI * Add python3.0 in CI * Bump to v0.17.0 (#898) * Update README.md * Update README_zh-CN.md * Update version.py * Update getting_started.md * Update getting_started.md * Update changelog.md * Remove "recent" in the news * Remove "recent" in the news * Fix comments * [Docs] Fix the version of sphinx (#902) * Fix sphinx version * Fix sphinx version * Fix sphinx version * Fix sphinx version * Fix sphinx version * Fix sphinx version * Fix sphinx version * Fix sphinx version * Fix sphinx version * Fix sphinx version * Fix sphinx version * Fix sphinx version * add AP40 * add unitest * add unitest * seperate AP11 and AP40 * fix some typos Co-authored-by: dingchang Co-authored-by: Tai-Wang * [Feature] add smoke backbone neck (#939) * add smoke detecotor and it's backbone and neck * typo fix * fix typo * add docstring * fix typo * fix comments * fix comments * fix comments * fix typo * fix typo * fix * fix typo * fix docstring * refine feature * fix typo * use Basemodule in Neck * [Refactor] Refactor the transformation from image to camera coordinates (#938) * Refactor points_img2cam * Refine docstring * Support array converter and add unit tests * [Feature] FCOS3D BBox Coder (#940) * FCOS3D BBox Coder * Add unit tests * Change the value from long to float/double * Rename bbox_out as bbox * Add comments to forward returns * [Feature] PGD BBox Coder (#948) * Support PGD BBox Coder * Refine docstring * add smoke augmentation method * add docs * fix docstrings * fix typos * change point name * fix typos * fix typos * fix typos Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com> Co-authored-by: Xi Liu <75658786+xiliu8006@users.noreply.github.com> Co-authored-by: THU17cyz Co-authored-by: Wenhao Wu <79644370+wHao-Wu@users.noreply.github.com> Co-authored-by: Tai-Wang Co-authored-by: dingchang Co-authored-by: 谢恩泽 Co-authored-by: Robin Karlsson <34254153+robin-karlsson0@users.noreply.github.com> Co-authored-by: Danila Rukhovich --- mmdet3d/datasets/__init__.py | 6 +- mmdet3d/datasets/pipelines/__init__.py | 18 +- mmdet3d/datasets/pipelines/transforms_3d.py | 257 ++++++++++++++++++ .../test_augmentations/test_transforms_3d.py | 107 +++++++- 4 files changed, 371 insertions(+), 17 deletions(-) diff --git a/mmdet3d/datasets/__init__.py b/mmdet3d/datasets/__init__.py index cb64c89d2..cf76d09cb 100644 --- a/mmdet3d/datasets/__init__.py +++ b/mmdet3d/datasets/__init__.py @@ -9,14 +9,14 @@ from .nuscenes_dataset import NuScenesDataset from .nuscenes_mono_dataset import NuScenesMonoDataset # yapf: disable -from .pipelines import (BackgroundPointsFilter, GlobalAlignment, +from .pipelines import (AffineResize, BackgroundPointsFilter, GlobalAlignment, GlobalRotScaleTrans, IndoorPatchPointSample, IndoorPointSample, LoadAnnotations3D, LoadPointsFromFile, LoadPointsFromMultiSweeps, NormalizePointsColor, ObjectNameFilter, ObjectNoise, ObjectRangeFilter, ObjectSample, PointSample, PointShuffle, PointsRangeFilter, RandomDropPointsColor, - RandomFlip3D, RandomJitterPoints, + RandomFlip3D, RandomJitterPoints, RandomShiftScale, VoxelBasedPointSampler) # yapf: enable from .s3dis_dataset import S3DISDataset, S3DISSegDataset @@ -38,5 +38,5 @@ 'Custom3DDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints', - 'ObjectNameFilter' + 'ObjectNameFilter', 'AffineResize', 'RandomShiftScale' ] diff --git a/mmdet3d/datasets/pipelines/__init__.py b/mmdet3d/datasets/pipelines/__init__.py index 68da65a0b..1f0526bec 100644 --- a/mmdet3d/datasets/pipelines/__init__.py +++ b/mmdet3d/datasets/pipelines/__init__.py @@ -7,13 +7,15 @@ LoadPointsFromMultiSweeps, NormalizePointsColor, PointSegClassMapping) from .test_time_aug import MultiScaleFlipAug3D -from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment, - GlobalRotScaleTrans, IndoorPatchPointSample, - IndoorPointSample, ObjectNameFilter, ObjectNoise, - ObjectRangeFilter, ObjectSample, PointSample, - PointShuffle, PointsRangeFilter, - RandomDropPointsColor, RandomFlip3D, - RandomJitterPoints, VoxelBasedPointSampler) +# yapf: disable +from .transforms_3d import (AffineResize, BackgroundPointsFilter, + GlobalAlignment, GlobalRotScaleTrans, + IndoorPatchPointSample, IndoorPointSample, + ObjectNameFilter, ObjectNoise, ObjectRangeFilter, + ObjectSample, PointSample, PointShuffle, + PointsRangeFilter, RandomDropPointsColor, + RandomFlip3D, RandomJitterPoints, RandomShiftScale, + VoxelBasedPointSampler) __all__ = [ 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', @@ -25,5 +27,5 @@ 'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor', - 'RandomJitterPoints' + 'RandomJitterPoints', 'AffineResize', 'RandomShiftScale' ] diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index ad675dade..2bd60ae3c 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import cv2 import numpy as np +import random import warnings from mmcv import is_tuple_of from mmcv.utils import build_from_cfg @@ -1422,3 +1424,258 @@ def _auto_indent(repr_str, indent): repr_str += ' ' * indent + 'prev_voxel_generator=\n' repr_str += f'{_auto_indent(repr(self.prev_voxel_generator), 8)})' return repr_str + + +@PIPELINES.register_module() +class AffineResize(object): + """Get the affine transform matrices to the target size. + + Different from :class:`RandomAffine` in MMDetection, this class can + calculate the affine transform matrices while resizing the input image + to a fixed size. The affine transform matrices include: 1) matrix + transforming original image to the network input image size. 2) matrix + transforming original image to the network output feature map size. + + Args: + img_scale (tuple): Images scales for resizing. + down_ratio (int): The down ratio of feature map. + Actually the arg should be >= 1. + bbox_clip_border (bool, optional): Whether clip the objects + outside the border of the image. Defaults to True. + """ + + def __init__(self, img_scale, down_ratio, bbox_clip_border=True): + + self.img_scale = img_scale + self.down_ratio = down_ratio + self.bbox_clip_border = bbox_clip_border + + def __call__(self, results): + """Call function to do affine transform to input image and labels. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Results after affine resize, 'affine_aug', 'trans_mat' + keys are added in the result dict. + """ + # The results have gone through RandomShiftScale before AffineResize + if 'center' not in results: + img = results['img'] + height, width = img.shape[:2] + center = np.array([width / 2, height / 2], dtype=np.float32) + size = np.array([width, height], dtype=np.float32) + results['affine_aug'] = False + else: + # The results did not go through RandomShiftScale before + # AffineResize + img = results['img'] + center = results['center'] + size = results['size'] + + trans_affine = self._get_transform_matrix(center, size, self.img_scale) + + img = cv2.warpAffine(img, trans_affine[:2, :], self.img_scale) + + if isinstance(self.down_ratio, tuple): + trans_mat = [ + self._get_transform_matrix( + center, size, + (self.img_scale[0] // ratio, self.img_scale[1] // ratio)) + for ratio in self.down_ratio + ] # (3, 3) + else: + trans_mat = self._get_transform_matrix( + center, size, (self.img_scale[0] // self.down_ratio, + self.img_scale[1] // self.down_ratio)) + + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape + results['trans_mat'] = trans_mat + + self._affine_bboxes(results, trans_affine) + + if 'centers2d' in results: + centers2d = self._affine_transform(results['centers2d'], + trans_affine) + valid_index = (centers2d[:, 0] > + 0) & (centers2d[:, 0] < + self.img_scale[0]) & (centers2d[:, 1] > 0) & ( + centers2d[:, 1] < self.img_scale[1]) + results['centers2d'] = centers2d[valid_index] + + for key in results.get('bbox_fields', []): + if key in ['gt_bboxes']: + results[key] = results[key][valid_index] + if 'gt_labels' in results: + results['gt_labels'] = results['gt_labels'][ + valid_index] + if 'gt_masks' in results: + raise NotImplementedError( + 'AffineResize only supports bbox.') + + for key in results.get('bbox3d_fields', []): + if key in ['gt_bboxes_3d']: + results[key].tensor = results[key].tensor[valid_index] + if 'gt_labels_3d' in results: + results['gt_labels_3d'] = results['gt_labels_3d'][ + valid_index] + + results['depths'] = results['depths'][valid_index] + + return results + + def _affine_bboxes(self, results, matrix): + """Affine transform bboxes to input image. + + Args: + results (dict): Result dict from loading pipeline. + matrix (np.ndarray): Matrix transforming original + image to the network input image size. + shape: (3, 3) + """ + + for key in results.get('bbox_fields', []): + bboxes = results[key] + bboxes[:, :2] = self._affine_transform(bboxes[:, :2], matrix) + bboxes[:, 2:] = self._affine_transform(bboxes[:, 2:], matrix) + if self.bbox_clip_border: + bboxes[:, + [0, 2]] = bboxes[:, + [0, 2]].clip(0, self.img_scale[0] - 1) + bboxes[:, + [1, 3]] = bboxes[:, + [1, 3]].clip(0, self.img_scale[1] - 1) + results[key] = bboxes + + def _affine_transform(self, points, matrix): + """Affine transform bbox points to input iamge. + + Args: + points (np.ndarray): Points to be transformed. + shape: (N, 2) + matrix (np.ndarray): Affine transform matrix. + shape: (3, 3) + + Returns: + np.ndarray: Transformed points. + """ + num_points = points.shape[0] + hom_points_2d = np.concatenate((points, np.ones((num_points, 1))), + axis=1) + hom_points_2d = hom_points_2d.T + affined_points = np.matmul(matrix, hom_points_2d).T + return affined_points[:, :2] + + def _get_transform_matrix(self, center, scale, output_scale): + """Get affine transform matrix. + + Args: + center (tuple): Center of current image. + scale (tuple): Scale of current image. + output_scale (tuple[float]): The transform target image scales. + + Returns: + np.ndarray: Affine transform matrix. + """ + # TODO: further add rot and shift here. + src_w = scale[0] + dst_w = output_scale[0] + dst_h = output_scale[1] + + src_dir = np.array([0, src_w * -0.5]) + dst_dir = np.array([0, dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + src[1, :] = center + src_dir + dst[0, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + + src[2, :] = self._get_ref_point(src[0, :], src[1, :]) + dst[2, :] = self._get_ref_point(dst[0, :], dst[1, :]) + + get_matrix = cv2.getAffineTransform(src, dst) + + matrix = np.concatenate((get_matrix, [[0., 0., 1.]])) + + return matrix.astype(np.float32) + + def _get_ref_point(self, ref_point1, ref_point2): + """Get reference point to calculate affine transfrom matrix. + + While using opencv to calculate the affine matrix, we need at least + three corresponding points seperately on original image and target + image. Here we use two points to get the the third reference point. + """ + d = ref_point1 - ref_point2 + ref_point3 = ref_point2 + np.array([-d[1], d[0]]) + return ref_point3 + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(img_scale={self.img_scale}, ' + repr_str += f'down_ratio={self.down_ratio}) ' + return repr_str + + +@PIPELINES.register_module() +class RandomShiftScale(object): + """Random shift scale. + + Different from the normal shift and scale function, it doesn't + directly shift or scale image. It can record the shift and scale + infos into loading pipelines. It's desgined to be used with + AffineResize together. + + Args: + shift_scale (tuple[float]): Shift and scale range. + aug_prob (float): The shifting and scaling probability. + """ + + def __init__(self, shift_scale, aug_prob): + + self.shift_scale = shift_scale + self.aug_prob = aug_prob + + def __call__(self, results): + """Call function to record random shift and scale infos. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Results after random shift and scale, 'center', 'size' + and 'affine_aug' keys are added in the result dict. + """ + img = results['img'] + + height, width = img.shape[:2] + + center = np.array([width / 2, height / 2], dtype=np.float32) + size = np.array([width, height], dtype=np.float32) + + if random.random() < self.aug_prob: + shift, scale = self.shift_scale[0], self.shift_scale[1] + shift_ranges = np.arange(-shift, shift + 0.1, 0.1) + center[0] += size[0] * random.choice(shift_ranges) + center[1] += size[1] * random.choice(shift_ranges) + scale_ranges = np.arange(1 - scale, 1 + scale + 0.1, 0.1) + size *= random.choice(scale_ranges) + results['affine_aug'] = True + else: + results['affine_aug'] = False + + results['center'] = center + results['size'] = size + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(shift_scale={self.shift_scale}, ' + repr_str += f'aug_prob={self.aug_prob}) ' + return repr_str diff --git a/tests/test_data/test_pipelines/test_augmentations/test_transforms_3d.py b/tests/test_data/test_pipelines/test_augmentations/test_transforms_3d.py index 5ac7aeda3..a12022213 100644 --- a/tests/test_data/test_pipelines/test_augmentations/test_transforms_3d.py +++ b/tests/test_data/test_pipelines/test_augmentations/test_transforms_3d.py @@ -8,12 +8,14 @@ DepthInstance3DBoxes, LiDARInstance3DBoxes) from mmdet3d.core.bbox import Coord3DMode from mmdet3d.core.points import DepthPoints, LiDARPoints -from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment, - GlobalRotScaleTrans, ObjectNameFilter, - ObjectNoise, ObjectRangeFilter, ObjectSample, - PointSample, PointShuffle, PointsRangeFilter, - RandomDropPointsColor, RandomFlip3D, - RandomJitterPoints, VoxelBasedPointSampler) +# yapf: disable +from mmdet3d.datasets import (AffineResize, BackgroundPointsFilter, + GlobalAlignment, GlobalRotScaleTrans, + ObjectNameFilter, ObjectNoise, ObjectRangeFilter, + ObjectSample, PointSample, PointShuffle, + PointsRangeFilter, RandomDropPointsColor, + RandomFlip3D, RandomJitterPoints, + RandomShiftScale, VoxelBasedPointSampler) def test_remove_points_in_boxes(): @@ -755,3 +757,96 @@ def test_points_sample(): select_idx = np.array([449, 444]) expected_pts = points.tensor.numpy()[select_idx] assert np.allclose(sampled_pts.tensor.numpy(), expected_pts) + + +def test_affine_resize(): + + def create_random_bboxes(num_bboxes, img_w, img_h): + bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2)) + bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2)) + bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1) + bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype( + np.float32) + return bboxes + + affine_reszie = AffineResize(img_scale=(1290, 384), down_ratio=4) + + # test the situation: not use Random_Scale_Shift before AffineResize + results = dict() + img = mmcv.imread('./tests/data/kitti/training/image_2/000000.png', + 'color') + results['img'] = img + results['bbox_fields'] = ['gt_bboxes'] + results['bbox3d_fields'] = ['gt_bboxes_3d'] + + h, w, _ = img.shape + gt_bboxes = create_random_bboxes(8, w, h) + gt_bboxes_3d = CameraInstance3DBoxes(torch.randn((8, 7))) + results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64) + results['gt_labels3d'] = results['gt_labels'] + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_3d'] = gt_bboxes_3d + results['depths'] = np.random.randn(gt_bboxes.shape[0]) + centers2d_x = (gt_bboxes[:, [0]] + gt_bboxes[:, [2]]) / 2 + centers2d_y = (gt_bboxes[:, [1]] + gt_bboxes[:, [3]]) / 2 + centers2d = np.concatenate((centers2d_x, centers2d_y), axis=1) + results['centers2d'] = centers2d + + results = affine_reszie(results) + + assert results['gt_labels'].shape[0] == results['centers2d'].shape[0] + assert results['gt_labels3d'].shape[0] == results['centers2d'].shape[0] + assert results['gt_bboxes'].shape[0] == results['centers2d'].shape[0] + assert results['gt_bboxes_3d'].tensor.shape[0] == \ + results['centers2d'].shape[0] + assert results['affine_aug'] is False + + # test the situation: not use Random_Scale_Shift before AffineResize + results = dict() + img = mmcv.imread('./tests/data/kitti/training/image_2/000000.png', + 'color') + results['img'] = img + results['bbox_fields'] = ['gt_bboxes'] + results['bbox3d_fields'] = ['gt_bboxes_3d'] + h, w, _ = img.shape + center = np.array([w / 2, h / 2], dtype=np.float32) + size = np.array([w, h], dtype=np.float32) + + results['center'] = center + results['size'] = size + results['affine_aug'] = False + + gt_bboxes = create_random_bboxes(8, w, h) + gt_bboxes_3d = CameraInstance3DBoxes(torch.randn((8, 7))) + results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64) + results['gt_labels3d'] = results['gt_labels'] + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_3d'] = gt_bboxes_3d + results['depths'] = np.random.randn(gt_bboxes.shape[0]) + centers2d_x = (gt_bboxes[:, [0]] + gt_bboxes[:, [2]]) / 2 + centers2d_y = (gt_bboxes[:, [1]] + gt_bboxes[:, [3]]) / 2 + centers2d = np.concatenate((centers2d_x, centers2d_y), axis=1) + results['centers2d'] = centers2d + + results = affine_reszie(results) + + assert results['gt_labels'].shape[0] == results['centers2d'].shape[0] + assert results['gt_labels3d'].shape[0] == results['centers2d'].shape[0] + assert results['gt_bboxes'].shape[0] == results['centers2d'].shape[0] + assert results['gt_bboxes_3d'].tensor.shape[0] == results[ + 'centers2d'].shape[0] + assert 'center' in results + assert 'size' in results + assert 'affine_aug' in results + + +def test_random_shift_scale(): + random_shift_scale = RandomShiftScale(shift_scale=(0.2, 0.4), aug_prob=0.3) + results = dict() + img = mmcv.imread('./tests/data/kitti/training/image_2/000000.png', + 'color') + results['img'] = img + results = random_shift_scale(results) + assert results['center'].dtype == np.float32 + assert results['size'].dtype == np.float32 + assert 'affine_aug' in results