Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Aug 29, 2023
1 parent 17fe05a commit bd39e66
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 65 deletions.
6 changes: 3 additions & 3 deletions mmpose/codecs/annotation_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
NEG_INF = -1e6


class AnnotationProcessor(BaseKeypointCodec):
class BaseAnnotationProcessor(BaseKeypointCodec):
"""Base class for annotation processors."""

def decode(self, *args, **kwargs):
pass


@KEYPOINT_CODECS.register_module()
class YOLOXPoseAnnotationProcessor(AnnotationProcessor):
"""Processor for YOLOXPose dataset annotations.
class YOLOXPoseAnnotationProcessor(BaseAnnotationProcessor):
"""Convert dataset annotations to the input format of YOLOX-Pose.
This processor expands bounding boxes and converts category IDs to labels.
Expand Down
31 changes: 3 additions & 28 deletions mmpose/evaluation/functional/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
from torch import Tensor

from mmpose.structures.bbox import bbox_overlaps


def nms(dets: np.ndarray, thr: float) -> List[int]:
"""Greedily select boxes with high confidence and overlap <= thr.
Expand Down Expand Up @@ -329,37 +331,10 @@ def nearby_joints_nms(
return keep_pose_inds


def compute_iou(bbox, bboxes):
"""Compute the Intersection-over-Union (IoU) for a bounding box with a list
of bounding boxes.
Args:
bbox (Tensor): the bounding box (4 elements for x1, y1, x2, y2).
param (bboxes): Tensor, list of bounding boxes with the same format
as bbox.
return (Tensor): the IoU value between bbox and each of the bounding
boxes in bboxes.
"""

bboxes_overlap = torch.stack((
bboxes[:, 0].clip(min=bbox[0]),
bboxes[:, 1].clip(min=bbox[1]),
bboxes[:, 2].clip(max=bbox[2]),
bboxes[:, 3].clip(max=bbox[3]),
),
dim=1)
a0 = torch.prod(bbox[2:] - bbox[:2])
a1 = torch.prod(bboxes[:, 2:] - bboxes[:, :2], dim=1)
a2 = torch.prod(
(bboxes_overlap[:, 2:] - bboxes_overlap[:, :2]).clip(min=0), dim=1)
iou = a2 / (a1 + a0 - a2)
return iou


def nms_torch(bboxes: Tensor,
scores: Tensor,
threshold: float = 0.65,
iou_calculator=compute_iou,
iou_calculator=bbox_overlaps,
return_group: bool = False):
"""Perform Non-Maximum Suppression (NMS) on a set of bounding boxes using
their corresponding scores.
Expand Down
6 changes: 3 additions & 3 deletions mmpose/evaluation/metrics/coco_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def process(self, data_batch: Sequence[dict],
pred['keypoints'] = keypoints
pred['keypoint_scores'] = keypoint_scores
pred['category_id'] = data_sample.get('category_id', 1)
if 'bbox' in data_sample['pred_instances']:
pred['bboxes'] = bbox_xyxy2xywh(
if 'bboxes' in data_sample['pred_instances']:
pred['bbox'] = bbox_xyxy2xywh(
data_sample['pred_instances']['bboxes'])

if 'bbox_scores' in data_sample['pred_instances']:
Expand Down Expand Up @@ -370,7 +370,7 @@ def compute_metrics(self, results: list) -> Dict[str, float]:
'bbox_score': pred['bbox_scores'][idx],
}
if 'bbox' in pred:
instance['bbox'] = pred['bboxes'][idx]
instance['bbox'] = pred['bbox'][idx]

if 'areas' in pred:
instance['area'] = pred['areas'][idx]
Expand Down
18 changes: 3 additions & 15 deletions mmpose/models/pose_estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class BasePoseEstimator(BaseModel, metaclass=ABCMeta):
init_cfg (dict | ConfigDict): The model initialization config.
Defaults to ``None``
use_syncbn (bool): whether to use SyncBatchNorm. Defaults to False.
switch_to_deploy (bool): whether to switch the sub-modules to deploy
mode. Defaults to False.
metainfo (dict): Meta information for dataset, such as keypoints
definition and properties. If set, the metainfo of the input data
batch will be overridden. For more details, please refer to
Expand All @@ -44,12 +42,13 @@ def __init__(self,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
use_syncbn: bool = False,
switch_to_deploy: bool = False,
init_cfg: OptMultiConfig = None,
metainfo: Optional[dict] = None):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.metainfo = self._load_metainfo(metainfo)
self.train_cfg = train_cfg if train_cfg else {}
self.test_cfg = test_cfg if test_cfg else {}

self.backbone = MODELS.build(backbone)

Expand All @@ -64,18 +63,7 @@ def __init__(self,

if head is not None:
self.head = MODELS.build(head)

self.train_cfg = train_cfg if train_cfg else {}
self.test_cfg = test_cfg if test_cfg else {}

# adjust model structure in deploy mode
if switch_to_deploy:
for name, layer in self.named_modules():
if callable(getattr(layer, 'switch_to_deploy', None)):
print_log(
f'module {name} has been switched to deploy mode',
'current')
layer.switch_to_deploy()
self.head.test_cfg = self.test_cfg.copy()

# Register the hook to automatically convert old version state dicts
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
Expand Down
6 changes: 0 additions & 6 deletions mmpose/models/pose_estimators/bottomup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ class BottomupPoseEstimator(BasePoseEstimator):
test_cfg (dict, optional): The runtime config for testing process.
Defaults to ``None``
use_syncbn (bool): whether to use SyncBatchNorm. Defaults to False.
switch_to_deploy (bool): whether to switch the sub-modules to deploy
mode. Defaults to False.
data_preprocessor (dict, optional): The data preprocessing config to
build the instance of :class:`BaseDataPreprocessor`. Defaults to
``None``.
Expand All @@ -40,7 +38,6 @@ def __init__(self,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
use_syncbn: bool = False,
switch_to_deploy: bool = False,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
Expand All @@ -50,12 +47,9 @@ def __init__(self,
train_cfg=train_cfg,
test_cfg=test_cfg,
use_syncbn=use_syncbn,
switch_to_deploy=switch_to_deploy,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)

self.head.test_cfg = self.test_cfg.copy()

def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Expand Down
25 changes: 15 additions & 10 deletions mmpose/models/task_modules/assigners/sim_ota_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ class SimOTAAssigner:
"""Computes matching between predictions and ground truth.
Args:
center_radius (float): Ground truth center size
to judge whether a prior is in center. Defaults to 2.5.
candidate_topk (int): The candidate top-k which used to
get top-k ious to calculate dynamic-k. Defaults to 10.
iou_weight (float): The scale factor for regression
iou cost. Defaults to 3.0.
cls_weight (float): The scale factor for classification
cost. Defaults to 1.0.
iou_calculator (ConfigType): Config of overlaps Calculator.
center_radius (float): Radius of center area to determine
if a prior is in the center of a gt. Defaults to 2.5.
candidate_topk (int): Top-k ious candidates to calculate dynamic-k.
Defaults to 10.
iou_weight (float): Weight of bbox iou cost. Defaults to 3.0.
cls_weight (float): Weight of classification cost. Defaults to 1.0.
oks_weight (float): Weight of keypoint OKS cost. Defaults to 3.0.
vis_weight (float): Weight of keypoint visibility cost. Defaults to 0.0
dynamic_k_indicator (str): Cost type for calculating dynamic-k,
either 'iou' or 'oks'. Defaults to 'iou'.
iou_calculator (dict): Config of IoU calculation method.
Defaults to dict(type='BBoxOverlaps2D').
oks_calculator (dict): Config of OKS calculation method.
Defaults to dict(type='PoseOKS').
"""

def __init__(self,
Expand Down Expand Up @@ -70,7 +74,8 @@ def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
annotations. It usually includes ``bboxes``, with shape (k, 4),
and ``labels``, with shape (k, ).
Returns:
obj:`AssignResult`: The assigned result.
dict: Assignment result containing assigned gt indices,
max iou overlaps, assigned labels, etc.
"""
gt_bboxes = gt_instances.bboxes
gt_labels = gt_instances.labels
Expand Down
5 changes: 5 additions & 0 deletions mmpose/structures/bbox/bbox_overlaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def bbox_overlaps(bboxes1,
assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)

if bboxes1.ndim == 1:
bboxes1 = bboxes1.unsqueeze(0)
if bboxes2.ndim == 1:
bboxes2 = bboxes2.unsqueeze(0)

assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
batch_shape = bboxes1.shape[:-2]

Expand Down

0 comments on commit bd39e66

Please sign in to comment.