Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] 3d human pose demo #2554

Merged
merged 3 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 139 additions & 109 deletions demo/body3d_pose_lifter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import mmengine
import numpy as np
from mmengine.logging import print_log
from mmengine.structures import InstanceData

from mmpose.apis import (_track_by_iou, _track_by_oks, collect_multi_frames,
convert_keypoint_definition, extract_pose_sequence,
Expand Down Expand Up @@ -59,12 +58,13 @@ def parse_args():
default=False,
help='Whether to show visualizations')
parser.add_argument(
'--rebase-keypoint-height',
'--disable-rebase-keypoint',
action='store_true',
help='Rebase the predicted 3D pose so its lowest keypoint has a '
'height of 0 (landing on the ground). This is useful for '
'visualization when the model do not predict the global position '
'of the 3D pose.')
default=False,
help='Whether to disable rebasing the predicted 3D pose so its '
'lowest keypoint has a height of 0 (landing on the ground). Rebase '
'is useful for visualization when the model do not predict the '
'global position of the 3D pose.')
parser.add_argument(
'--norm-pose-2d',
action='store_true',
Expand All @@ -75,7 +75,7 @@ def parse_args():
parser.add_argument(
'--num-instances',
type=int,
default=-1,
default=1,
help='The number of 3D poses to be visualized in every frame. If '
'less than 0, it will be set to the number of pose results in the '
'first frame.')
Expand Down Expand Up @@ -130,16 +130,74 @@ def parse_args():
return args


def get_area(results):
for i, data_sample in enumerate(results):
pred_instance = data_sample.pred_instances.cpu().numpy()
if 'bboxes' in pred_instance:
bboxes = pred_instance.bboxes
results[i].pred_instances.set_field(
np.array([(bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
for bbox in bboxes]), 'areas')
def process_one_image(args, detector, frame, frame_idx, pose_estimator,
pose_est_frame, pose_est_results_last,
pose_est_results_list, next_id, pose_lifter,
visualize_frame, visualizer):
"""Visualize detected and predicted keypoints of one image.

Args:
args (Argument): Custom command-line arguments.
detector (mmdet.BaseDetector): The mmdet detector.
frame (np.ndarray): The image frame read from input image or video.
frame_idx (int): The index of current frame.
pose_estimator (TopdownPoseEstimator): The pose estimator for 2d pose.
pose_est_frame (np.ndarray | list(np.ndarray)): The frames for pose
estimation.
pose_est_results_last (list(PoseDataSample)): The results of pose
estimation from the last frame for tracking instances.
pose_est_results_list (list(list(PoseDataSample))): The list of all
pose estimation results converted by
``convert_keypoint_definition`` from previous frames. In
pose-lifting stage it is used to obtain the 2d estimation sequence.
next_id (int): The next track id to be used.
pose_lifter (PoseLifter): The pose-lifter for estimating 3d pose.
visualize_frame (np.ndarray): The image for drawing the results on.
visualizer (Visualizer): The visualizer for visualizing the 2d and 3d
pose estimation results.

Returns:
pose_est_results (list(PoseDataSample)): The pose estimation result of
the current frame.
pose_est_results_list (list(list(PoseDataSample))): The list of all
converted pose estimation results until the current frame.
pred_3d_instances (InstanceData): The result of pose-lifting.
Specifically, the predicted keypoints and scores are saved at
``pred_3d_instances.keypoints`` and
``pred_3d_instances.keypoint_scores``.
next_id (int): The next track id to be used.
"""
pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset

det_result = inference_detector(detector, frame)
pred_instance = det_result.pred_instances.cpu().numpy()

# First stage: 2D pose detection
bboxes = pred_instance.bboxes
bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id,
pred_instance.scores > args.bbox_thr)]

# estimate pose results for current image
pose_est_results = inference_topdown(pose_estimator, pose_est_frame,
bboxes)

if args.use_oks_tracking:
_track = partial(_track_by_oks)
else:
_track = _track_by_iou

pose_det_dataset = pose_estimator.cfg.test_dataloader.dataset
pose_est_results_converted = []

