Skip to content

Commit

Permalink
[Feature] Add demo script for 3d hand pose (#2710)
Browse files Browse the repository at this point in the history
  • Loading branch information
LareinaM authored Sep 22, 2023
1 parent 52c245c commit 26f3688
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 20 deletions.
37 changes: 37 additions & 0 deletions demo/docs/en/3d_hand_demo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
## 3D Hand Demo

<img src="https://user-images.githubusercontent.com/28900607/121288285-b8fcbf00-c915-11eb-98e4-ba846de12987.gif" width="600px" alt><br>

### 3D Hand Estimation Image Demo

#### Using gt hand bounding boxes as input

We provide a demo script to test a single image, given gt json file.

```shell
python demo/hand3d_internet_demo.py \
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
--input ${INPUT_FILE} \
--output-root ${OUTPUT_ROOT} \
[--save-predictions] \
[--gt-joints-file ${GT_JOINTS_FILE}]\
[--disable-rebase-keypoint] \
[--show] \
[--device ${GPU_ID or CPU}] \
[--kpt-thr ${KPT_THR}] \
[--show-kpt-idx] \
[--radius ${RADIUS}] \
[--thickness ${THICKNESS}]
```

The pre-trained hand pose estimation model can be downloaded from [model zoo](https://mmpose.readthedocs.io/en/latest/model_zoo/hand_3d_keypoint.html).
Take [internet model](https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3dv1.0_all_256x256-42b7f2ac_20210702.pth) as an example:

```shell
python demo/hand3d_internet_demo.py \
configs/hand_3d_keypoint/internet/interhand3d/internet_res50_4xb16-20e_interhand3d-256x256.py \
https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3dv1.0_all_256x256-42b7f2ac_20210702.pth \
--input tests/data/interhand2.6m/image69148.jpg \
--save-predictions \
--output-root vis_results
```
279 changes: 279 additions & 0 deletions demo/hand3d_internet_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import mimetypes
import os
import time
from argparse import ArgumentParser

import cv2
import json_tricks as json
import mmcv
import mmengine
import numpy as np
from mmengine.logging import print_log

from mmpose.apis import inference_topdown, init_model
from mmpose.registry import VISUALIZERS
from mmpose.structures import (PoseDataSample, merge_data_samples,
split_instances)


def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--input', type=str, default='', help='Image/Video file')
parser.add_argument(
'--output-root',
type=str,
default='',
help='root of the output img file. '
'Default not saving the visualization images.')
parser.add_argument(
'--save-predictions',
action='store_true',
default=False,
help='whether to save predicted results')
parser.add_argument(
'--disable-rebase-keypoint',
action='store_true',
default=False,
help='Whether to disable rebasing the predicted 3D pose so its '
'lowest keypoint has a height of 0 (landing on the ground). Rebase '
'is useful for visualization when the model do not predict the '
'global position of the 3D pose.')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='whether to show result')
parser.add_argument('--device', default='cpu', help='Device for inference')
parser.add_argument(
'--kpt-thr',
type=float,
default=0.3,
help='Visualizing keypoint thresholds')
parser.add_argument(
'--show-kpt-idx',
action='store_true',
default=False,
help='Whether to show the index of keypoints')
parser.add_argument(
'--radius',
type=int,
default=3,
help='Keypoint radius for visualization')
parser.add_argument(
'--thickness',
type=int,
default=1,
help='Link thickness for visualization')

args = parser.parse_args()
return args


def process_one_image(args, img, model, visualizer=None, show_interval=0):
"""Visualize predicted keypoints of one image."""
# inference a single image
pose_results = inference_topdown(model, img)
# post-processing
pose_results_2d = []
for idx, res in enumerate(pose_results):
pred_instances = res.pred_instances
keypoints = pred_instances.keypoints
rel_root_depth = pred_instances.rel_root_depth
scores = pred_instances.keypoint_scores
hand_type = pred_instances.hand_type

res_2d = PoseDataSample()
gt_instances = res.gt_instances.clone()
pred_instances = pred_instances.clone()
res_2d.gt_instances = gt_instances
res_2d.pred_instances = pred_instances

# add relative root depth to left hand joints
keypoints[:, 21:, 2] += rel_root_depth

# set joint scores according to hand type
scores[:, :21] *= hand_type[:, [0]]
scores[:, 21:] *= hand_type[:, [1]]
# normalize kpt score
if scores.max() > 1:
scores /= 255

res_2d.pred_instances.set_field(keypoints[..., :2].copy(), 'keypoints')

# rotate the keypoint to make z-axis correspondent to height
# for better visualization
vis_R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
keypoints[..., :3] = keypoints[..., :3] @ vis_R

# rebase height (z-axis)
if not args.disable_rebase_keypoint:
valid = scores > 0
keypoints[..., 2] -= np.min(
keypoints[valid, 2], axis=-1, keepdims=True)

pose_results[idx].pred_instances.keypoints = keypoints
pose_results[idx].pred_instances.keypoint_scores = scores
pose_results_2d.append(res_2d)

data_samples = merge_data_samples(pose_results)
data_samples_2d = merge_data_samples(pose_results_2d)

# show the results
if isinstance(img, str):
img = mmcv.imread(img, channel_order='rgb')
elif isinstance(img, np.ndarray):
img = mmcv.bgr2rgb(img)

if visualizer is not None:
visualizer.add_datasample(
'result',
img,
data_sample=data_samples,
det_data_sample=data_samples_2d,
draw_gt=False,
draw_bbox=True,
kpt_thr=args.kpt_thr,
convert_keypoint=False,
axis_azimuth=-115,
axis_limit=200,
axis_elev=15,
show_kpt_idx=args.show_kpt_idx,
show=args.show,
wait_time=show_interval)

# if there is no instance detected, return None
return data_samples.get('pred_instances', None)


def main():
args = parse_args()

assert args.input != ''
assert args.show or (args.output_root != '')

output_file = None
if args.output_root:
mmengine.mkdir_or_exist(args.output_root)
output_file = os.path.join(args.output_root,
os.path.basename(args.input))
if args.input == 'webcam':
output_file += '.mp4'

if args.save_predictions:
assert args.output_root != ''
args.pred_save_path = f'{args.output_root}/results_' \
f'{os.path.splitext(os.path.basename(args.input))[0]}.json'

# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)

# init visualizer
model.cfg.visualizer.radius = args.radius
model.cfg.visualizer.line_width = args.thickness

visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.set_dataset_meta(model.dataset_meta)

if args.input == 'webcam':
input_type = 'webcam'
else:
input_type = mimetypes.guess_type(args.input)[0].split('/')[0]

if input_type == 'image':
# inference
pred_instances = process_one_image(args, args.input, model, visualizer)

if args.save_predictions:
pred_instances_list = split_instances(pred_instances)

if output_file:
img_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)

