-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add demo script for 3d hand pose (#2710)
- Loading branch information
Showing
3 changed files
with
383 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
## 3D Hand Demo | ||
|
||
<img src="https://user-images.githubusercontent.com/28900607/121288285-b8fcbf00-c915-11eb-98e4-ba846de12987.gif" width="600px" alt><br> | ||
|
||
### 3D Hand Estimation Image Demo | ||
|
||
#### Using gt hand bounding boxes as input | ||
|
||
We provide a demo script to test a single image, given gt json file. | ||
|
||
```shell | ||
python demo/hand3d_internet_demo.py \ | ||
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \ | ||
--input ${INPUT_FILE} \ | ||
--output-root ${OUTPUT_ROOT} \ | ||
[--save-predictions] \ | ||
[--gt-joints-file ${GT_JOINTS_FILE}]\ | ||
[--disable-rebase-keypoint] \ | ||
[--show] \ | ||
[--device ${GPU_ID or CPU}] \ | ||
[--kpt-thr ${KPT_THR}] \ | ||
[--show-kpt-idx] \ | ||
[--radius ${RADIUS}] \ | ||
[--thickness ${THICKNESS}] | ||
``` | ||
|
||
The pre-trained hand pose estimation model can be downloaded from [model zoo](https://mmpose.readthedocs.io/en/latest/model_zoo/hand_3d_keypoint.html). | ||
Take [internet model](https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3dv1.0_all_256x256-42b7f2ac_20210702.pth) as an example: | ||
|
||
```shell | ||
python demo/hand3d_internet_demo.py \ | ||
configs/hand_3d_keypoint/internet/interhand3d/internet_res50_4xb16-20e_interhand3d-256x256.py \ | ||
https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3dv1.0_all_256x256-42b7f2ac_20210702.pth \ | ||
--input tests/data/interhand2.6m/image69148.jpg \ | ||
--save-predictions \ | ||
--output-root vis_results | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import logging | ||
import mimetypes | ||
import os | ||
import time | ||
from argparse import ArgumentParser | ||
|
||
import cv2 | ||
import json_tricks as json | ||
import mmcv | ||
import mmengine | ||
import numpy as np | ||
from mmengine.logging import print_log | ||
|
||
from mmpose.apis import inference_topdown, init_model | ||
from mmpose.registry import VISUALIZERS | ||
from mmpose.structures import (PoseDataSample, merge_data_samples, | ||
split_instances) | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument('config', help='Config file') | ||
parser.add_argument('checkpoint', help='Checkpoint file') | ||
parser.add_argument( | ||
'--input', type=str, default='', help='Image/Video file') | ||
parser.add_argument( | ||
'--output-root', | ||
type=str, | ||
default='', | ||
help='root of the output img file. ' | ||
'Default not saving the visualization images.') | ||
parser.add_argument( | ||
'--save-predictions', | ||
action='store_true', | ||
default=False, | ||
help='whether to save predicted results') | ||
parser.add_argument( | ||
'--disable-rebase-keypoint', | ||
action='store_true', | ||
default=False, | ||
help='Whether to disable rebasing the predicted 3D pose so its ' | ||
'lowest keypoint has a height of 0 (landing on the ground). Rebase ' | ||
'is useful for visualization when the model do not predict the ' | ||
'global position of the 3D pose.') | ||
parser.add_argument( | ||
'--show', | ||
action='store_true', | ||
default=False, | ||
help='whether to show result') | ||
parser.add_argument('--device', default='cpu', help='Device for inference') | ||
parser.add_argument( | ||
'--kpt-thr', | ||
type=float, | ||
default=0.3, | ||
help='Visualizing keypoint thresholds') | ||
parser.add_argument( | ||
'--show-kpt-idx', | ||
action='store_true', | ||
default=False, | ||
help='Whether to show the index of keypoints') | ||
parser.add_argument( | ||
'--radius', | ||
type=int, | ||
default=3, | ||
help='Keypoint radius for visualization') | ||
parser.add_argument( | ||
'--thickness', | ||
type=int, | ||
default=1, | ||
help='Link thickness for visualization') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def process_one_image(args, img, model, visualizer=None, show_interval=0): | ||
"""Visualize predicted keypoints of one image.""" | ||
# inference a single image | ||
pose_results = inference_topdown(model, img) | ||
# post-processing | ||
pose_results_2d = [] | ||
for idx, res in enumerate(pose_results): | ||
pred_instances = res.pred_instances | ||
keypoints = pred_instances.keypoints | ||
rel_root_depth = pred_instances.rel_root_depth | ||
scores = pred_instances.keypoint_scores | ||
hand_type = pred_instances.hand_type | ||
|
||
res_2d = PoseDataSample() | ||
gt_instances = res.gt_instances.clone() | ||
pred_instances = pred_instances.clone() | ||
res_2d.gt_instances = gt_instances | ||
res_2d.pred_instances = pred_instances | ||
|
||
# add relative root depth to left hand joints | ||
keypoints[:, 21:, 2] += rel_root_depth | ||
|
||
# set joint scores according to hand type | ||
scores[:, :21] *= hand_type[:, [0]] | ||
scores[:, 21:] *= hand_type[:, [1]] | ||
# normalize kpt score | ||
if scores.max() > 1: | ||
scores /= 255 | ||
|
||
res_2d.pred_instances.set_field(keypoints[..., :2].copy(), 'keypoints') | ||
|
||
# rotate the keypoint to make z-axis correspondent to height | ||
# for better visualization | ||
vis_R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) | ||
keypoints[..., :3] = keypoints[..., :3] @ vis_R | ||
|
||
# rebase height (z-axis) | ||
if not args.disable_rebase_keypoint: | ||
valid = scores > 0 | ||
keypoints[..., 2] -= np.min( | ||
keypoints[valid, 2], axis=-1, keepdims=True) | ||
|
||
pose_results[idx].pred_instances.keypoints = keypoints | ||
pose_results[idx].pred_instances.keypoint_scores = scores | ||
pose_results_2d.append(res_2d) | ||
|
||
data_samples = merge_data_samples(pose_results) | ||
data_samples_2d = merge_data_samples(pose_results_2d) | ||
|
||
# show the results | ||
if isinstance(img, str): | ||
img = mmcv.imread(img, channel_order='rgb') | ||
elif isinstance(img, np.ndarray): | ||
img = mmcv.bgr2rgb(img) | ||
|
||
if visualizer is not None: | ||
visualizer.add_datasample( | ||
'result', | ||
img, | ||
data_sample=data_samples, | ||
det_data_sample=data_samples_2d, | ||
draw_gt=False, | ||
draw_bbox=True, | ||
kpt_thr=args.kpt_thr, | ||
convert_keypoint=False, | ||
axis_azimuth=-115, | ||
axis_limit=200, | ||
axis_elev=15, | ||
show_kpt_idx=args.show_kpt_idx, | ||
show=args.show, | ||
wait_time=show_interval) | ||
|
||
# if there is no instance detected, return None | ||
return data_samples.get('pred_instances', None) | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
assert args.input != '' | ||
assert args.show or (args.output_root != '') | ||
|
||
output_file = None | ||
if args.output_root: | ||
mmengine.mkdir_or_exist(args.output_root) | ||
output_file = os.path.join(args.output_root, | ||
os.path.basename(args.input)) | ||
if args.input == 'webcam': | ||
output_file += '.mp4' | ||
|
||
if args.save_predictions: | ||
assert args.output_root != '' | ||
args.pred_save_path = f'{args.output_root}/results_' \ | ||
f'{os.path.splitext(os.path.basename(args.input))[0]}.json' | ||
|
||
# build the model from a config file and a checkpoint file | ||
model = init_model(args.config, args.checkpoint, device=args.device) | ||
|
||
# init visualizer | ||
model.cfg.visualizer.radius = args.radius | ||
model.cfg.visualizer.line_width = args.thickness | ||
|
||
visualizer = VISUALIZERS.build(model.cfg.visualizer) | ||
visualizer.set_dataset_meta(model.dataset_meta) | ||
|
||
if args.input == 'webcam': | ||
input_type = 'webcam' | ||
else: | ||
input_type = mimetypes.guess_type(args.input)[0].split('/')[0] | ||
|
||
if input_type == 'image': | ||
# inference | ||
pred_instances = process_one_image(args, args.input, model, visualizer) | ||
|
||
if args.save_predictions: | ||
pred_instances_list = split_instances(pred_instances) | ||
|
||
if output_file: | ||
img_vis = visualizer.get_image() | ||
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file) | ||
|
||
elif input_type in ['webcam', 'video']: | ||
|
||
if args.input == 'webcam': | ||
cap = cv2.VideoCapture(0) | ||
else: | ||
cap = cv2.VideoCapture(args.input) | ||
|
||
video_writer = None | ||
pred_instances_list = [] | ||
frame_idx = 0 | ||
|
||
while cap.isOpened(): | ||
success, frame = cap.read() | ||
frame_idx += 1 | ||
|
||
if not success: | ||
break | ||
|
||
# topdown pose estimation | ||
pred_instances = process_one_image(args, frame, model, visualizer, | ||
0.001) | ||
|
||
if args.save_predictions: | ||
# save prediction results | ||
pred_instances_list.append( | ||
dict( | ||
frame_id=frame_idx, | ||
instances=split_instances(pred_instances))) | ||
|
||
# output videos | ||
if output_file: | ||
frame_vis = visualizer.get_image() | ||
|
||
if video_writer is None: | ||
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | ||
# the size of the image with visualization may vary | ||
# depending on the presence of heatmaps | ||
video_writer = cv2.VideoWriter( | ||
output_file, | ||
fourcc, | ||
25, # saved fps | ||
(frame_vis.shape[1], frame_vis.shape[0])) | ||
|
||
video_writer.write(mmcv.rgb2bgr(frame_vis)) | ||
|
||
if args.show: | ||
# press ESC to exit | ||
if cv2.waitKey(5) & 0xFF == 27: | ||
break | ||
|
||
time.sleep(args.show_interval) | ||
|
||
if video_writer: | ||
video_writer.release() | ||
|
||
cap.release() | ||
|
||
else: | ||
args.save_predictions = False | ||
raise ValueError( | ||
f'file {os.path.basename(args.input)} has invalid format.') | ||
|
||
if args.save_predictions: | ||
with open(args.pred_save_path, 'w') as f: | ||
json.dump( | ||
dict( | ||
meta_info=model.dataset_meta, | ||
instance_info=pred_instances_list), | ||
f, | ||
indent='\t') | ||
print(f'predictions have been saved at {args.pred_save_path}') | ||
|
||
if output_file is not None: | ||
input_type = input_type.replace('webcam', 'video') | ||
print_log( | ||
f'the output {input_type} has been saved at {output_file}', | ||
logger='current', | ||
level=logging.INFO) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.