for i, data_sample in enumerate(pose_est_results):
pred_instances = data_sample.pred_instances.cpu().numpy()
keypoints = pred_instances.keypoints
# calculate area and bbox
if 'bboxes' in pred_instances:
areas = np.array([(bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
for bbox in pred_instances.bboxes])
pose_est_results[i].pred_instances.set_field(areas, 'areas')
else:
keypoints = pred_instance.keypoints
areas, bboxes = [], []
for keypoint in keypoints:
xmin = np.min(keypoint[:, 0][keypoint[:, 0] > 0], initial=1e10)
Expand All @@ -148,72 +206,47 @@ def get_area(results):
ymax = np.max(keypoint[:, 1])
areas.append((xmax - xmin) * (ymax - ymin))
bboxes.append([xmin, ymin, xmax, ymax])
results[i].pred_instances.areas = np.array(areas)
results[i].pred_instances.bboxes = np.array(bboxes)
return results
pose_est_results[i].pred_instances.areas = np.array(areas)
pose_est_results[i].pred_instances.bboxes = np.array(bboxes)


def get_pose_est_results(args, pose_estimator, frame, bboxes,
pose_est_results_last, next_id, pose_lift_dataset):
pose_det_dataset = pose_estimator.cfg.test_dataloader.dataset

# make person results for current image
pose_est_results = inference_topdown(pose_estimator, frame, bboxes)

pose_est_results = get_area(pose_est_results)
if args.use_oks_tracking:
_track = partial(_track_by_oks)
else:
_track = _track_by_iou

for i, result in enumerate(pose_est_results):
track_id, pose_est_results_last, match_result = _track(
result, pose_est_results_last, args.tracking_thr)
# track id
track_id, pose_est_results_last, _ = _track(data_sample,
pose_est_results_last,
args.tracking_thr)
if track_id == -1:
pred_instances = result.pred_instances.cpu().numpy()
keypoints = pred_instances.keypoints
if np.count_nonzero(keypoints[:, :, 1]) >= 3:
pose_est_results[i].set_field(next_id, 'track_id')
track_id = next_id
next_id += 1
else:
# If the number of keypoints detected is small,
# delete that person instance.
keypoints[:, :, 1] = -10
pose_est_results[i].pred_instances.set_field(
keypoints, 'keypoints')
bboxes = pred_instances.bboxes * 0
pose_est_results[i].pred_instances.set_field(bboxes, 'bboxes')
pose_est_results[i].set_field(-1, 'track_id')
pose_est_results[i].pred_instances.set_field(
pred_instances.bboxes * 0, 'bboxes')
pose_est_results[i].set_field(pred_instances, 'pred_instances')
else:
pose_est_results[i].set_field(track_id, 'track_id')
track_id = -1
pose_est_results[i].set_field(track_id, 'track_id')

del match_result

pose_est_results_converted = []
for pose_est_result in pose_est_results:
# convert keypoints for pose-lifting
pose_est_result_converted = PoseDataSample()
gt_instances = InstanceData()
pred_instances = InstanceData()
for k in pose_est_result.gt_instances.keys():
gt_instances.set_field(pose_est_result.gt_instances[k], k)
for k in pose_est_result.pred_instances.keys():
pred_instances.set_field(pose_est_result.pred_instances[k], k)
pose_est_result_converted.gt_instances = gt_instances
pose_est_result_converted.pred_instances = pred_instances
pose_est_result_converted.track_id = pose_est_result.track_id

keypoints = convert_keypoint_definition(pred_instances.keypoints,
pose_est_result_converted.set_field(
pose_est_results[i].pred_instances.clone(), 'pred_instances')
pose_est_result_converted.set_field(
pose_est_results[i].gt_instances.clone(), 'gt_instances')
keypoints = convert_keypoint_definition(keypoints,
pose_det_dataset['type'],
pose_lift_dataset['type'])
pose_est_result_converted.pred_instances.keypoints = keypoints
pose_est_result_converted.pred_instances.set_field(
keypoints, 'keypoints')
pose_est_result_converted.set_field(pose_est_results[i].track_id,
'track_id')
pose_est_results_converted.append(pose_est_result_converted)
return pose_est_results, pose_est_results_converted, next_id

pose_est_results_list.append(pose_est_results_converted.copy())

def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
frame, frame_idx, pose_est_results):
pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset
# Second stage: Pose lifting
# extract and pad input pose2d sequence
pose_seq_2d = extract_pose_sequence(
pose_est_results_list,
Expand All @@ -223,18 +256,17 @@ def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
step=pose_lift_dataset.get('seq_step', 1))

# 2D-to-3D pose lifting
width, height = frame.shape[:2]
pose_lift_results = inference_pose_lifter_model(
pose_lifter,
pose_seq_2d,
image_size=(width, height),
image_size=visualize_frame.shape[:2],
norm_pose_2d=args.norm_pose_2d)

# Pose processing
for idx, pose_lift_res in enumerate(pose_lift_results):
pose_lift_res.track_id = pose_est_results[idx].get('track_id', 1e4)
# post-processing
for idx, pose_lift_result in enumerate(pose_lift_results):
pose_lift_result.track_id = pose_est_results[idx].get('track_id', 1e4)

pred_instances = pose_lift_res.pred_instances
pred_instances = pose_lift_result.pred_instances
keypoints = pred_instances.keypoints
keypoint_scores = pred_instances.keypoint_scores
if keypoint_scores.ndim == 3:
Expand All @@ -249,7 +281,7 @@ def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
keypoints[..., 2] = -keypoints[..., 2]

# rebase height (z-axis)
if args.rebase_keypoint_height:
if not args.disable_rebase_keypoint:
keypoints[..., 2] -= np.min(
keypoints[..., 2], axis=-1, keepdims=True)

Expand All @@ -260,6 +292,7 @@ def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,

pred_3d_data_samples = merge_data_samples(pose_lift_results)
det_data_sample = merge_data_samples(pose_est_results)
pred_3d_instances = pred_3d_data_samples.get('pred_instances', None)

if args.num_instances < 0:
args.num_instances = len(pose_lift_results)
Expand All @@ -268,7 +301,7 @@ def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
if visualizer is not None:
visualizer.add_datasample(
'result',
frame,
visualize_frame,
data_sample=pred_3d_data_samples,
det_data_sample=det_data_sample,
draw_gt=False,
Expand All @@ -278,17 +311,7 @@ def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
num_instances=args.num_instances,
wait_time=args.show_interval)

return pred_3d_data_samples.get('pred_instances', None)


def get_bbox(args, detector, frame):
det_result = inference_detector(detector, frame)
pred_instance = det_result.pred_instances.cpu().numpy()

bboxes = pred_instance.bboxes
bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id,
pred_instance.scores > args.bbox_thr)]
return bboxes
return pose_est_results, pose_est_results_list, pred_3d_instances, next_id


def main():
Expand Down Expand Up @@ -333,7 +356,6 @@ def main():
assert isinstance(pose_lifter, PoseLifter), \
'Only "PoseLifter" model is supported for the 2nd stage ' \
'(2D-to-3D lifting)'
pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset

pose_lifter.cfg.visualizer.radius = args.radius
pose_lifter.cfg.visualizer.line_width = args.thickness
Expand Down Expand Up @@ -372,27 +394,31 @@ def main():
pred_instances_list = []
if input_type == 'image':
frame = mmcv.imread(args.input, channel_order='rgb')

# First stage: 2D pose detection
bboxes = get_bbox(args, detector, frame)
pose_est_results, pose_est_results_converted, _ = get_pose_est_results(
args, pose_estimator, frame, bboxes, [], 0, pose_lift_dataset)
pose_est_results_list.append(pose_est_results_converted.copy())
pred_3d_pred = get_pose_lift_results(args, visualizer, pose_lifter,
pose_est_results_list, frame, 0,
pose_est_results)
_, _, pred_3d_instances, _ = process_one_image(
args=args,
detector=detector,
frame=frame,
frame_idx=0,
pose_estimator=pose_estimator,
pose_est_frame=frame,
pose_est_results_last=[],
pose_est_results_list=pose_est_results_list,
next_id=0,
pose_lifter=pose_lifter,
visualize_frame=frame,
visualizer=visualizer)

if args.save_predictions:
# save prediction results
pred_instances_list = split_instances(pred_3d_pred)
pred_instances_list = split_instances(pred_3d_instances)

if save_output:
frame_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(frame_vis), output_file)

elif input_type in ['webcam', 'video']:
next_id = 0
pose_est_results_converted = []
pose_est_results = []

if args.input == 'webcam':
video = cv2.VideoCapture(0)
Expand All @@ -415,33 +441,37 @@ def main():
if not success:
break

pose_est_results_last = pose_est_results_converted
pose_est_results_last = pose_est_results

# First stage: 2D pose detection
pose_est_frame = frame
if args.use_multi_frames:
frames = collect_multi_frames(video, frame_idx, indices,
args.online)
pose_est_frame = frames

# make person results for current image
bboxes = get_bbox(args, detector, frame)
pose_est_results, pose_est_results_converted, next_id = get_pose_est_results( # noqa: E501
args, pose_estimator,
frames if args.use_multi_frames else frame, bboxes,
pose_est_results_last, next_id, pose_lift_dataset)
pose_est_results_list.append(pose_est_results_converted.copy())

# Second stage: Pose lifting
pred_3d_pred = get_pose_lift_results(args, visualizer, pose_lifter,
pose_est_results_list,
mmcv.bgr2rgb(frame),
frame_idx, pose_est_results)
(pose_est_results, pose_est_results_list, pred_3d_instances,
next_id) = process_one_image(
args=args,
detector=detector,
frame=frame,
frame_idx=frame_idx,
pose_estimator=pose_estimator,
pose_est_frame=pose_est_frame,
pose_est_results_last=pose_est_results_last,
pose_est_results_list=pose_est_results_list,
next_id=next_id,
pose_lifter=pose_lifter,
visualize_frame=mmcv.bgr2rgb(frame),
visualizer=visualizer)

if args.save_predictions:
# save prediction results
pred_instances_list.append(
dict(
frame_id=frame_idx,
instances=split_instances(pred_3d_pred)))
instances=split_instances(pred_3d_instances)))

if save_output:
frame_vis = visualizer.get_image()
Expand Down
Loading