diff --git a/demo/body3d_pose_lifter_demo.py b/demo/body3d_pose_lifter_demo.py index 72e7b93958..b5c19e8916 100644 --- a/demo/body3d_pose_lifter_demo.py +++ b/demo/body3d_pose_lifter_demo.py @@ -305,6 +305,8 @@ def process_one_image(args, detector, frame, frame_idx, pose_estimator, data_sample=pred_3d_data_samples, det_data_sample=det_data_sample, draw_gt=False, + dataset_2d=pose_det_dataset['type'], + dataset_3d=pose_lift_dataset['type'], show=args.show, draw_bbox=True, kpt_thr=args.kpt_thr, diff --git a/mmpose/visualization/local_visualizer_3d.py b/mmpose/visualization/local_visualizer_3d.py index 7e3462ce79..99d8086a1e 100644 --- a/mmpose/visualization/local_visualizer_3d.py +++ b/mmpose/visualization/local_visualizer_3d.py @@ -9,6 +9,7 @@ from mmengine.dist import master_only from mmengine.structures import InstanceData +from mmpose.apis import convert_keypoint_definition from mmpose.registry import VISUALIZERS from mmpose.structures import PoseDataSample from . import PoseLocalVisualizer @@ -74,18 +75,17 @@ def __init__( self.det_dataset_skeleton = det_dataset_skeleton self.det_dataset_link_color = det_dataset_link_color - def _draw_3d_data_samples( - self, - image: np.ndarray, - pose_samples: PoseDataSample, - draw_gt: bool = True, - kpt_thr: float = 0.3, - num_instances=-1, - axis_azimuth: float = 70, - axis_limit: float = 1.7, - axis_dist: float = 10.0, - axis_elev: float = 15.0, - ): + def _draw_3d_data_samples(self, + image: np.ndarray, + pose_samples: PoseDataSample, + draw_gt: bool = True, + kpt_thr: float = 0.3, + num_instances=-1, + axis_azimuth: float = 70, + axis_limit: float = 1.7, + axis_dist: float = 10.0, + axis_elev: float = 15.0, + scores_2d: Optional[np.ndarray] = None): """Draw keypoints and skeletons (optional) of GT or prediction. Args: @@ -109,6 +109,8 @@ def _draw_3d_data_samples( - y: [y_c - axis_limit/2, y_c + axis_limit/2] - z: [0, axis_limit] Where x_c, y_c is the mean value of x and y coordinates + scores_2d (np.ndarray, optional): Keypoint scores of 2d estimation + that will be used to filter 3d instances. Returns: Tuple(np.ndarray): the drawn image which channel is RGB. @@ -145,20 +147,21 @@ def _draw_3d_data_samples( def _draw_3d_instances_kpts(keypoints, scores, + scores_2d, keypoints_visible, fig_idx, title=None): - for idx, (kpts, score, visible) in enumerate( - zip(keypoints, scores, keypoints_visible)): + for idx, (kpts, score, score_2d) in enumerate( + zip(keypoints, scores, scores_2d)): - valid = np.logical_and(score >= kpt_thr, + valid = np.logical_and(score >= kpt_thr, score_2d >= kpt_thr, np.any(~np.isnan(kpts), axis=-1)) + kpts_valid = kpts[valid] ax = fig.add_subplot( 1, num_fig, fig_idx * (idx + 1), projection='3d') ax.view_init(elev=axis_elev, azim=axis_azimuth) - ax.set_zlim3d([0, axis_limit]) ax.set_aspect('auto') ax.set_xticks([]) ax.set_yticks([]) @@ -171,13 +174,14 @@ def _draw_3d_instances_kpts(keypoints, ax.set_title(f'{title} ({idx})') ax.dist = axis_dist - x_c = np.mean(kpts[valid, 0]) if valid.any() else 0 - y_c = np.mean(kpts[valid, 1]) if valid.any() else 0 + x_c = np.mean(kpts_valid[:, 0]) if valid.any() else 0 + y_c = np.mean(kpts_valid[:, 1]) if valid.any() else 0 + z_c = np.mean(kpts_valid[:, 2]) if valid.any() else 0 ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2]) ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2]) - - kpts = np.array(kpts, copy=False) + ax.set_zlim3d( + [min(0, z_c - axis_limit / 2), z_c + axis_limit / 2]) if self.kpt_color is None or isinstance(self.kpt_color, str): kpt_color = [self.kpt_color] * len(kpts) @@ -189,8 +193,7 @@ def _draw_3d_instances_kpts(keypoints, f'({len(self.kpt_color)}) does not matches ' f'that of keypoints ({len(kpts)})') - kpts = kpts[valid] - x_3d, y_3d, z_3d = np.split(kpts[:, :3], [1, 2], axis=1) + x_3d, y_3d, z_3d = np.split(kpts_valid[:, :3], [1, 2], axis=1) kpt_color = kpt_color[valid][..., ::-1] / 255. @@ -218,7 +221,9 @@ def _draw_3d_instances_kpts(keypoints, ys_3d = kpts[sk_indices, 1] zs_3d = kpts[sk_indices, 2] kpt_score = score[sk_indices] - if kpt_score.min() > kpt_thr: + kpt_score_2d = score_2d[sk_indices] + if kpt_score.min() > kpt_thr and kpt_score_2d.min( + ) > kpt_thr: # matplotlib uses RGB color in [0, 1] value range _color = link_color[sk_id][::-1] / 255. ax.plot( @@ -233,13 +238,16 @@ def _draw_3d_instances_kpts(keypoints, else: scores = np.ones(keypoints.shape[:-1]) + if scores_2d is None: + scores_2d = np.ones(keypoints.shape[:-1]) + if 'keypoints_visible' in pred_instances: keypoints_visible = pred_instances.keypoints_visible else: keypoints_visible = np.ones(keypoints.shape[:-1]) - _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, 1, - 'Prediction') + _draw_3d_instances_kpts(keypoints, scores, scores_2d, + keypoints_visible, 1, 'Prediction') if draw_gt and 'gt_instances' in pose_samples: gt_instances = pose_samples.gt_instances @@ -300,6 +308,7 @@ def _draw_instances_kpts(self, self.set_image(image) img_h, img_w, _ = image.shape + scores = None if 'keypoints' in instances: keypoints = instances.get('transformed_keypoints', @@ -452,7 +461,7 @@ def _draw_instances_kpts(self, self.draw_lines( X, Y, color, line_widths=self.line_width) - return self.get_image() + return self.get_image(), scores @master_only def add_datasample(self, @@ -466,6 +475,8 @@ def add_datasample(self, draw_bbox: bool = False, show_kpt_idx: bool = False, skeleton_style: str = 'mmpose', + dataset_2d: str = 'CocoDataset', + dataset_3d: str = 'Human36mDataset', num_instances: int = -1, show: bool = False, wait_time: float = 0, @@ -502,6 +513,10 @@ def add_datasample(self, Defaults to ``False`` skeleton_style (str): Skeleton style selection. Defaults to ``'mmpose'`` + dataset_2d (str): Name of 2d keypoint dataset. Defaults to + ``'CocoDataset'`` + dataset_3d (str): Name of 3d keypoint dataset. Defaults to + ``'Human36mDataset'`` num_instances (int): Number of instances to be shown in 3D. If smaller than 0, all the instances in the pose_result will be shown. Otherwise, pad or truncate the pose_result to a length @@ -517,24 +532,31 @@ def add_datasample(self, det_img_data = None gt_img_data = None + scores_2d = None if draw_2d: det_img_data = image.copy() # draw bboxes & keypoints if 'pred_instances' in det_data_sample: - det_img_data = self._draw_instances_kpts( + det_img_data, scores_2d = self._draw_instances_kpts( det_img_data, det_data_sample.pred_instances, kpt_thr, show_kpt_idx, skeleton_style) if draw_bbox: det_img_data = self._draw_instances_bbox( det_img_data, det_data_sample.pred_instances) - + if scores_2d is not None: + if scores_2d.ndim == 2: + scores_2d = scores_2d[..., None] + scores_2d = np.squeeze( + convert_keypoint_definition(scores_2d, dataset_2d, dataset_3d), + axis=-1) pred_img_data = self._draw_3d_data_samples( image.copy(), data_sample, draw_gt=draw_gt, - num_instances=num_instances) + num_instances=num_instances, + scores_2d=scores_2d) # merge visualization results if det_img_data is not None and gt_img_data is not None: