Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] 3d human pose demo #2554

Merged
merged 3 commits into from
Jul 21, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 98 additions & 100 deletions demo/body3d_pose_lifter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import mmengine
import numpy as np
from mmengine.logging import print_log
from mmengine.structures import InstanceData

from mmpose.apis import (_track_by_iou, _track_by_oks, collect_multi_frames,
convert_keypoint_definition, extract_pose_sequence,
Expand Down Expand Up @@ -130,16 +129,43 @@ def parse_args():
return args


def get_area(results):
for i, data_sample in enumerate(results):
pred_instance = data_sample.pred_instances.cpu().numpy()
if 'bboxes' in pred_instance:
bboxes = pred_instance.bboxes
results[i].pred_instances.set_field(
np.array([(bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
for bbox in bboxes]), 'areas')
def process_one_image(args, detector, frame, frame_idx, pose_estimator,
pose_est_frame, pose_est_results_last,
pose_est_results_list, next_id, pose_lifter,
pose_lift_frame, visualizer):
"""Visualize detected and predicted keypoints of one image."""
LareinaM marked this conversation as resolved.
Show resolved Hide resolved

pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset

det_result = inference_detector(detector, frame)
pred_instance = det_result.pred_instances.cpu().numpy()

# First stage: 2D pose detection
bboxes = pred_instance.bboxes
bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id,
pred_instance.scores > args.bbox_thr)]

# estimate pose results for current image
pose_est_results = inference_topdown(pose_estimator, pose_est_frame,
bboxes)

if args.use_oks_tracking:
_track = partial(_track_by_oks)
else:
_track = _track_by_iou

pose_det_dataset = pose_estimator.cfg.test_dataloader.dataset
pose_est_results_converted = []

for i, data_sample in enumerate(pose_est_results):
pred_instances = data_sample.pred_instances.cpu().numpy()
keypoints = pred_instances.keypoints
# calculate area and bbox
if 'bboxes' in pred_instances:
areas = np.array([(bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
for bbox in pred_instances.bboxes])
pose_est_results[i].pred_instances.set_field(areas, 'areas')
else:
keypoints = pred_instance.keypoints
areas, bboxes = [], []
for keypoint in keypoints:
xmin = np.min(keypoint[:, 0][keypoint[:, 0] > 0], initial=1e10)
Expand All @@ -148,72 +174,47 @@ def get_area(results):
ymax = np.max(keypoint[:, 1])
areas.append((xmax - xmin) * (ymax - ymin))
bboxes.append([xmin, ymin, xmax, ymax])
results[i].pred_instances.areas = np.array(areas)
results[i].pred_instances.bboxes = np.array(bboxes)
return results


def get_pose_est_results(args, pose_estimator, frame, bboxes,
pose_est_results_last, next_id, pose_lift_dataset):
pose_det_dataset = pose_estimator.cfg.test_dataloader.dataset

# make person results for current image
pose_est_results = inference_topdown(pose_estimator, frame, bboxes)
pose_est_results[i].pred_instances.areas = np.array(areas)
pose_est_results[i].pred_instances.bboxes = np.array(bboxes)

pose_est_results = get_area(pose_est_results)
if args.use_oks_tracking:
_track = partial(_track_by_oks)
else:
_track = _track_by_iou

for i, result in enumerate(pose_est_results):
track_id, pose_est_results_last, match_result = _track(
result, pose_est_results_last, args.tracking_thr)
# track id
track_id, pose_est_results_last, _ = _track(data_sample,
pose_est_results_last,
args.tracking_thr)
if track_id == -1:
pred_instances = result.pred_instances.cpu().numpy()
keypoints = pred_instances.keypoints
if np.count_nonzero(keypoints[:, :, 1]) >= 3:
pose_est_results[i].set_field(next_id, 'track_id')
track_id = next_id
next_id += 1
else:
# If the number of keypoints detected is small,
# delete that person instance.
keypoints[:, :, 1] = -10
pose_est_results[i].pred_instances.set_field(
keypoints, 'keypoints')
bboxes = pred_instances.bboxes * 0
pose_est_results[i].pred_instances.set_field(bboxes, 'bboxes')
pose_est_results[i].set_field(-1, 'track_id')
pose_est_results[i].pred_instances.set_field(
pred_instances.bboxes * 0, 'bboxes')
pose_est_results[i].set_field(pred_instances, 'pred_instances')
else:
pose_est_results[i].set_field(track_id, 'track_id')
track_id = -1
pose_est_results[i].set_field(track_id, 'track_id')

del match_result

pose_est_results_converted = []
for pose_est_result in pose_est_results:
# convert keypoints for pose-lifting
pose_est_result_converted = PoseDataSample()
gt_instances = InstanceData()
pred_instances = InstanceData()
for k in pose_est_result.gt_instances.keys():
gt_instances.set_field(pose_est_result.gt_instances[k], k)
for k in pose_est_result.pred_instances.keys():
pred_instances.set_field(pose_est_result.pred_instances[k], k)
pose_est_result_converted.gt_instances = gt_instances
pose_est_result_converted.pred_instances = pred_instances
pose_est_result_converted.track_id = pose_est_result.track_id

keypoints = convert_keypoint_definition(pred_instances.keypoints,
pose_est_result_converted.set_field(
pose_est_results[i].pred_instances.clone(), 'pred_instances')
pose_est_result_converted.set_field(
pose_est_results[i].gt_instances.clone(), 'gt_instances')
keypoints = convert_keypoint_definition(keypoints,
pose_det_dataset['type'],
pose_lift_dataset['type'])
pose_est_result_converted.pred_instances.keypoints = keypoints
pose_est_result_converted.pred_instances.set_field(
keypoints, 'keypoints')
pose_est_result_converted.set_field(pose_est_results[i].track_id,
'track_id')
pose_est_results_converted.append(pose_est_result_converted)
return pose_est_results, pose_est_results_converted, next_id

pose_est_results_list.append(pose_est_results_converted.copy())

def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
frame, frame_idx, pose_est_results):
pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset
# Second stage: Pose lifting
# extract and pad input pose2d sequence
pose_seq_2d = extract_pose_sequence(
pose_est_results_list,
Expand All @@ -223,18 +224,17 @@ def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
step=pose_lift_dataset.get('seq_step', 1))

# 2D-to-3D pose lifting
width, height = frame.shape[:2]
pose_lift_results = inference_pose_lifter_model(
pose_lifter,
pose_seq_2d,
image_size=(width, height),
image_size=pose_lift_frame.shape[:2],
norm_pose_2d=args.norm_pose_2d)

# Pose processing
for idx, pose_lift_res in enumerate(pose_lift_results):
pose_lift_res.track_id = pose_est_results[idx].get('track_id', 1e4)
# post-processing
for idx, pose_lift_result in enumerate(pose_lift_results):
pose_lift_result.track_id = pose_est_results[idx].get('track_id', 1e4)

pred_instances = pose_lift_res.pred_instances
pred_instances = pose_lift_result.pred_instances
keypoints = pred_instances.keypoints
keypoint_scores = pred_instances.keypoint_scores
if keypoint_scores.ndim == 3:
Expand All @@ -260,6 +260,7 @@ def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,

pred_3d_data_samples = merge_data_samples(pose_lift_results)
det_data_sample = merge_data_samples(pose_est_results)
pred_3d_pred = pred_3d_data_samples.get('pred_instances', None)
LareinaM marked this conversation as resolved.
Show resolved Hide resolved

if args.num_instances < 0:
args.num_instances = len(pose_lift_results)
Expand All @@ -268,7 +269,7 @@ def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
if visualizer is not None:
visualizer.add_datasample(
'result',
frame,
pose_lift_frame,
data_sample=pred_3d_data_samples,
det_data_sample=det_data_sample,
draw_gt=False,
Expand All @@ -278,17 +279,7 @@ def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
num_instances=args.num_instances,
wait_time=args.show_interval)

return pred_3d_data_samples.get('pred_instances', None)


def get_bbox(args, detector, frame):
det_result = inference_detector(detector, frame)
pred_instance = det_result.pred_instances.cpu().numpy()

bboxes = pred_instance.bboxes
bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id,
pred_instance.scores > args.bbox_thr)]
return bboxes
return pose_est_results, pose_est_results_list, pred_3d_pred, next_id


