diff --git a/demo/bottomup_demo.py b/demo/bottomup_demo.py index 4616f54b3e..3d6fee7a03 100644 --- a/demo/bottomup_demo.py +++ b/demo/bottomup_demo.py @@ -11,6 +11,7 @@ import numpy as np from mmpose.apis import inference_bottomup, init_model +from mmpose.registry import VISUALIZERS from mmpose.structures import split_instances @@ -128,20 +129,18 @@ def main(): device=args.device, cfg_options=cfg_options) + # build visualizer + model.cfg.visualizer.radius = args.radius + model.cfg.visualizer.line_width = args.thickness + visualizer = VISUALIZERS.build(model.cfg.visualizer) + visualizer.set_dataset_meta(model.dataset_meta) + if args.input == 'webcam': input_type = 'webcam' else: input_type = mimetypes.guess_type(args.input)[0].split('/')[0] if input_type == 'image': - # init visualizer - from mmpose.registry import VISUALIZERS - - model.cfg.visualizer.radius = args.radius - model.cfg.visualizer.line_width = args.thickness - visualizer = VISUALIZERS.build(model.cfg.visualizer) - visualizer.set_dataset_meta(model.dataset_meta) - # inference pred_instances = process_one_image( args, args.input, model, visualizer, show_interval=0) @@ -154,22 +153,6 @@ def main(): mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file) elif input_type in ['webcam', 'video']: - from mmpose.visualization import FastVisualizer - - visualizer = FastVisualizer( - model.dataset_meta, - radius=args.radius, - line_width=args.thickness, - kpt_thr=args.kpt_thr) - - if args.draw_heatmap: - # init Localvisualizer - from mmpose.registry import VISUALIZERS - - model.cfg.visualizer.radius = args.radius - model.cfg.visualizer.line_width = args.thickness - local_visualizer = VISUALIZERS.build(model.cfg.visualizer) - local_visualizer.set_dataset_meta(model.dataset_meta) if args.input == 'webcam': cap = cv2.VideoCapture(0) @@ -187,15 +170,8 @@ def main(): if not success: break - # bottom-up pose estimation - if args.draw_heatmap: - pred_instances = process_one_image(args, frame, model, - local_visualizer, 0.001) - else: - pred_instances = process_one_image(args, frame, model) - # visualization - visualizer.draw_pose(frame, pred_instances) - cv2.imshow('MMPose Demo [Press ESC to Exit]', frame) + pred_instances = process_one_image(args, frame, model, visualizer, + 0.001) if args.save_predictions: # save prediction results @@ -206,10 +182,7 @@ def main(): # output videos if output_file: - if args.draw_heatmap: - frame_vis = local_visualizer.get_image() - else: - frame_vis = frame.copy()[:, :, ::-1] + frame_vis = visualizer.get_image() if video_writer is None: fourcc = cv2.VideoWriter_fourcc(*'mp4v') diff --git a/demo/topdown_demo_with_mmdet.py b/demo/topdown_demo_with_mmdet.py index cd001e8db6..a143795693 100644 --- a/demo/topdown_demo_with_mmdet.py +++ b/demo/topdown_demo_with_mmdet.py @@ -13,6 +13,7 @@ from mmpose.apis import inference_topdown from mmpose.apis import init_model as init_pose_estimator from mmpose.evaluation.functional import nms +from mmpose.registry import VISUALIZERS from mmpose.structures import merge_data_samples, split_instances from mmpose.utils import adapt_mmdet_pipeline @@ -186,24 +187,22 @@ def main(): cfg_options=dict( model=dict(test_cfg=dict(output_heatmaps=args.draw_heatmap)))) + # build visualizer + pose_estimator.cfg.visualizer.radius = args.radius + pose_estimator.cfg.visualizer.alpha = args.alpha + pose_estimator.cfg.visualizer.line_width = args.thickness + visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer) + # the dataset_meta is loaded from the checkpoint and + # then pass to the model in init_pose_estimator + visualizer.set_dataset_meta( + pose_estimator.dataset_meta, skeleton_style=args.skeleton_style) + if args.input == 'webcam': input_type = 'webcam' else: input_type = mimetypes.guess_type(args.input)[0].split('/')[0] if input_type == 'image': - # init visualizer - from mmpose.registry import VISUALIZERS - - pose_estimator.cfg.visualizer.radius = args.radius - pose_estimator.cfg.visualizer.alpha = args.alpha - pose_estimator.cfg.visualizer.line_width = args.thickness - visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer) - - # the dataset_meta is loaded from the checkpoint and - # then pass to the model in init_pose_estimator - visualizer.set_dataset_meta( - pose_estimator.dataset_meta, skeleton_style=args.skeleton_style) # inference pred_instances = process_one_image(args, args.input, detector, @@ -218,28 +217,6 @@ def main(): mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file) elif input_type in ['webcam', 'video']: - from mmpose.visualization import FastVisualizer - - visualizer = FastVisualizer( - pose_estimator.dataset_meta, - radius=args.radius, - line_width=args.thickness, - kpt_thr=args.kpt_thr) - - if args.draw_heatmap: - # init Localvisualizer - from mmpose.registry import VISUALIZERS - - pose_estimator.cfg.visualizer.radius = args.radius - pose_estimator.cfg.visualizer.alpha = args.alpha - pose_estimator.cfg.visualizer.line_width = args.thickness - local_visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer) - - # the dataset_meta is loaded from the checkpoint and - # then pass to the model in init_pose_estimator - local_visualizer.set_dataset_meta( - pose_estimator.dataset_meta, - skeleton_style=args.skeleton_style) if args.input == 'webcam': cap = cv2.VideoCapture(0) @@ -258,16 +235,9 @@ def main(): break # topdown pose estimation - if args.draw_heatmap: - pred_instances = process_one_image(args, frame, detector, - pose_estimator, - local_visualizer, 0.001) - else: - pred_instances = process_one_image(args, frame, detector, - pose_estimator) - # visualization - visualizer.draw_pose(frame, pred_instances) - cv2.imshow('MMPose Demo [Press ESC to Exit]', frame) + pred_instances = process_one_image(args, frame, detector, + pose_estimator, visualizer, + 0.001) if args.save_predictions: # save prediction results @@ -278,10 +248,7 @@ def main(): # output videos if output_file: - if args.draw_heatmap: - frame_vis = local_visualizer.get_image() - else: - frame_vis = frame.copy()[:, :, ::-1] + frame_vis = visualizer.get_image() if video_writer is None: fourcc = cv2.VideoWriter_fourcc(*'mp4v') diff --git a/mmpose/apis/inferencers/base_mmpose_inferencer.py b/mmpose/apis/inferencers/base_mmpose_inferencer.py index 15312c6bb7..86e61463b6 100644 --- a/mmpose/apis/inferencers/base_mmpose_inferencer.py +++ b/mmpose/apis/inferencers/base_mmpose_inferencer.py @@ -159,6 +159,9 @@ def _get_webcam_inputs(self, inputs: str) -> Generator: Raises: ValueError: If the inputs string is not in the expected format. """ + assert getattr(self.visualizer, 'backend', None) == 'opencv', \ + 'Visualizer must utilize the OpenCV backend in order to ' \ + 'support webcam inputs.' # Ensure the inputs string is in the expected format. inputs = inputs.lower() @@ -187,12 +190,9 @@ def _get_webcam_inputs(self, inputs: str) -> Generator: self.video_info = dict( fps=10, name='webcam.mp4', writer=None, predictions=[]) - # Set up webcam reader generator function. - self._window_closing = False - def _webcam_reader() -> Generator: while True: - if self._window_closing: + if cv2.waitKey(5) & 0xFF == 27: vcap.release() break @@ -322,16 +322,6 @@ def visualize(self, kpt_thr=kpt_thr) results.append(visualization) - if show and not hasattr(self, '_window_close_cid'): - if window_close_event_handler is None: - window_close_event_handler = \ - self._visualization_window_on_close - self._window_close_cid = \ - self.visualizer.manager.canvas.mpl_connect( - 'close_event', - window_close_event_handler - ) - if vis_out_dir: out_img = mmcv.rgb2bgr(visualization) diff --git a/mmpose/visualization/__init__.py b/mmpose/visualization/__init__.py index 73fbd645a9..357d40a707 100644 --- a/mmpose/visualization/__init__.py +++ b/mmpose/visualization/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .fast_visualizer import FastVisualizer from .local_visualizer import PoseLocalVisualizer -__all__ = ['PoseLocalVisualizer', 'FastVisualizer'] +__all__ = ['PoseLocalVisualizer'] diff --git a/mmpose/visualization/fast_visualizer.py b/mmpose/visualization/fast_visualizer.py deleted file mode 100644 index fa0cb38527..0000000000 --- a/mmpose/visualization/fast_visualizer.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import cv2 - - -class FastVisualizer: - """MMPose Fast Visualizer. - - A simple yet fast visualizer for video/webcam inference. - - Args: - metainfo (dict): pose meta information - radius (int, optional)): Keypoint radius for visualization. - Defaults to 6. - line_width (int, optional): Link width for visualization. - Defaults to 3. - kpt_thr (float, optional): Threshold for keypoints' confidence score, - keypoints with score below this value will not be drawn. - Defaults to 0.3. - """ - - def __init__(self, metainfo, radius=6, line_width=3, kpt_thr=0.3): - self.radius = radius - self.line_width = line_width - self.kpt_thr = kpt_thr - - self.keypoint_id2name = metainfo['keypoint_id2name'] - self.keypoint_name2id = metainfo['keypoint_name2id'] - self.keypoint_colors = metainfo['keypoint_colors'] - self.skeleton_links = metainfo['skeleton_links'] - self.skeleton_link_colors = metainfo['skeleton_link_colors'] - - def draw_pose(self, img, instances): - """Draw pose estimations on the given image. - - This method draws keypoints and skeleton links on the input image - using the provided instances. - - Args: - img (numpy.ndarray): The input image on which to - draw the pose estimations. - instances (object): An object containing detected instances' - information, including keypoints and keypoint_scores. - - Returns: - None: The input image will be modified in place. - """ - - if instances is None: - print('no instance detected') - return - - keypoints = instances.keypoints - scores = instances.keypoint_scores - - for kpts, score in zip(keypoints, scores): - for sk_id, sk in enumerate(self.skeleton_links): - if score[sk[0]] < self.kpt_thr or score[sk[1]] < self.kpt_thr: - # skip the link that should not be drawn - continue - - pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) - pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) - - color = self.skeleton_link_colors[sk_id].tolist() - cv2.line(img, pos1, pos2, color, thickness=self.line_width) - - for kid, kpt in enumerate(kpts): - if score[kid] < self.kpt_thr: - # skip the point that should not be drawn - continue - - x_coord, y_coord = int(kpt[0]), int(kpt[1]) - - color = self.keypoint_colors[kid].tolist() - cv2.circle(img, (int(x_coord), int(y_coord)), self.radius, - color, -1) - cv2.circle(img, (int(x_coord), int(y_coord)), self.radius, - (255, 255, 255)) diff --git a/mmpose/visualization/local_visualizer.py b/mmpose/visualization/local_visualizer.py index b19e89dea6..205993c006 100644 --- a/mmpose/visualization/local_visualizer.py +++ b/mmpose/visualization/local_visualizer.py @@ -8,11 +8,11 @@ import torch from mmengine.dist import master_only from mmengine.structures import InstanceData, PixelData -from mmengine.visualization import Visualizer from mmpose.datasets.datasets.utils import parse_pose_metainfo from mmpose.registry import VISUALIZERS from mmpose.structures import PoseDataSample +from .opencv_backend_visualizer import OpencvBackendVisualizer from .simcc_vis import SimCCVisualizer @@ -42,7 +42,7 @@ def _get_adaptive_scales(areas: np.ndarray, @VISUALIZERS.register_module() -class PoseLocalVisualizer(Visualizer): +class PoseLocalVisualizer(OpencvBackendVisualizer): """MMPose Local Visualizer. Args: @@ -115,8 +115,15 @@ def __init__(self, line_width: Union[int, float] = 1, radius: Union[int, float] = 3, show_keypoint_weight: bool = False, + backend: str = 'opencv', alpha: float = 0.8): - super().__init__(name, image, vis_backends, save_dir) + super().__init__( + name=name, + image=image, + vis_backends=vis_backends, + save_dir=save_dir, + backend=backend) + self.bbox_color = bbox_color self.kpt_color = kpt_color self.link_color = link_color @@ -297,35 +304,6 @@ def _draw_instances_kpts(self, f'({len(self.kpt_color)}) does not matches ' f'that of keypoints ({len(kpts)})') - # draw each point on image - for kid, kpt in enumerate(kpts): - if score[kid] < kpt_thr or not visible[ - kid] or kpt_color[kid] is None: - # skip the point that should not be drawn - continue - - color = kpt_color[kid] - if not isinstance(color, str): - color = tuple(int(c) for c in color) - transparency = self.alpha - if self.show_keypoint_weight: - transparency *= max(0, min(1, score[kid])) - self.draw_circles( - kpt, - radius=np.array([self.radius]), - face_colors=color, - edge_colors=color, - alpha=transparency, - line_widths=self.radius) - if show_kpt_idx: - self.draw_texts( - str(kid), - kpt, - colors=color, - font_sizes=self.radius * 3, - vertical_alignments='bottom', - horizontal_alignments='center') - # draw links if self.skeleton is not None and self.link_color is not None: if self.link_color is None or isinstance( @@ -385,6 +363,37 @@ def _draw_instances_kpts(self, self.draw_lines( X, Y, color, line_widths=self.line_width) + # draw each point on image + for kid, kpt in enumerate(kpts): + if score[kid] < kpt_thr or not visible[ + kid] or kpt_color[kid] is None: + # skip the point that should not be drawn + continue + + color = kpt_color[kid] + if not isinstance(color, str): + color = tuple(int(c) for c in color) + transparency = self.alpha + if self.show_keypoint_weight: + transparency *= max(0, min(1, score[kid])) + self.draw_circles( + kpt, + radius=np.array([self.radius]), + face_colors=color, + edge_colors=color, + alpha=transparency, + line_widths=self.radius) + if show_kpt_idx: + kpt[0] += self.radius + kpt[1] -= self.radius + self.draw_texts( + str(kid), + kpt, + colors=color, + font_sizes=self.radius * 3, + vertical_alignments='bottom', + horizontal_alignments='center') + return self.get_image() def _draw_instance_heatmap( diff --git a/mmpose/visualization/opencv_backend_visualizer.py b/mmpose/visualization/opencv_backend_visualizer.py new file mode 100644 index 0000000000..66a7731c76 --- /dev/null +++ b/mmpose/visualization/opencv_backend_visualizer.py @@ -0,0 +1,444 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import cv2 +import mmcv +import numpy as np +import torch +from mmengine.dist import master_only +from mmengine.visualization import Visualizer + + +class OpencvBackendVisualizer(Visualizer): + """Base visualizer with opencv backend support. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + fig_save_cfg (dict): Keyword parameters of figure for saving. + Defaults to empty dict. + fig_show_cfg (dict): Keyword parameters of figure for showing. + Defaults to empty dict. + backend (str): Backend used to draw elements on the image and display + the image. Defaults to 'matplotlib'. + """ + + def __init__(self, + name='visualizer', + backend: str = 'matplotlib', + *args, + **kwargs): + super().__init__(name, *args, **kwargs) + assert backend in ('opencv', 'matplotlib'), f'the argument ' \ + f'\'backend\' must be either \'opencv\' or \'matplotlib\', ' \ + f'but got \'{backend}\'.' + self.backend = backend + + @master_only + def set_image(self, image: np.ndarray) -> None: + """Set the image to draw. + + Args: + image (np.ndarray): The image to draw. + backend (str): The backend to save the image. + """ + assert image is not None + image = image.astype('uint8') + self._image = image + self.width, self.height = image.shape[1], image.shape[0] + self._default_font_size = max( + np.sqrt(self.height * self.width) // 90, 10) + + if self.backend == 'matplotlib': + # add a small 1e-2 to avoid precision lost due to matplotlib's + # truncation (https://github.com/matplotlib/matplotlib/issues/15363) # noqa + self.fig_save.set_size_inches( # type: ignore + (self.width + 1e-2) / self.dpi, + (self.height + 1e-2) / self.dpi) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + self.ax_save.cla() + self.ax_save.axis(False) + self.ax_save.imshow( + image, + extent=(0, self.width, self.height, 0), + interpolation='none') + + @master_only + def get_image(self) -> np.ndarray: + """Get the drawn image. The format is RGB. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + assert self._image is not None, 'Please set image using `set_image`' + if self.backend == 'matplotlib': + return super().get_image() + else: + return self._image + + @master_only + def draw_circles(self, + center: Union[np.ndarray, torch.Tensor], + radius: Union[np.ndarray, torch.Tensor], + face_colors: Union[str, tuple, List[str], + List[tuple]] = 'none', + **kwargs) -> 'Visualizer': + """Draw single or multiple circles. + + Args: + center (Union[np.ndarray, torch.Tensor]): The x coordinate of + each line' start and end points. + radius (Union[np.ndarray, torch.Tensor]): The y coordinate of + each line' start and end points. + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of circles. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, + all the lines will have the same colors. Reference to + https://matplotlib.org/stable/gallery/color/named_colors.html + for more details. Defaults to 'g. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single + value, all the lines will have the same linestyle. + Reference to + https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle + for more details. Defaults to '-'. + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have + the same length with lines or just single value. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. + face_colors (Union[str, tuple, List[str], List[tuple]]): + The face colors. Defaults to None. + alpha (Union[int, float]): The transparency of circles. + Defaults to 0.8. + """ + if self.backend == 'matplotlib': + super().draw_circles( + center=center, + radius=radius, + face_colors=face_colors, + **kwargs) + elif self.backend == 'opencv': + if isinstance(face_colors, str): + face_colors = mmcv.color_val(face_colors) + self._image = cv2.circle(self._image, + (int(center[0]), int(center[1])), + int(radius), face_colors, -1) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def draw_texts( + self, + texts: Union[str, List[str]], + positions: Union[np.ndarray, torch.Tensor], + font_sizes: Optional[Union[int, List[int]]] = None, + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + vertical_alignments: Union[str, List[str]] = 'top', + horizontal_alignments: Union[str, List[str]] = 'left', + bboxes: Optional[Union[dict, List[dict]]] = None, + **kwargs, + ) -> 'Visualizer': + """Draw single or multiple text boxes. + + Args: + texts (Union[str, List[str]]): Texts to draw. + positions (Union[np.ndarray, torch.Tensor]): The position to draw + the texts, which should have the same length with texts and + each dim contain x and y. + font_sizes (Union[int, List[int]], optional): The font size of + texts. ``font_sizes`` can have the same length with texts or + just single value. If ``font_sizes`` is single value, all the + texts will have the same font size. Defaults to None. + colors (Union[str, tuple, List[str], List[tuple]]): The colors + of texts. ``colors`` can have the same length with texts or + just single value. If ``colors`` is single value, all the + texts will have the same colors. Reference to + https://matplotlib.org/stable/gallery/color/named_colors.html + for more details. Defaults to 'g. + vertical_alignments (Union[str, List[str]]): The verticalalignment + of texts. verticalalignment controls whether the y positional + argument for the text indicates the bottom, center or top side + of the text bounding box. + ``vertical_alignments`` can have the same length with + texts or just single value. If ``vertical_alignments`` is + single value, all the texts will have the same + verticalalignment. verticalalignment can be 'center' or + 'top', 'bottom' or 'baseline'. Defaults to 'top'. + horizontal_alignments (Union[str, List[str]]): The + horizontalalignment of texts. Horizontalalignment controls + whether the x positional argument for the text indicates the + left, center or right side of the text bounding box. + ``horizontal_alignments`` can have + the same length with texts or just single value. + If ``horizontal_alignments`` is single value, all the texts + will have the same horizontalalignment. Horizontalalignment + can be 'center','right' or 'left'. Defaults to 'left'. + font_families (Union[str, List[str]]): The font family of + texts. ``font_families`` can have the same length with texts or + just single value. If ``font_families`` is single value, all + the texts will have the same font family. + font_familiy can be 'serif', 'sans-serif', 'cursive', 'fantasy' + or 'monospace'. Defaults to 'sans-serif'. + bboxes (Union[dict, List[dict]], optional): The bounding box of the + texts. If bboxes is None, there are no bounding box around + texts. ``bboxes`` can have the same length with texts or + just single value. If ``bboxes`` is single value, all + the texts will have the same bbox. Reference to + https://matplotlib.org/stable/api/_as_gen/matplotlib.patches.FancyBboxPatch.html#matplotlib.patches.FancyBboxPatch + for more details. Defaults to None. + font_properties (Union[FontProperties, List[FontProperties]], optional): + The font properties of texts. FontProperties is + a ``font_manager.FontProperties()`` object. + If you want to draw Chinese texts, you need to prepare + a font file that can show Chinese characters properly. + For example: `simhei.ttf`, `simsun.ttc`, `simkai.ttf` and so on. + Then set ``font_properties=matplotlib.font_manager.FontProperties(fname='path/to/font_file')`` + ``font_properties`` can have the same length with texts or + just single value. If ``font_properties`` is single value, + all the texts will have the same font properties. + Defaults to None. + `New in version 0.6.0.` + """ # noqa: E501 + + if self.backend == 'matplotlib': + super().draw_texts( + texts=texts, + positions=positions, + font_sizes=font_sizes, + colors=colors, + vertical_alignments=vertical_alignments, + horizontal_alignments=horizontal_alignments, + bboxes=bboxes, + **kwargs) + + elif self.backend == 'opencv': + font_scale = max(0.1, font_sizes / 30) + thickness = max(1, font_sizes // 15) + + text_size, text_baseline = cv2.getTextSize(texts, + cv2.FONT_HERSHEY_DUPLEX, + font_scale, thickness) + + x = int(positions[0]) + if horizontal_alignments == 'right': + x = max(0, x - text_size[0]) + y = int(positions[1]) + if vertical_alignments == 'top': + y = min(self.height, y + text_size[1]) + + if bboxes is not None: + bbox_color = bboxes[0]['facecolor'] + if isinstance(bbox_color, str): + bbox_color = mmcv.color_val(bbox_color) + + y = y - text_baseline // 2 + self._image = cv2.rectangle( + self._image, (x, y - text_size[1] - text_baseline // 2), + (x + text_size[0], y + text_baseline // 2), bbox_color, + cv2.FILLED) + + self._image = cv2.putText(self._image, texts, (x, y), + cv2.FONT_HERSHEY_SIMPLEX, font_scale, + colors, thickness - 1) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def draw_bboxes(self, + bboxes: Union[np.ndarray, torch.Tensor], + edge_colors: Union[str, tuple, List[str], + List[tuple]] = 'g', + line_widths: Union[Union[int, float], + List[Union[int, float]]] = 2, + **kwargs) -> 'Visualizer': + """Draw single or multiple bboxes. + + Args: + bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw with + the format of(x1,y1,x2,y2). + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of bboxes. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, all + the lines will have the same colors. Refer to `matplotlib. + colors` for full list of formats that are accepted. + Defaults to 'g'. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single + value, all the lines will have the same linestyle. + Reference to + https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle + for more details. Defaults to '-'. + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have + the same length with lines or just single value. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. + face_colors (Union[str, tuple, List[str], List[tuple]]): + The face colors. Defaults to None. + alpha (Union[int, float]): The transparency of bboxes. + Defaults to 0.8. + """ + if self.backend == 'matplotlib': + super().draw_bboxes( + bboxes=bboxes, + edge_colors=edge_colors, + line_widths=line_widths, + **kwargs) + + elif self.backend == 'opencv': + self._image = mmcv.imshow_bboxes( + self._image, + bboxes, + edge_colors, + top_k=-1, + thickness=line_widths, + show=False) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def draw_lines(self, + x_datas: Union[np.ndarray, torch.Tensor], + y_datas: Union[np.ndarray, torch.Tensor], + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_widths: Union[Union[int, float], + List[Union[int, float]]] = 2, + **kwargs) -> 'Visualizer': + """Draw single or multiple line segments. + + Args: + x_datas (Union[np.ndarray, torch.Tensor]): The x coordinate of + each line' start and end points. + y_datas (Union[np.ndarray, torch.Tensor]): The y coordinate of + each line' start and end points. + colors (Union[str, tuple, List[str], List[tuple]]): The colors of + lines. ``colors`` can have the same length with lines or just + single value. If ``colors`` is single value, all the lines + will have the same colors. Reference to + https://matplotlib.org/stable/gallery/color/named_colors.html + for more details. Defaults to 'g'. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single + value, all the lines will have the same linestyle. + Reference to + https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle + for more details. Defaults to '-'. + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have + the same length with lines or just single value. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. + """ + if self.backend == 'matplotlib': + super().draw_lines( + x_datas=x_datas, + y_datas=y_datas, + colors=colors, + line_widths=line_widths, + **kwargs) + + elif self.backend == 'opencv': + + self._image = cv2.line( + self._image, (x_datas[0], y_datas[0]), + (x_datas[1], y_datas[1]), + colors, + thickness=line_widths) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def draw_polygons(self, + polygons: Union[Union[np.ndarray, torch.Tensor], + List[Union[np.ndarray, torch.Tensor]]], + edge_colors: Union[str, tuple, List[str], + List[tuple]] = 'g', + **kwargs) -> 'Visualizer': + """Draw single or multiple bboxes. + + Args: + polygons (Union[Union[np.ndarray, torch.Tensor],\ + List[Union[np.ndarray, torch.Tensor]]]): The polygons to draw + with the format of (x1,y1,x2,y2,...,xn,yn). + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of polygons. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, + all the lines will have the same colors. Refer to + `matplotlib.colors` for full list of formats that are accepted. + Defaults to 'g. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single + value, all the lines will have the same linestyle. + Reference to + https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle + for more details. Defaults to '-'. + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have + the same length with lines or just single value. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. + face_colors (Union[str, tuple, List[str], List[tuple]]): + The face colors. Defaults to None. + alpha (Union[int, float]): The transparency of polygons. + Defaults to 0.8. + """ + if self.backend == 'matplotlib': + super().draw_polygons( + polygons=polygons, edge_colors=edge_colors, **kwargs) + + elif self.backend == 'opencv': + + self._image = cv2.fillConvexPoly(self._image, polygons, + edge_colors) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def show(self, + drawn_img: Optional[np.ndarray] = None, + win_name: str = 'image', + wait_time: float = 0., + continue_key=' ') -> None: + """Show the drawn image. + + Args: + drawn_img (np.ndarray, optional): The image to show. If drawn_img + is None, it will show the image got by Visualizer. Defaults + to None. + win_name (str): The image title. Defaults to 'image'. + wait_time (float): Delay in seconds. 0 is the special + value that means "forever". Defaults to 0. + continue_key (str): The key for users to continue. Defaults to + the space key. + """ + if self.backend == 'matplotlib': + super().show( + drawn_img=drawn_img, + win_name=win_name, + wait_time=wait_time, + continue_key=continue_key) + + elif self.backend == 'opencv': + # Keep images are shown in the same window, and the title of window + # will be updated with `win_name`. + if not hasattr(self, win_name): + self._cv_win_name = win_name + cv2.namedWindow(winname=f'{id(self)}') + cv2.setWindowTitle(f'{id(self)}', win_name) + else: + cv2.setWindowTitle(f'{id(self)}', win_name) + shown_img = self.get_image() if drawn_img is None else drawn_img + cv2.imshow(str(id(self)), mmcv.bgr2rgb(shown_img)) + cv2.waitKey(int(np.ceil(wait_time * 1000))) + else: + raise ValueError(f'got unsupported backend {self.backend}') diff --git a/tests/test_visualization/test_fast_visualizer.py b/tests/test_visualization/test_fast_visualizer.py deleted file mode 100644 index f4a24ca1f9..0000000000 --- a/tests/test_visualization/test_fast_visualizer.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -import numpy as np - -from mmpose.visualization import FastVisualizer - - -class TestFastVisualizer(TestCase): - - def setUp(self): - self.metainfo = { - 'keypoint_id2name': { - 0: 'nose', - 1: 'left_eye', - 2: 'right_eye' - }, - 'keypoint_name2id': { - 'nose': 0, - 'left_eye': 1, - 'right_eye': 2 - }, - 'keypoint_colors': np.array([[255, 0, 0], [0, 255, 0], [0, 0, - 255]]), - 'skeleton_links': [(0, 1), (1, 2)], - 'skeleton_link_colors': np.array([[255, 255, 0], [255, 0, 255]]) - } - self.visualizer = FastVisualizer(self.metainfo) - - def test_init(self): - self.assertEqual(self.visualizer.radius, 6) - self.assertEqual(self.visualizer.line_width, 3) - self.assertEqual(self.visualizer.kpt_thr, 0.3) - self.assertEqual(self.visualizer.keypoint_id2name, - self.metainfo['keypoint_id2name']) - self.assertEqual(self.visualizer.keypoint_name2id, - self.metainfo['keypoint_name2id']) - np.testing.assert_array_equal(self.visualizer.keypoint_colors, - self.metainfo['keypoint_colors']) - self.assertEqual(self.visualizer.skeleton_links, - self.metainfo['skeleton_links']) - np.testing.assert_array_equal(self.visualizer.skeleton_link_colors, - self.metainfo['skeleton_link_colors']) - - def test_draw_pose(self): - img = np.zeros((480, 640, 3), dtype=np.uint8) - instances = type('Instances', (object, ), {})() - instances.keypoints = np.array([[[100, 100], [200, 200], [300, 300]]], - dtype=np.float32) - instances.keypoint_scores = np.array([[0.5, 0.5, 0.5]], - dtype=np.float32) - - self.visualizer.draw_pose(img, instances) - - # Check if keypoints are drawn - self.assertNotEqual(img[100, 100].tolist(), [0, 0, 0]) - self.assertNotEqual(img[200, 200].tolist(), [0, 0, 0]) - self.assertNotEqual(img[300, 300].tolist(), [0, 0, 0]) - - # Check if skeleton links are drawn - self.assertNotEqual(img[150, 150].tolist(), [0, 0, 0]) - self.assertNotEqual(img[250, 250].tolist(), [0, 0, 0]) - - def test_draw_pose_with_none_instances(self): - img = np.zeros((480, 640, 3), dtype=np.uint8) - instances = None - - self.visualizer.draw_pose(img, instances) - - # Check if the image is still empty (black) - self.assertEqual(np.count_nonzero(img), 0)