diff --git a/demo/docs/en/3d_hand_demo.md b/demo/docs/en/3d_hand_demo.md
new file mode 100644
index 0000000000..b39008ff7b
--- /dev/null
+++ b/demo/docs/en/3d_hand_demo.md
@@ -0,0 +1,37 @@
+## 3D Hand Demo
+
+
+
+### 3D Hand Estimation Image Demo
+
+#### Using gt hand bounding boxes as input
+
+We provide a demo script to test a single image, given gt json file.
+
+```shell
+python demo/hand3d_internet_demo.py \
+ ${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
+ --input ${INPUT_FILE} \
+ --output-root ${OUTPUT_ROOT} \
+ [--save-predictions] \
+ [--gt-joints-file ${GT_JOINTS_FILE}]\
+ [--disable-rebase-keypoint] \
+ [--show] \
+ [--device ${GPU_ID or CPU}] \
+ [--kpt-thr ${KPT_THR}] \
+ [--show-kpt-idx] \
+ [--radius ${RADIUS}] \
+ [--thickness ${THICKNESS}]
+```
+
+The pre-trained hand pose estimation model can be downloaded from [model zoo](https://mmpose.readthedocs.io/en/latest/model_zoo/hand_3d_keypoint.html).
+Take [internet model](https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3dv1.0_all_256x256-42b7f2ac_20210702.pth) as an example:
+
+```shell
+python demo/hand3d_internet_demo.py \
+ configs/hand_3d_keypoint/internet/interhand3d/internet_res50_4xb16-20e_interhand3d-256x256.py \
+ https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3dv1.0_all_256x256-42b7f2ac_20210702.pth \
+ --input tests/data/interhand2.6m/image69148.jpg \
+ --save-predictions \
+ --output-root vis_results
+```
diff --git a/demo/hand3d_internet_demo.py b/demo/hand3d_internet_demo.py
new file mode 100644
index 0000000000..64db01e454
--- /dev/null
+++ b/demo/hand3d_internet_demo.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import mimetypes
+import os
+import time
+from argparse import ArgumentParser
+
+import cv2
+import json_tricks as json
+import mmcv
+import mmengine
+import numpy as np
+from mmengine.logging import print_log
+
+from mmpose.apis import inference_topdown, init_model
+from mmpose.registry import VISUALIZERS
+from mmpose.structures import (PoseDataSample, merge_data_samples,
+ split_instances)
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument('config', help='Config file')
+ parser.add_argument('checkpoint', help='Checkpoint file')
+ parser.add_argument(
+ '--input', type=str, default='', help='Image/Video file')
+ parser.add_argument(
+ '--output-root',
+ type=str,
+ default='',
+ help='root of the output img file. '
+ 'Default not saving the visualization images.')
+ parser.add_argument(
+ '--save-predictions',
+ action='store_true',
+ default=False,
+ help='whether to save predicted results')
+ parser.add_argument(
+ '--disable-rebase-keypoint',
+ action='store_true',
+ default=False,
+ help='Whether to disable rebasing the predicted 3D pose so its '
+ 'lowest keypoint has a height of 0 (landing on the ground). Rebase '
+ 'is useful for visualization when the model do not predict the '
+ 'global position of the 3D pose.')
+ parser.add_argument(
+ '--show',
+ action='store_true',
+ default=False,
+ help='whether to show result')
+ parser.add_argument('--device', default='cpu', help='Device for inference')
+ parser.add_argument(
+ '--kpt-thr',
+ type=float,
+ default=0.3,
+ help='Visualizing keypoint thresholds')
+ parser.add_argument(
+ '--show-kpt-idx',
+ action='store_true',
+ default=False,
+ help='Whether to show the index of keypoints')
+ parser.add_argument(
+ '--radius',
+ type=int,
+ default=3,
+ help='Keypoint radius for visualization')
+ parser.add_argument(
+ '--thickness',
+ type=int,
+ default=1,
+ help='Link thickness for visualization')
+
+ args = parser.parse_args()
+ return args
+
+
+def process_one_image(args, img, model, visualizer=None, show_interval=0):
+ """Visualize predicted keypoints of one image."""
+ # inference a single image
+ pose_results = inference_topdown(model, img)
+ # post-processing
+ pose_results_2d = []
+ for idx, res in enumerate(pose_results):
+ pred_instances = res.pred_instances
+ keypoints = pred_instances.keypoints
+ rel_root_depth = pred_instances.rel_root_depth
+ scores = pred_instances.keypoint_scores
+ hand_type = pred_instances.hand_type
+
+ res_2d = PoseDataSample()
+ gt_instances = res.gt_instances.clone()
+ pred_instances = pred_instances.clone()
+ res_2d.gt_instances = gt_instances
+ res_2d.pred_instances = pred_instances
+
+ # add relative root depth to left hand joints
+ keypoints[:, 21:, 2] += rel_root_depth
+
+ # set joint scores according to hand type
+ scores[:, :21] *= hand_type[:, [0]]
+ scores[:, 21:] *= hand_type[:, [1]]
+ # normalize kpt score
+ if scores.max() > 1:
+ scores /= 255
+
+ res_2d.pred_instances.set_field(keypoints[..., :2].copy(), 'keypoints')
+
+ # rotate the keypoint to make z-axis correspondent to height
+ # for better visualization
+ vis_R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
+ keypoints[..., :3] = keypoints[..., :3] @ vis_R
+
+ # rebase height (z-axis)
+ if not args.disable_rebase_keypoint:
+ valid = scores > 0
+ keypoints[..., 2] -= np.min(
+ keypoints[valid, 2], axis=-1, keepdims=True)
+
+ pose_results[idx].pred_instances.keypoints = keypoints
+ pose_results[idx].pred_instances.keypoint_scores = scores
+ pose_results_2d.append(res_2d)
+
+ data_samples = merge_data_samples(pose_results)
+ data_samples_2d = merge_data_samples(pose_results_2d)
+
+ # show the results
+ if isinstance(img, str):
+ img = mmcv.imread(img, channel_order='rgb')
+ elif isinstance(img, np.ndarray):
+ img = mmcv.bgr2rgb(img)
+
+ if visualizer is not None:
+ visualizer.add_datasample(
+ 'result',
+ img,
+ data_sample=data_samples,
+ det_data_sample=data_samples_2d,
+ draw_gt=False,
+ draw_bbox=True,
+ kpt_thr=args.kpt_thr,
+ convert_keypoint=False,
+ axis_azimuth=-115,
+ axis_limit=200,
+ axis_elev=15,
+ show_kpt_idx=args.show_kpt_idx,
+ show=args.show,
+ wait_time=show_interval)
+
+ # if there is no instance detected, return None
+ return data_samples.get('pred_instances', None)
+
+
+def main():
+ args = parse_args()
+
+ assert args.input != ''
+ assert args.show or (args.output_root != '')
+
+ output_file = None
+ if args.output_root:
+ mmengine.mkdir_or_exist(args.output_root)
+ output_file = os.path.join(args.output_root,
+ os.path.basename(args.input))
+ if args.input == 'webcam':
+ output_file += '.mp4'
+
+ if args.save_predictions:
+ assert args.output_root != ''
+ args.pred_save_path = f'{args.output_root}/results_' \
+ f'{os.path.splitext(os.path.basename(args.input))[0]}.json'
+
+ # build the model from a config file and a checkpoint file
+ model = init_model(args.config, args.checkpoint, device=args.device)
+
+ # init visualizer
+ model.cfg.visualizer.radius = args.radius
+ model.cfg.visualizer.line_width = args.thickness
+
+ visualizer = VISUALIZERS.build(model.cfg.visualizer)
+ visualizer.set_dataset_meta(model.dataset_meta)
+
+ if args.input == 'webcam':
+ input_type = 'webcam'
+ else:
+ input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
+
+ if input_type == 'image':
+ # inference
+ pred_instances = process_one_image(args, args.input, model, visualizer)
+
+ if args.save_predictions:
+ pred_instances_list = split_instances(pred_instances)
+
+ if output_file:
+ img_vis = visualizer.get_image()
+ mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)
+
+ elif input_type in ['webcam', 'video']:
+
+ if args.input == 'webcam':
+ cap = cv2.VideoCapture(0)
+ else:
+ cap = cv2.VideoCapture(args.input)
+
+ video_writer = None
+ pred_instances_list = []
+ frame_idx = 0
+
+ while cap.isOpened():
+ success, frame = cap.read()
+ frame_idx += 1
+
+ if not success:
+ break
+
+ # topdown pose estimation
+ pred_instances = process_one_image(args, frame, model, visualizer,
+ 0.001)
+
+ if args.save_predictions:
+ # save prediction results
+ pred_instances_list.append(
+ dict(
+ frame_id=frame_idx,
+ instances=split_instances(pred_instances)))
+
+ # output videos
+ if output_file:
+ frame_vis = visualizer.get_image()
+
+ if video_writer is None:
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ # the size of the image with visualization may vary
+ # depending on the presence of heatmaps
+ video_writer = cv2.VideoWriter(
+ output_file,
+ fourcc,
+ 25, # saved fps
+ (frame_vis.shape[1], frame_vis.shape[0]))
+
+ video_writer.write(mmcv.rgb2bgr(frame_vis))
+
+ if args.show:
+ # press ESC to exit
+ if cv2.waitKey(5) & 0xFF == 27:
+ break
+
+ time.sleep(args.show_interval)
+
+ if video_writer:
+ video_writer.release()
+
+ cap.release()
+
+ else:
+ args.save_predictions = False
+ raise ValueError(
+ f'file {os.path.basename(args.input)} has invalid format.')
+
+ if args.save_predictions:
+ with open(args.pred_save_path, 'w') as f:
+ json.dump(
+ dict(
+ meta_info=model.dataset_meta,
+ instance_info=pred_instances_list),
+ f,
+ indent='\t')
+ print(f'predictions have been saved at {args.pred_save_path}')
+
+ if output_file is not None:
+ input_type = input_type.replace('webcam', 'video')
+ print_log(
+ f'the output {input_type} has been saved at {output_file}',
+ logger='current',
+ level=logging.INFO)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/mmpose/visualization/local_visualizer_3d.py b/mmpose/visualization/local_visualizer_3d.py
index 63a1d5b47c..56f3bbfba1 100644
--- a/mmpose/visualization/local_visualizer_3d.py
+++ b/mmpose/visualization/local_visualizer_3d.py
@@ -85,6 +85,7 @@ def _draw_3d_data_samples(self,
axis_limit: float = 1.7,
axis_dist: float = 10.0,
axis_elev: float = 15.0,
+ show_kpt_idx: bool = False,
scores_2d: Optional[np.ndarray] = None):
"""Draw keypoints and skeletons (optional) of GT or prediction.
@@ -109,13 +110,16 @@ def _draw_3d_data_samples(self,
- y: [y_c - axis_limit/2, y_c + axis_limit/2]
- z: [0, axis_limit]
Where x_c, y_c is the mean value of x and y coordinates
+ show_kpt_idx (bool): Whether to show the index of keypoints.
+ Defaults to ``False``
scores_2d (np.ndarray, optional): Keypoint scores of 2d estimation
that will be used to filter 3d instances.
Returns:
Tuple(np.ndarray): the drawn image which channel is RGB.
"""
- vis_height, vis_width, _ = image.shape
+ vis_width = max(image.shape)
+ vis_height = vis_width
if 'pred_instances' in pose_samples:
pred_instances = pose_samples.pred_instances
@@ -150,6 +154,7 @@ def _draw_3d_instances_kpts(keypoints,
scores_2d,
keypoints_visible,
fig_idx,
+ show_kpt_idx,
title=None):
for idx, (kpts, score, score_2d) in enumerate(
@@ -169,7 +174,6 @@ def _draw_3d_instances_kpts(keypoints,
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
- ax.scatter([0], [0], [0], marker='o', color='red')
if title:
ax.set_title(f'{title} ({idx})')
ax.dist = axis_dist
@@ -199,9 +203,10 @@ def _draw_3d_instances_kpts(keypoints,
ax.scatter(x_3d, y_3d, z_3d, marker='o', c=kpt_color)
- for kpt_idx in range(len(x_3d)):
- ax.text(x_3d[kpt_idx][0], y_3d[kpt_idx][0],
- z_3d[kpt_idx][0], str(kpt_idx))
+ if show_kpt_idx:
+ for kpt_idx in range(len(x_3d)):
+ ax.text(x_3d[kpt_idx][0], y_3d[kpt_idx][0],
+ z_3d[kpt_idx][0], str(kpt_idx))
if self.skeleton is not None and self.link_color is not None:
if self.link_color is None or isinstance(
@@ -247,7 +252,8 @@ def _draw_3d_instances_kpts(keypoints,
keypoints_visible = np.ones(keypoints.shape[:-1])
_draw_3d_instances_kpts(keypoints, scores, scores_2d,
- keypoints_visible, 1, 'Prediction')
+ keypoints_visible, 1, show_kpt_idx,
+ 'Prediction')
if draw_gt and 'gt_instances' in pose_samples:
gt_instances = pose_samples.gt_instances
@@ -260,9 +266,22 @@ def _draw_3d_instances_kpts(keypoints,
keypoints_visible = gt_instances.lifting_target_visible
else:
keypoints_visible = np.ones(keypoints.shape[:-1])
+ elif 'keypoints_gt' in gt_instances:
+ keypoints = gt_instances.get('keypoints_gt',
+ gt_instances.keypoints_gt)
+ scores = np.ones(keypoints.shape[:-1])
- _draw_3d_instances_kpts(keypoints, scores, keypoints_visible,
- 2, 'Ground Truth')
+ if 'keypoints_visible' in gt_instances:
+ keypoints_visible = gt_instances.keypoints_visible
+ else:
+ keypoints_visible = np.ones(keypoints.shape[:-1])
+ else:
+ raise ValueError('to visualize ground truth results, '
+ 'data sample must contain '
+ '"lifting_target" or "keypoints_gt"')
+
+ _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, 2,
+ show_kpt_idx, 'Ground Truth')
# convert figure to numpy array
fig.tight_layout()
@@ -357,7 +376,7 @@ def _draw_instances_kpts(self,
for kpts, score, visible in zip(keypoints, scores,
keypoints_visible):
- kpts = np.array(kpts, copy=False)
+ kpts = np.array(kpts[..., :2], copy=False)
if kpt_color is None or isinstance(kpt_color, str):
kpt_color = [kpt_color] * len(kpts)
@@ -477,6 +496,11 @@ def add_datasample(self,
skeleton_style: str = 'mmpose',
dataset_2d: str = 'coco',
dataset_3d: str = 'h36m',
+ convert_keypoint: bool = True,
+ axis_azimuth: float = 70,
+ axis_limit: float = 1.7,
+ axis_dist: float = 10.0,
+ axis_elev: float = 15.0,
num_instances: int = -1,
show: bool = False,
wait_time: float = 0,
@@ -517,6 +541,17 @@ def add_datasample(self,
``'CocoDataset'``
dataset_3d (str): Name of 3d keypoint dataset. Defaults to
``'Human36mDataset'``
+ convert_keypoint (bool): Whether to convert keypoint definition.
+ Defaults to ``True``
+ axis_azimuth (float): axis azimuth angle for 3D visualizations.
+ axis_dist (float): axis distance for 3D visualizations.
+ axis_elev (float): axis elevation view angle for 3D visualizations.
+ axis_limit (float): The axis limit to visualize 3d pose. The xyz
+ range will be set as:
+ - x: [x_c - axis_limit/2, x_c + axis_limit/2]
+ - y: [y_c - axis_limit/2, y_c + axis_limit/2]
+ - z: [0, axis_limit]
+ Where x_c, y_c is the mean value of x and y coordinates
num_instances (int): Number of instances to be shown in 3D. If
smaller than 0, all the instances in the pose_result will be
shown. Otherwise, pad or truncate the pose_result to a length
@@ -531,21 +566,24 @@ def add_datasample(self,
"""
det_img_data = None
- gt_img_data = None
scores_2d = None
if draw_2d:
det_img_data = image.copy()
# draw bboxes & keypoints
- if 'pred_instances' in det_data_sample:
+ if (det_data_sample is not None
+ and 'pred_instances' in det_data_sample):
det_img_data, scores_2d = self._draw_instances_kpts(
- det_img_data, det_data_sample.pred_instances, kpt_thr,
- show_kpt_idx, skeleton_style)
+ image=det_img_data,
+ instances=det_data_sample.pred_instances,
+ kpt_thr=kpt_thr,
+ show_kpt_idx=show_kpt_idx,
+ skeleton_style=skeleton_style)
if draw_bbox:
det_img_data = self._draw_instances_bbox(
det_img_data, det_data_sample.pred_instances)
- if scores_2d is not None:
+ if scores_2d is not None and convert_keypoint:
if scores_2d.ndim == 2:
scores_2d = scores_2d[..., None]
scores_2d = np.squeeze(
@@ -556,16 +594,25 @@ def add_datasample(self,
data_sample,
draw_gt=draw_gt,
num_instances=num_instances,
+ axis_azimuth=axis_azimuth,
+ axis_limit=axis_limit,
+ show_kpt_idx=show_kpt_idx,
+ axis_dist=axis_dist,
+ axis_elev=axis_elev,
scores_2d=scores_2d)
# merge visualization results
- if det_img_data is not None and gt_img_data is not None:
- drawn_img = np.concatenate(
- (det_img_data, pred_img_data, gt_img_data), axis=1)
- elif det_img_data is not None:
+ if det_img_data is not None:
+ width = max(pred_img_data.shape[1] - det_img_data.shape[1], 0)
+ height = max(pred_img_data.shape[0] - det_img_data.shape[0], 0)
+ det_img_data = cv2.copyMakeBorder(
+ det_img_data,
+ height // 2,
+ (height // 2 + 1) if height % 2 == 1 else height // 2,
+ width // 2, (width // 2 + 1) if width % 2 == 1 else width // 2,
+ cv2.BORDER_CONSTANT,
+ value=(255, 255, 255))
drawn_img = np.concatenate((det_img_data, pred_img_data), axis=1)
- elif gt_img_data is not None:
- drawn_img = np.concatenate((det_img_data, gt_img_data), axis=1)
else:
drawn_img = pred_img_data