def main():
Expand Down Expand Up @@ -333,7 +324,6 @@ def main():
assert isinstance(pose_lifter, PoseLifter), \
'Only "PoseLifter" model is supported for the 2nd stage ' \
'(2D-to-3D lifting)'
pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset

pose_lifter.cfg.visualizer.radius = args.radius
pose_lifter.cfg.visualizer.line_width = args.thickness
Expand Down Expand Up @@ -372,15 +362,19 @@ def main():
pred_instances_list = []
if input_type == 'image':
frame = mmcv.imread(args.input, channel_order='rgb')

# First stage: 2D pose detection
bboxes = get_bbox(args, detector, frame)
pose_est_results, pose_est_results_converted, _ = get_pose_est_results(
args, pose_estimator, frame, bboxes, [], 0, pose_lift_dataset)
pose_est_results_list.append(pose_est_results_converted.copy())
pred_3d_pred = get_pose_lift_results(args, visualizer, pose_lifter,
pose_est_results_list, frame, 0,
pose_est_results)
_, _, pred_3d_pred, _ = process_one_image(
args=args,
detector=detector,
frame=frame,
frame_idx=0,
pose_estimator=pose_estimator,
pose_est_frame=frame,
pose_est_results_last=[],
pose_est_results_list=pose_est_results_list,
next_id=0,
pose_lifter=pose_lifter,
pose_lift_frame=frame,
visualizer=visualizer)

if args.save_predictions:
# save prediction results
Expand All @@ -392,7 +386,7 @@ def main():

elif input_type in ['webcam', 'video']:
next_id = 0
pose_est_results_converted = []
pose_est_results = []

if args.input == 'webcam':
video = cv2.VideoCapture(0)
Expand All @@ -415,26 +409,30 @@ def main():
if not success:
break

pose_est_results_last = pose_est_results_converted
pose_est_results_last = pose_est_results

# First stage: 2D pose detection
pose_est_frame = frame
if args.use_multi_frames:
frames = collect_multi_frames(video, frame_idx, indices,
args.online)
pose_est_frame = frames

# make person results for current image
bboxes = get_bbox(args, detector, frame)
pose_est_results, pose_est_results_converted, next_id = get_pose_est_results( # noqa: E501
args, pose_estimator,
frames if args.use_multi_frames else frame, bboxes,
pose_est_results_last, next_id, pose_lift_dataset)
pose_est_results_list.append(pose_est_results_converted.copy())

# Second stage: Pose lifting
pred_3d_pred = get_pose_lift_results(args, visualizer, pose_lifter,
pose_est_results_list,
mmcv.bgr2rgb(frame),
frame_idx, pose_est_results)
(pose_est_results, pose_est_results_list, pred_3d_pred,
next_id) = process_one_image(
args=args,
detector=detector,
frame=frame,
frame_idx=frame_idx,
pose_estimator=pose_estimator,
pose_est_frame=pose_est_frame,
pose_est_results_last=pose_est_results_last,
pose_est_results_list=pose_est_results_list,
next_id=next_id,
pose_lifter=pose_lifter,
pose_lift_frame=mmcv.bgr2rgb(frame),
visualizer=visualizer)

if args.save_predictions:
# save prediction results
Expand Down