diff --git a/mmpose/codecs/annotation_processors.py b/mmpose/codecs/annotation_processors.py index 1df50e82f1..6e85b42cbb 100644 --- a/mmpose/codecs/annotation_processors.py +++ b/mmpose/codecs/annotation_processors.py @@ -10,7 +10,7 @@ NEG_INF = -1e6 -class AnnotationProcessor(BaseKeypointCodec): +class BaseAnnotationProcessor(BaseKeypointCodec): """Base class for annotation processors.""" def decode(self, *args, **kwargs): @@ -18,8 +18,8 @@ def decode(self, *args, **kwargs): @KEYPOINT_CODECS.register_module() -class YOLOXPoseAnnotationProcessor(AnnotationProcessor): - """Processor for YOLOXPose dataset annotations. +class YOLOXPoseAnnotationProcessor(BaseAnnotationProcessor): + """Convert dataset annotations to the input format of YOLOX-Pose. This processor expands bounding boxes and converts category IDs to labels. diff --git a/mmpose/evaluation/functional/nms.py b/mmpose/evaluation/functional/nms.py index d26a7258b9..801fee7764 100644 --- a/mmpose/evaluation/functional/nms.py +++ b/mmpose/evaluation/functional/nms.py @@ -10,6 +10,8 @@ import torch from torch import Tensor +from mmpose.structures.bbox import bbox_overlaps + def nms(dets: np.ndarray, thr: float) -> List[int]: """Greedily select boxes with high confidence and overlap <= thr. @@ -329,37 +331,10 @@ def nearby_joints_nms( return keep_pose_inds -def compute_iou(bbox, bboxes): - """Compute the Intersection-over-Union (IoU) for a bounding box with a list - of bounding boxes. - - Args: - bbox (Tensor): the bounding box (4 elements for x1, y1, x2, y2). - param (bboxes): Tensor, list of bounding boxes with the same format - as bbox. - return (Tensor): the IoU value between bbox and each of the bounding - boxes in bboxes. - """ - - bboxes_overlap = torch.stack(( - bboxes[:, 0].clip(min=bbox[0]), - bboxes[:, 1].clip(min=bbox[1]), - bboxes[:, 2].clip(max=bbox[2]), - bboxes[:, 3].clip(max=bbox[3]), - ), - dim=1) - a0 = torch.prod(bbox[2:] - bbox[:2]) - a1 = torch.prod(bboxes[:, 2:] - bboxes[:, :2], dim=1) - a2 = torch.prod( - (bboxes_overlap[:, 2:] - bboxes_overlap[:, :2]).clip(min=0), dim=1) - iou = a2 / (a1 + a0 - a2) - return iou - - def nms_torch(bboxes: Tensor, scores: Tensor, threshold: float = 0.65, - iou_calculator=compute_iou, + iou_calculator=bbox_overlaps, return_group: bool = False): """Perform Non-Maximum Suppression (NMS) on a set of bounding boxes using their corresponding scores. diff --git a/mmpose/evaluation/metrics/coco_metric.py b/mmpose/evaluation/metrics/coco_metric.py index 8a3cb40413..27838e0eae 100644 --- a/mmpose/evaluation/metrics/coco_metric.py +++ b/mmpose/evaluation/metrics/coco_metric.py @@ -180,8 +180,8 @@ def process(self, data_batch: Sequence[dict], pred['keypoints'] = keypoints pred['keypoint_scores'] = keypoint_scores pred['category_id'] = data_sample.get('category_id', 1) - if 'bbox' in data_sample['pred_instances']: - pred['bboxes'] = bbox_xyxy2xywh( + if 'bboxes' in data_sample['pred_instances']: + pred['bbox'] = bbox_xyxy2xywh( data_sample['pred_instances']['bboxes']) if 'bbox_scores' in data_sample['pred_instances']: @@ -370,7 +370,7 @@ def compute_metrics(self, results: list) -> Dict[str, float]: 'bbox_score': pred['bbox_scores'][idx], } if 'bbox' in pred: - instance['bbox'] = pred['bboxes'][idx] + instance['bbox'] = pred['bbox'][idx] if 'areas' in pred: instance['area'] = pred['areas'][idx] diff --git a/mmpose/models/pose_estimators/base.py b/mmpose/models/pose_estimators/base.py index c58a392c3d..e98b2caeb8 100644 --- a/mmpose/models/pose_estimators/base.py +++ b/mmpose/models/pose_estimators/base.py @@ -25,8 +25,6 @@ class BasePoseEstimator(BaseModel, metaclass=ABCMeta): init_cfg (dict | ConfigDict): The model initialization config. Defaults to ``None`` use_syncbn (bool): whether to use SyncBatchNorm. Defaults to False. - switch_to_deploy (bool): whether to switch the sub-modules to deploy - mode. Defaults to False. metainfo (dict): Meta information for dataset, such as keypoints definition and properties. If set, the metainfo of the input data batch will be overridden. For more details, please refer to @@ -44,12 +42,13 @@ def __init__(self, test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, use_syncbn: bool = False, - switch_to_deploy: bool = False, init_cfg: OptMultiConfig = None, metainfo: Optional[dict] = None): super().__init__( data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.metainfo = self._load_metainfo(metainfo) + self.train_cfg = train_cfg if train_cfg else {} + self.test_cfg = test_cfg if test_cfg else {} self.backbone = MODELS.build(backbone) @@ -64,18 +63,7 @@ def __init__(self, if head is not None: self.head = MODELS.build(head) - - self.train_cfg = train_cfg if train_cfg else {} - self.test_cfg = test_cfg if test_cfg else {} - - # adjust model structure in deploy mode - if switch_to_deploy: - for name, layer in self.named_modules(): - if callable(getattr(layer, 'switch_to_deploy', None)): - print_log( - f'module {name} has been switched to deploy mode', - 'current') - layer.switch_to_deploy() + self.head.test_cfg = self.test_cfg.copy() # Register the hook to automatically convert old version state dicts self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) diff --git a/mmpose/models/pose_estimators/bottomup.py b/mmpose/models/pose_estimators/bottomup.py index bb01e733b6..7b82980a13 100644 --- a/mmpose/models/pose_estimators/bottomup.py +++ b/mmpose/models/pose_estimators/bottomup.py @@ -24,8 +24,6 @@ class BottomupPoseEstimator(BasePoseEstimator): test_cfg (dict, optional): The runtime config for testing process. Defaults to ``None`` use_syncbn (bool): whether to use SyncBatchNorm. Defaults to False. - switch_to_deploy (bool): whether to switch the sub-modules to deploy - mode. Defaults to False. data_preprocessor (dict, optional): The data preprocessing config to build the instance of :class:`BaseDataPreprocessor`. Defaults to ``None``. @@ -40,7 +38,6 @@ def __init__(self, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, use_syncbn: bool = False, - switch_to_deploy: bool = False, data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None): super().__init__( @@ -50,12 +47,9 @@ def __init__(self, train_cfg=train_cfg, test_cfg=test_cfg, use_syncbn=use_syncbn, - switch_to_deploy=switch_to_deploy, data_preprocessor=data_preprocessor, init_cfg=init_cfg) - self.head.test_cfg = self.test_cfg.copy() - def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: """Calculate losses from a batch of inputs and data samples. diff --git a/mmpose/models/task_modules/assigners/sim_ota_assigner.py b/mmpose/models/task_modules/assigners/sim_ota_assigner.py index b2489f209c..69c7ed677e 100644 --- a/mmpose/models/task_modules/assigners/sim_ota_assigner.py +++ b/mmpose/models/task_modules/assigners/sim_ota_assigner.py @@ -18,16 +18,20 @@ class SimOTAAssigner: """Computes matching between predictions and ground truth. Args: - center_radius (float): Ground truth center size - to judge whether a prior is in center. Defaults to 2.5. - candidate_topk (int): The candidate top-k which used to - get top-k ious to calculate dynamic-k. Defaults to 10. - iou_weight (float): The scale factor for regression - iou cost. Defaults to 3.0. - cls_weight (float): The scale factor for classification - cost. Defaults to 1.0. - iou_calculator (ConfigType): Config of overlaps Calculator. + center_radius (float): Radius of center area to determine + if a prior is in the center of a gt. Defaults to 2.5. + candidate_topk (int): Top-k ious candidates to calculate dynamic-k. + Defaults to 10. + iou_weight (float): Weight of bbox iou cost. Defaults to 3.0. + cls_weight (float): Weight of classification cost. Defaults to 1.0. + oks_weight (float): Weight of keypoint OKS cost. Defaults to 3.0. + vis_weight (float): Weight of keypoint visibility cost. Defaults to 0.0 + dynamic_k_indicator (str): Cost type for calculating dynamic-k, + either 'iou' or 'oks'. Defaults to 'iou'. + iou_calculator (dict): Config of IoU calculation method. Defaults to dict(type='BBoxOverlaps2D'). + oks_calculator (dict): Config of OKS calculation method. + Defaults to dict(type='PoseOKS'). """ def __init__(self, @@ -70,7 +74,8 @@ def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, annotations. It usually includes ``bboxes``, with shape (k, 4), and ``labels``, with shape (k, ). Returns: - obj:`AssignResult`: The assigned result. + dict: Assignment result containing assigned gt indices, + max iou overlaps, assigned labels, etc. """ gt_bboxes = gt_instances.bboxes gt_labels = gt_instances.labels diff --git a/mmpose/structures/bbox/bbox_overlaps.py b/mmpose/structures/bbox/bbox_overlaps.py index 012ca3848e..682008c337 100644 --- a/mmpose/structures/bbox/bbox_overlaps.py +++ b/mmpose/structures/bbox/bbox_overlaps.py @@ -51,6 +51,11 @@ def bbox_overlaps(bboxes1, assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0) assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0) + if bboxes1.ndim == 1: + bboxes1 = bboxes1.unsqueeze(0) + if bboxes2.ndim == 1: + bboxes2 = bboxes2.unsqueeze(0) + assert bboxes1.shape[:-2] == bboxes2.shape[:-2] batch_shape = bboxes1.shape[:-2]