elif input_type in ['webcam', 'video']:

if args.input == 'webcam':
cap = cv2.VideoCapture(0)
else:
cap = cv2.VideoCapture(args.input)

video_writer = None
pred_instances_list = []
frame_idx = 0

while cap.isOpened():
success, frame = cap.read()
frame_idx += 1

if not success:
break

# topdown pose estimation
pred_instances = process_one_image(args, frame, model, visualizer,
0.001)

if args.save_predictions:
# save prediction results
pred_instances_list.append(
dict(
frame_id=frame_idx,
instances=split_instances(pred_instances)))

# output videos
if output_file:
frame_vis = visualizer.get_image()

if video_writer is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# the size of the image with visualization may vary
# depending on the presence of heatmaps
video_writer = cv2.VideoWriter(
output_file,
fourcc,
25, # saved fps
(frame_vis.shape[1], frame_vis.shape[0]))

video_writer.write(mmcv.rgb2bgr(frame_vis))

if args.show:
# press ESC to exit
if cv2.waitKey(5) & 0xFF == 27:
break

time.sleep(args.show_interval)

if video_writer:
video_writer.release()

cap.release()

else:
args.save_predictions = False
raise ValueError(
f'file {os.path.basename(args.input)} has invalid format.')

if args.save_predictions:
with open(args.pred_save_path, 'w') as f:
json.dump(
dict(
meta_info=model.dataset_meta,
instance_info=pred_instances_list),
f,
indent='\t')
print(f'predictions have been saved at {args.pred_save_path}')

if output_file is not None:
input_type = input_type.replace('webcam', 'video')
print_log(
f'the output {input_type} has been saved at {output_file}',
logger='current',
level=logging.INFO)


if __name__ == '__main__':
main()
Loading

0 comments on commit 26f3688

Please sign in to comment.