Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Visualization in 3d demo #2565

Merged
merged 1 commit into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions demo/body3d_pose_lifter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ def process_one_image(args, detector, frame, frame_idx, pose_estimator,
data_sample=pred_3d_data_samples,
det_data_sample=det_data_sample,
draw_gt=False,
dataset_2d=pose_det_dataset['type'],
dataset_3d=pose_lift_dataset['type'],
show=args.show,
draw_bbox=True,
kpt_thr=args.kpt_thr,
Expand Down
80 changes: 51 additions & 29 deletions mmpose/visualization/local_visualizer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mmengine.dist import master_only
from mmengine.structures import InstanceData

from mmpose.apis import convert_keypoint_definition
from mmpose.registry import VISUALIZERS
from mmpose.structures import PoseDataSample
from . import PoseLocalVisualizer
Expand Down Expand Up @@ -74,18 +75,17 @@ def __init__(
self.det_dataset_skeleton = det_dataset_skeleton
self.det_dataset_link_color = det_dataset_link_color

def _draw_3d_data_samples(
self,
image: np.ndarray,
pose_samples: PoseDataSample,
draw_gt: bool = True,
kpt_thr: float = 0.3,
num_instances=-1,
axis_azimuth: float = 70,
axis_limit: float = 1.7,
axis_dist: float = 10.0,
axis_elev: float = 15.0,
):
def _draw_3d_data_samples(self,
image: np.ndarray,
pose_samples: PoseDataSample,
draw_gt: bool = True,
kpt_thr: float = 0.3,
num_instances=-1,
axis_azimuth: float = 70,
axis_limit: float = 1.7,
axis_dist: float = 10.0,
axis_elev: float = 15.0,
scores_2d: Optional[np.ndarray] = None):
"""Draw keypoints and skeletons (optional) of GT or prediction.

Args:
Expand All @@ -109,6 +109,8 @@ def _draw_3d_data_samples(
- y: [y_c - axis_limit/2, y_c + axis_limit/2]
- z: [0, axis_limit]
Where x_c, y_c is the mean value of x and y coordinates
scores_2d (np.ndarray, optional): Keypoint scores of 2d estimation
that will be used to filter 3d instances.

Returns:
Tuple(np.ndarray): the drawn image which channel is RGB.
Expand Down Expand Up @@ -145,20 +147,21 @@ def _draw_3d_data_samples(

def _draw_3d_instances_kpts(keypoints,
scores,
scores_2d,
keypoints_visible,
fig_idx,
title=None):

for idx, (kpts, score, visible) in enumerate(
zip(keypoints, scores, keypoints_visible)):
for idx, (kpts, score, score_2d) in enumerate(
zip(keypoints, scores, scores_2d)):

valid = np.logical_and(score >= kpt_thr,
valid = np.logical_and(score >= kpt_thr, score_2d >= kpt_thr,
np.any(~np.isnan(kpts), axis=-1))

kpts_valid = kpts[valid]
ax = fig.add_subplot(
1, num_fig, fig_idx * (idx + 1), projection='3d')
ax.view_init(elev=axis_elev, azim=axis_azimuth)
ax.set_zlim3d([0, axis_limit])
ax.set_aspect('auto')
ax.set_xticks([])
ax.set_yticks([])
Expand All @@ -171,13 +174,14 @@ def _draw_3d_instances_kpts(keypoints,
ax.set_title(f'{title} ({idx})')
ax.dist = axis_dist

x_c = np.mean(kpts[valid, 0]) if valid.any() else 0
y_c = np.mean(kpts[valid, 1]) if valid.any() else 0
x_c = np.mean(kpts_valid[:, 0]) if valid.any() else 0
y_c = np.mean(kpts_valid[:, 1]) if valid.any() else 0
z_c = np.mean(kpts_valid[:, 2]) if valid.any() else 0

ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2])
ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2])

kpts = np.array(kpts, copy=False)
ax.set_zlim3d(
[min(0, z_c - axis_limit / 2), z_c + axis_limit / 2])

if self.kpt_color is None or isinstance(self.kpt_color, str):
kpt_color = [self.kpt_color] * len(kpts)
Expand All @@ -189,8 +193,7 @@ def _draw_3d_instances_kpts(keypoints,
f'({len(self.kpt_color)}) does not matches '
f'that of keypoints ({len(kpts)})')

kpts = kpts[valid]
x_3d, y_3d, z_3d = np.split(kpts[:, :3], [1, 2], axis=1)
x_3d, y_3d, z_3d = np.split(kpts_valid[:, :3], [1, 2], axis=1)

kpt_color = kpt_color[valid][..., ::-1] / 255.

Expand Down Expand Up @@ -218,7 +221,9 @@ def _draw_3d_instances_kpts(keypoints,
ys_3d = kpts[sk_indices, 1]
zs_3d = kpts[sk_indices, 2]
kpt_score = score[sk_indices]
if kpt_score.min() > kpt_thr:
kpt_score_2d = score_2d[sk_indices]
if kpt_score.min() > kpt_thr and kpt_score_2d.min(
) > kpt_thr:
# matplotlib uses RGB color in [0, 1] value range
_color = link_color[sk_id][::-1] / 255.
ax.plot(
Expand All @@ -233,13 +238,16 @@ def _draw_3d_instances_kpts(keypoints,
else:
scores = np.ones(keypoints.shape[:-1])

if scores_2d is None:
scores_2d = np.ones(keypoints.shape[:-1])

if 'keypoints_visible' in pred_instances:
keypoints_visible = pred_instances.keypoints_visible
else:
keypoints_visible = np.ones(keypoints.shape[:-1])

_draw_3d_instances_kpts(keypoints, scores, keypoints_visible, 1,
'Prediction')
_draw_3d_instances_kpts(keypoints, scores, scores_2d,
keypoints_visible, 1, 'Prediction')

if draw_gt and 'gt_instances' in pose_samples:
gt_instances = pose_samples.gt_instances
Expand Down Expand Up @@ -300,6 +308,7 @@ def _draw_instances_kpts(self,

self.set_image(image)
img_h, img_w, _ = image.shape
scores = None

if 'keypoints' in instances:
keypoints = instances.get('transformed_keypoints',
Expand Down Expand Up @@ -452,7 +461,7 @@ def _draw_instances_kpts(self,
self.draw_lines(
X, Y, color, line_widths=self.line_width)

return self.get_image()
return self.get_image(), scores

@master_only
def add_datasample(self,
Expand All @@ -466,6 +475,8 @@ def add_datasample(self,
draw_bbox: bool = False,
show_kpt_idx: bool = False,
skeleton_style: str = 'mmpose',
dataset_2d: str = 'CocoDataset',
dataset_3d: str = 'Human36mDataset',
num_instances: int = -1,
show: bool = False,
wait_time: float = 0,
Expand Down Expand Up @@ -502,6 +513,10 @@ def add_datasample(self,
Defaults to ``False``
skeleton_style (str): Skeleton style selection. Defaults to
``'mmpose'``
dataset_2d (str): Name of 2d keypoint dataset. Defaults to
``'CocoDataset'``
dataset_3d (str): Name of 3d keypoint dataset. Defaults to
``'Human36mDataset'``
num_instances (int): Number of instances to be shown in 3D. If
smaller than 0, all the instances in the pose_result will be
shown. Otherwise, pad or truncate the pose_result to a length
Expand All @@ -517,24 +532,31 @@ def add_datasample(self,

det_img_data = None
gt_img_data = None
scores_2d = None

if draw_2d:
det_img_data = image.copy()

# draw bboxes & keypoints
if 'pred_instances' in det_data_sample:
det_img_data = self._draw_instances_kpts(
det_img_data, scores_2d = self._draw_instances_kpts(
det_img_data, det_data_sample.pred_instances, kpt_thr,
show_kpt_idx, skeleton_style)
if draw_bbox:
det_img_data = self._draw_instances_bbox(
det_img_data, det_data_sample.pred_instances)

if scores_2d is not None:
if scores_2d.ndim == 2:
scores_2d = scores_2d[..., None]
scores_2d = np.squeeze(
convert_keypoint_definition(scores_2d, dataset_2d, dataset_3d),
axis=-1)
pred_img_data = self._draw_3d_data_samples(
image.copy(),
data_sample,
draw_gt=draw_gt,
num_instances=num_instances)
num_instances=num_instances,
scores_2d=scores_2d)

# merge visualization results
if det_img_data is not None and gt_img_data is not None:
Expand Down