diff --git a/demo/docs/en/3d_hand_demo.md b/demo/docs/en/3d_hand_demo.md new file mode 100644 index 0000000000..b39008ff7b --- /dev/null +++ b/demo/docs/en/3d_hand_demo.md @@ -0,0 +1,37 @@ +## 3D Hand Demo + +
+ +### 3D Hand Estimation Image Demo + +#### Using gt hand bounding boxes as input + +We provide a demo script to test a single image, given gt json file. + +```shell +python demo/hand3d_internet_demo.py \ + ${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \ + --input ${INPUT_FILE} \ + --output-root ${OUTPUT_ROOT} \ + [--save-predictions] \ + [--gt-joints-file ${GT_JOINTS_FILE}]\ + [--disable-rebase-keypoint] \ + [--show] \ + [--device ${GPU_ID or CPU}] \ + [--kpt-thr ${KPT_THR}] \ + [--show-kpt-idx] \ + [--radius ${RADIUS}] \ + [--thickness ${THICKNESS}] +``` + +The pre-trained hand pose estimation model can be downloaded from [model zoo](https://mmpose.readthedocs.io/en/latest/model_zoo/hand_3d_keypoint.html). +Take [internet model](https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3dv1.0_all_256x256-42b7f2ac_20210702.pth) as an example: + +```shell +python demo/hand3d_internet_demo.py \ + configs/hand_3d_keypoint/internet/interhand3d/internet_res50_4xb16-20e_interhand3d-256x256.py \ + https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3dv1.0_all_256x256-42b7f2ac_20210702.pth \ + --input tests/data/interhand2.6m/image69148.jpg \ + --save-predictions \ + --output-root vis_results +``` diff --git a/demo/hand3d_internet_demo.py b/demo/hand3d_internet_demo.py new file mode 100644 index 0000000000..64db01e454 --- /dev/null +++ b/demo/hand3d_internet_demo.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import mimetypes +import os +import time +from argparse import ArgumentParser + +import cv2 +import json_tricks as json +import mmcv +import mmengine +import numpy as np +from mmengine.logging import print_log + +from mmpose.apis import inference_topdown, init_model +from mmpose.registry import VISUALIZERS +from mmpose.structures import (PoseDataSample, merge_data_samples, + split_instances) + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--input', type=str, default='', help='Image/Video file') + parser.add_argument( + '--output-root', + type=str, + default='', + help='root of the output img file. ' + 'Default not saving the visualization images.') + parser.add_argument( + '--save-predictions', + action='store_true', + default=False, + help='whether to save predicted results') + parser.add_argument( + '--disable-rebase-keypoint', + action='store_true', + default=False, + help='Whether to disable rebasing the predicted 3D pose so its ' + 'lowest keypoint has a height of 0 (landing on the ground). Rebase ' + 'is useful for visualization when the model do not predict the ' + 'global position of the 3D pose.') + parser.add_argument( + '--show', + action='store_true', + default=False, + help='whether to show result') + parser.add_argument('--device', default='cpu', help='Device for inference') + parser.add_argument( + '--kpt-thr', + type=float, + default=0.3, + help='Visualizing keypoint thresholds') + parser.add_argument( + '--show-kpt-idx', + action='store_true', + default=False, + help='Whether to show the index of keypoints') + parser.add_argument( + '--radius', + type=int, + default=3, + help='Keypoint radius for visualization') + parser.add_argument( + '--thickness', + type=int, + default=1, + help='Link thickness for visualization') + + args = parser.parse_args() + return args + + +def process_one_image(args, img, model, visualizer=None, show_interval=0): + """Visualize predicted keypoints of one image.""" + # inference a single image + pose_results = inference_topdown(model, img) + # post-processing + pose_results_2d = [] + for idx, res in enumerate(pose_results): + pred_instances = res.pred_instances + keypoints = pred_instances.keypoints + rel_root_depth = pred_instances.rel_root_depth + scores = pred_instances.keypoint_scores + hand_type = pred_instances.hand_type + + res_2d = PoseDataSample() + gt_instances = res.gt_instances.clone() + pred_instances = pred_instances.clone() + res_2d.gt_instances = gt_instances + res_2d.pred_instances = pred_instances + + # add relative root depth to left hand joints + keypoints[:, 21:, 2] += rel_root_depth + + # set joint scores according to hand type + scores[:, :21] *= hand_type[:, [0]] + scores[:, 21:] *= hand_type[:, [1]] + # normalize kpt score + if scores.max() > 1: + scores /= 255 + + res_2d.pred_instances.set_field(keypoints[..., :2].copy(), 'keypoints') + + # rotate the keypoint to make z-axis correspondent to height + # for better visualization + vis_R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + keypoints[..., :3] = keypoints[..., :3] @ vis_R + + # rebase height (z-axis) + if not args.disable_rebase_keypoint: + valid = scores > 0 + keypoints[..., 2] -= np.min( + keypoints[valid, 2], axis=-1, keepdims=True) + + pose_results[idx].pred_instances.keypoints = keypoints + pose_results[idx].pred_instances.keypoint_scores = scores + pose_results_2d.append(res_2d) + + data_samples = merge_data_samples(pose_results) + data_samples_2d = merge_data_samples(pose_results_2d) + + # show the results + if isinstance(img, str): + img = mmcv.imread(img, channel_order='rgb') + elif isinstance(img, np.ndarray): + img = mmcv.bgr2rgb(img) + + if visualizer is not None: + visualizer.add_datasample( + 'result', + img, + data_sample=data_samples, + det_data_sample=data_samples_2d, + draw_gt=False, + draw_bbox=True, + kpt_thr=args.kpt_thr, + convert_keypoint=False, + axis_azimuth=-115, + axis_limit=200, + axis_elev=15, + show_kpt_idx=args.show_kpt_idx, + show=args.show, + wait_time=show_interval) + + # if there is no instance detected, return None + return data_samples.get('pred_instances', None) + + +def main(): + args = parse_args() + + assert args.input != '' + assert args.show or (args.output_root != '') + + output_file = None + if args.output_root: + mmengine.mkdir_or_exist(args.output_root) + output_file = os.path.join(args.output_root, + os.path.basename(args.input)) + if args.input == 'webcam': + output_file += '.mp4' + + if args.save_predictions: + assert args.output_root != '' + args.pred_save_path = f'{args.output_root}/results_' \ + f'{os.path.splitext(os.path.basename(args.input))[0]}.json' + + # build the model from a config file and a checkpoint file + model = init_model(args.config, args.checkpoint, device=args.device) + + # init visualizer + model.cfg.visualizer.radius = args.radius + model.cfg.visualizer.line_width = args.thickness + + visualizer = VISUALIZERS.build(model.cfg.visualizer) + visualizer.set_dataset_meta(model.dataset_meta) + + if args.input == 'webcam': + input_type = 'webcam' + else: + input_type = mimetypes.guess_type(args.input)[0].split('/')[0] + + if input_type == 'image': + # inference + pred_instances = process_one_image(args, args.input, model, visualizer) + + if args.save_predictions: + pred_instances_list = split_instances(pred_instances) + + if output_file: + img_vis = visualizer.get_image() + mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file) + + elif input_type in ['webcam', 'video']: + + if args.input == 'webcam': + cap = cv2.VideoCapture(0) + else: + cap = cv2.VideoCapture(args.input) + + video_writer = None + pred_instances_list = [] + frame_idx = 0 + + while cap.isOpened(): + success, frame = cap.read() + frame_idx += 1 + + if not success: + break + + # topdown pose estimation + pred_instances = process_one_image(args, frame, model, visualizer, + 0.001) + + if args.save_predictions: + # save prediction results + pred_instances_list.append( + dict( + frame_id=frame_idx, + instances=split_instances(pred_instances))) + + # output videos + if output_file: + frame_vis = visualizer.get_image() + + if video_writer is None: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # the size of the image with visualization may vary + # depending on the presence of heatmaps + video_writer = cv2.VideoWriter( + output_file, + fourcc, + 25, # saved fps + (frame_vis.shape[1], frame_vis.shape[0])) + + video_writer.write(mmcv.rgb2bgr(frame_vis)) + + if args.show: + # press ESC to exit + if cv2.waitKey(5) & 0xFF == 27: + break + + time.sleep(args.show_interval) + + if video_writer: + video_writer.release() + + cap.release() + + else: + args.save_predictions = False + raise ValueError( + f'file {os.path.basename(args.input)} has invalid format.') + + if args.save_predictions: + with open(args.pred_save_path, 'w') as f: + json.dump( + dict( + meta_info=model.dataset_meta, + instance_info=pred_instances_list), + f, + indent='\t') + print(f'predictions have been saved at {args.pred_save_path}') + + if output_file is not None: + input_type = input_type.replace('webcam', 'video') + print_log( + f'the output {input_type} has been saved at {output_file}', + logger='current', + level=logging.INFO) + + +if __name__ == '__main__': + main() diff --git a/mmpose/visualization/local_visualizer_3d.py b/mmpose/visualization/local_visualizer_3d.py index 63a1d5b47c..56f3bbfba1 100644 --- a/mmpose/visualization/local_visualizer_3d.py +++ b/mmpose/visualization/local_visualizer_3d.py @@ -85,6 +85,7 @@ def _draw_3d_data_samples(self, axis_limit: float = 1.7, axis_dist: float = 10.0, axis_elev: float = 15.0, + show_kpt_idx: bool = False, scores_2d: Optional[np.ndarray] = None): """Draw keypoints and skeletons (optional) of GT or prediction. @@ -109,13 +110,16 @@ def _draw_3d_data_samples(self, - y: [y_c - axis_limit/2, y_c + axis_limit/2] - z: [0, axis_limit] Where x_c, y_c is the mean value of x and y coordinates + show_kpt_idx (bool): Whether to show the index of keypoints. + Defaults to ``False`` scores_2d (np.ndarray, optional): Keypoint scores of 2d estimation that will be used to filter 3d instances. Returns: Tuple(np.ndarray): the drawn image which channel is RGB. """ - vis_height, vis_width, _ = image.shape + vis_width = max(image.shape) + vis_height = vis_width if 'pred_instances' in pose_samples: pred_instances = pose_samples.pred_instances @@ -150,6 +154,7 @@ def _draw_3d_instances_kpts(keypoints, scores_2d, keypoints_visible, fig_idx, + show_kpt_idx, title=None): for idx, (kpts, score, score_2d) in enumerate( @@ -169,7 +174,6 @@ def _draw_3d_instances_kpts(keypoints, ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_zticklabels([]) - ax.scatter([0], [0], [0], marker='o', color='red') if title: ax.set_title(f'{title} ({idx})') ax.dist = axis_dist @@ -199,9 +203,10 @@ def _draw_3d_instances_kpts(keypoints, ax.scatter(x_3d, y_3d, z_3d, marker='o', c=kpt_color) - for kpt_idx in range(len(x_3d)): - ax.text(x_3d[kpt_idx][0], y_3d[kpt_idx][0], - z_3d[kpt_idx][0], str(kpt_idx)) + if show_kpt_idx: + for kpt_idx in range(len(x_3d)): + ax.text(x_3d[kpt_idx][0], y_3d[kpt_idx][0], + z_3d[kpt_idx][0], str(kpt_idx)) if self.skeleton is not None and self.link_color is not None: if self.link_color is None or isinstance( @@ -247,7 +252,8 @@ def _draw_3d_instances_kpts(keypoints, keypoints_visible = np.ones(keypoints.shape[:-1]) _draw_3d_instances_kpts(keypoints, scores, scores_2d, - keypoints_visible, 1, 'Prediction') + keypoints_visible, 1, show_kpt_idx, + 'Prediction') if draw_gt and 'gt_instances' in pose_samples: gt_instances = pose_samples.gt_instances @@ -260,9 +266,22 @@ def _draw_3d_instances_kpts(keypoints, keypoints_visible = gt_instances.lifting_target_visible else: keypoints_visible = np.ones(keypoints.shape[:-1]) + elif 'keypoints_gt' in gt_instances: + keypoints = gt_instances.get('keypoints_gt', + gt_instances.keypoints_gt) + scores = np.ones(keypoints.shape[:-1]) - _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, - 2, 'Ground Truth') + if 'keypoints_visible' in gt_instances: + keypoints_visible = gt_instances.keypoints_visible + else: + keypoints_visible = np.ones(keypoints.shape[:-1]) + else: + raise ValueError('to visualize ground truth results, ' + 'data sample must contain ' + '"lifting_target" or "keypoints_gt"') + + _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, 2, + show_kpt_idx, 'Ground Truth') # convert figure to numpy array fig.tight_layout() @@ -357,7 +376,7 @@ def _draw_instances_kpts(self, for kpts, score, visible in zip(keypoints, scores, keypoints_visible): - kpts = np.array(kpts, copy=False) + kpts = np.array(kpts[..., :2], copy=False) if kpt_color is None or isinstance(kpt_color, str): kpt_color = [kpt_color] * len(kpts) @@ -477,6 +496,11 @@ def add_datasample(self, skeleton_style: str = 'mmpose', dataset_2d: str = 'coco', dataset_3d: str = 'h36m', + convert_keypoint: bool = True, + axis_azimuth: float = 70, + axis_limit: float = 1.7, + axis_dist: float = 10.0, + axis_elev: float = 15.0, num_instances: int = -1, show: bool = False, wait_time: float = 0, @@ -517,6 +541,17 @@ def add_datasample(self, ``'CocoDataset'`` dataset_3d (str): Name of 3d keypoint dataset. Defaults to ``'Human36mDataset'`` + convert_keypoint (bool): Whether to convert keypoint definition. + Defaults to ``True`` + axis_azimuth (float): axis azimuth angle for 3D visualizations. + axis_dist (float): axis distance for 3D visualizations. + axis_elev (float): axis elevation view angle for 3D visualizations. + axis_limit (float): The axis limit to visualize 3d pose. The xyz + range will be set as: + - x: [x_c - axis_limit/2, x_c + axis_limit/2] + - y: [y_c - axis_limit/2, y_c + axis_limit/2] + - z: [0, axis_limit] + Where x_c, y_c is the mean value of x and y coordinates num_instances (int): Number of instances to be shown in 3D. If smaller than 0, all the instances in the pose_result will be shown. Otherwise, pad or truncate the pose_result to a length @@ -531,21 +566,24 @@ def add_datasample(self, """ det_img_data = None - gt_img_data = None scores_2d = None if draw_2d: det_img_data = image.copy() # draw bboxes & keypoints - if 'pred_instances' in det_data_sample: + if (det_data_sample is not None + and 'pred_instances' in det_data_sample): det_img_data, scores_2d = self._draw_instances_kpts( - det_img_data, det_data_sample.pred_instances, kpt_thr, - show_kpt_idx, skeleton_style) + image=det_img_data, + instances=det_data_sample.pred_instances, + kpt_thr=kpt_thr, + show_kpt_idx=show_kpt_idx, + skeleton_style=skeleton_style) if draw_bbox: det_img_data = self._draw_instances_bbox( det_img_data, det_data_sample.pred_instances) - if scores_2d is not None: + if scores_2d is not None and convert_keypoint: if scores_2d.ndim == 2: scores_2d = scores_2d[..., None] scores_2d = np.squeeze( @@ -556,16 +594,25 @@ def add_datasample(self, data_sample, draw_gt=draw_gt, num_instances=num_instances, + axis_azimuth=axis_azimuth, + axis_limit=axis_limit, + show_kpt_idx=show_kpt_idx, + axis_dist=axis_dist, + axis_elev=axis_elev, scores_2d=scores_2d) # merge visualization results - if det_img_data is not None and gt_img_data is not None: - drawn_img = np.concatenate( - (det_img_data, pred_img_data, gt_img_data), axis=1) - elif det_img_data is not None: + if det_img_data is not None: + width = max(pred_img_data.shape[1] - det_img_data.shape[1], 0) + height = max(pred_img_data.shape[0] - det_img_data.shape[0], 0) + det_img_data = cv2.copyMakeBorder( + det_img_data, + height // 2, + (height // 2 + 1) if height % 2 == 1 else height // 2, + width // 2, (width // 2 + 1) if width % 2 == 1 else width // 2, + cv2.BORDER_CONSTANT, + value=(255, 255, 255)) drawn_img = np.concatenate((det_img_data, pred_img_data), axis=1) - elif gt_img_data is not None: - drawn_img = np.concatenate((det_img_data, gt_img_data), axis=1) else: drawn_img = pred_img_data