diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c21b19fff1..f17b013e00 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,8 +43,6 @@ repos: - id: check-added-large-files - id: check-ast - id: check-builtin-literals - args: - - --no-allow-dict-kwargs - id: check-case-conflict - id: check-docstring-first - id: check-executables-have-shebangs diff --git a/composer/metrics/__init__.py b/composer/metrics/__init__.py index 8ba0ec5b46..ff36cfb870 100644 --- a/composer/metrics/__init__.py +++ b/composer/metrics/__init__.py @@ -3,10 +3,11 @@ """A collection of common torchmetrics.""" +from composer.metrics.map import MAP from composer.metrics.metrics import CrossEntropy, Dice, LossMetric, MIoU from composer.metrics.nlp import BinaryF1Score, HFCrossEntropy, LanguageCrossEntropy, MaskedAccuracy, Perplexity __all__ = [ - 'MIoU', 'Dice', 'CrossEntropy', 'LossMetric', 'Perplexity', 'BinaryF1Score', 'HFCrossEntropy', + 'MAP', 'MIoU', 'Dice', 'CrossEntropy', 'LossMetric', 'Perplexity', 'BinaryF1Score', 'HFCrossEntropy', 'LanguageCrossEntropy', 'MaskedAccuracy' ] diff --git a/composer/metrics/map.py b/composer/metrics/map.py new file mode 100644 index 0000000000..869771d47b --- /dev/null +++ b/composer/metrics/map.py @@ -0,0 +1,363 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/Lightning-AI/metrics/blob/1c42f6643f9241089e55a4d899f14da480021e16/torchmetrics/detection/map.py +# Current versions of MAP in torchmetrics are incorrect or slow as of 9/21/22. +# Relevant issues: +# https://github.com/Lightning-AI/metrics/issues/1024 +# https://github.com/Lightning-AI/metrics/issues/1164 + +"""MAP torchmetric for object detection.""" +#type: ignore +import logging +import sys +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import torch +from torch import Tensor +from torchmetrics.metric import Metric +from torchvision.ops import box_convert + +from composer.utils.import_helpers import MissingConditionalImportError + +__all__ = ['MAP'] + +log = logging.getLogger(__name__) + + +@dataclass +class MAPMetricResults: + """Dataclass to wrap the final mAP results.""" + map: Tensor + map_50: Tensor + map_75: Tensor + map_small: Tensor + map_medium: Tensor + map_large: Tensor + mar_1: Tensor + mar_10: Tensor + mar_100: Tensor + mar_small: Tensor + mar_medium: Tensor + mar_large: Tensor + map_per_class: Tensor + mar_100_per_class: Tensor + + def __getitem__(self, key: str) -> Union[Tensor, List[Tensor]]: + return getattr(self, key) + + +# noinspection PyMethodMayBeStatic +class WriteToLog: + """Logging class to move logs to log.debug().""" + + def write(self, buf: str) -> None: # skipcq: PY-D0003, PYL-R0201 + for line in buf.rstrip().splitlines(): + log.debug(line.rstrip()) + + def flush(self) -> None: # skipcq: PY-D0003, PYL-R0201 + for handler in log.handlers: + handler.flush() + + def close(self) -> None: # skipcq: PY-D0003, PYL-R0201 + for handler in log.handlers: + handler.close() + + +class _hide_prints: + """Internal helper context to suppress the default output of the pycocotools package.""" + + def __init__(self) -> None: + self._original_stdout = None + + def __enter__(self) -> None: + self._original_stdout = sys.stdout # type: ignore + sys.stdout = WriteToLog() # type: ignore + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore + sys.stdout.close() + sys.stdout = self._original_stdout # type: ignore + + +def _input_validator(preds: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]]) -> None: + """Ensure the correct input format of `preds` and `targets`.""" + if not isinstance(preds, Sequence): + raise ValueError('Expected argument `preds` to be of type List') + if not isinstance(targets, Sequence): + raise ValueError('Expected argument `target` to be of type List') + if len(preds) != len(targets): + raise ValueError('Expected argument `preds` and `target` to have the same length') + + for k in ['boxes', 'scores', 'labels']: + if any(k not in p for p in preds): + raise ValueError(f'Expected all dicts in `preds` to contain the `{k}` key') + + for k in ['boxes', 'labels']: + if any(k not in p for p in targets): + raise ValueError(f'Expected all dicts in `target` to contain the `{k}` key') + + if any(type(pred['boxes']) is not torch.Tensor for pred in preds): + raise ValueError('Expected all boxes in `preds` to be of type torch.Tensor') + if any(type(pred['scores']) is not torch.Tensor for pred in preds): + raise ValueError('Expected all scores in `preds` to be of type torch.Tensor') + if any(type(pred['labels']) is not torch.Tensor for pred in preds): + raise ValueError('Expected all labels in `preds` to be of type torch.Tensor') + if any(type(target['boxes']) is not torch.Tensor for target in targets): + raise ValueError('Expected all boxes in `target` to be of type torch.Tensor') + if any(type(target['labels']) is not torch.Tensor for target in targets): + raise ValueError('Expected all labels in `target` to be of type torch.Tensor') + + for i, item in enumerate(targets): + if item['boxes'].size(0) != item['labels'].size(0): + raise ValueError( + f'Input boxes and labels of sample {i} in targets have a' + f" different length (expected {item['boxes'].size(0)} labels, got {item['labels'].size(0)})") + for i, item in enumerate(preds): + if item['boxes'].size(0) != item['labels'].size(0) != item['scores'].size(0): + raise ValueError(f'Input boxes, labels and scores of sample {i} in preds have a' + f" different length (expected {item['boxes'].size(0)} labels and scores," + f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})") + + +class MAP(Metric): + """Computes the Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR) for object detection predictions. + + Optionally, the mAP and mAR values can be calculated per class. + Predicted boxes and targets have to be in Pascal VOC format \ + (xmin-top left, ymin-top left, xmax-bottom right, ymax-bottom right). + See the :meth:`update` method for more information about the input format to this metric. + See `this blog `_ for more details on (mAP) + and (mAR). + + .. warning:: This metric is a wrapper for the `pycocotools `_, + which is a standard implementation for the mAP metric for object detection. Using this metric + therefore requires you to have `pycocotools` installed. Please install with ``pip install pycocotools`` + + .. warning:: As the pycocotools library cannot deal with tensors directly, all results have to be transfered + to the CPU, this may have an performance impact on your training. + + Args: + class_metrics (bool, optional): Option to enable per-class metrics for mAP and mAR_100. Has a performance impact. Default: ``False``. + compute_on_step (bool, optional): Forward only calls ``update()`` and return ``None`` if this is set to ``False``. Default: ``False``. + dist_sync_on_step (bool, optional): Synchronize metric state across processes at each ``forward()`` before returning the value at the step. Default: ``False``. + process_group (any, optional): Specify the process group on which synchronization is called. Default: ``None`` (which selects the entire world). + dist_sync_fn (callable, optional): Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the all_gather. Default: ``None``. + + Raises: + ValueError: If ``class_metrics`` is not a boolean. + """ + + def __init__( + self, + class_metrics: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, # type: ignore + ) -> None: # type: ignore + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + try: + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group='coco', conda_package='pycocotools') from e + + self.COCO = COCO + self.COCOeval = COCOeval + + if not isinstance(class_metrics, bool): + raise ValueError('Expected argument `class_metrics` to be a boolean') + self.class_metrics = class_metrics + + self.add_state('detection_boxes', default=[], dist_reduce_fx=None) + self.add_state('detection_scores', default=[], dist_reduce_fx=None) + self.add_state('detection_labels', default=[], dist_reduce_fx=None) + self.add_state('groundtruth_boxes', default=[], dist_reduce_fx=None) + self.add_state('groundtruth_labels', default=[], dist_reduce_fx=None) + + def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore + """Add detections and groundtruth to the metric. + + Args: + preds (list[Dict[str, ~torch.Tensor]]): A list of dictionaries containing the key-values: + + ``boxes`` (torch.FloatTensor): [num_boxes, 4] predicted boxes of the format [xmin, ymin, xmax, ymax] in absolute image coordinates. + + ``scores`` (torch.FloatTensor): of shape [num_boxes] containing detection scores for the boxes. + + ``labels`` (torch.IntTensor): of shape [num_boxes] containing 0-indexed detection classes for the boxes. + + target (list[Dict[str, ~torch.Tensor]]): A list of dictionaries containing the key-values: + + ``boxes`` (torch.FloatTensor): [num_boxes, 4] ground truth boxes of the format [xmin, ymin, xmax, ymax] in absolute image coordinates. + + ``labels`` (torch.IntTensor): of shape [num_boxes] containing 1-indexed groundtruth classes for the boxes. + + Raises: + ValueError: If ``preds`` and ``target`` are not of the same length. + ValueError: If any of ``preds.boxes``, ``preds.scores`` and ``preds.labels`` are not of the same length. + ValueError: If any of ``target.boxes`` and ``target.labels`` are not of the same length. + ValueError: If any box is not type float and of length 4. + ValueError: If any class is not type int and of length 1. + ValueError: If any score is not type float and of length 1. + """ + _input_validator(preds, target) + + for item in preds: + self.detection_boxes.append(item['boxes']) # type: ignore + self.detection_scores.append(item['scores']) # type: ignore + self.detection_labels.append(item['labels']) # type: ignore + + for item in target: + self.groundtruth_boxes.append(item['boxes']) # type: ignore + self.groundtruth_labels.append(item['labels']) # type: ignore + + def compute(self) -> dict: + """Compute the Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR) scores. + + All detections added in the ``update()`` method are included. + + Note: + Main `map` score is calculated with @[ IoU=0.50:0.95 | area=all | maxDets=100 ] + + Returns: + MAPMetricResults (dict): containing: + + ``map`` (torch.Tensor): map at 95 iou. + + ``map_50`` (torch.Tensor): map at 50 iou. + + ``map_75`` (torch.Tensor): map at 75 iou. + + ``map_small`` (torch.Tensor): map at 95 iou for small objects. + + ``map_medium`` (torch.Tensor): map at 95 iou for medium objects. + + ``map_large`` (torch.Tensor): map at 95 iou for large objects. + + ``mar_1`` (torch.Tensor): mar at 1 max detection. + + ``mar_10`` (torch.Tensor): mar at 10 max detections. + + ``mar_100`` (torch.Tensor): mar at 100 max detections. + + ``mar_small`` (torch.Tensor): mar at 100 max detections for small objects. + + ``mar_medium`` (torch.Tensor): mar at 100 max detections for medium objects. + + ``mar_large`` (torch.Tensor): mar at 100 max detections for large objects. + + ``map_per_class`` (torch.Tensor) (-1 if class metrics are disabled): map value for each class. + + ``mar_100_per_class`` (torch.Tensor) (-1 if class metrics are disabled): mar at 100 detections for each class. + """ + coco_target, coco_preds = self.COCO(), self.COCO() # type: ignore + coco_target.dataset = self._get_coco_format(self.groundtruth_boxes, self.groundtruth_labels) # type: ignore + coco_preds.dataset = self._get_coco_format( # type: ignore + self.detection_boxes, # type: ignore + self.detection_labels, # type: ignore + self.detection_scores) # type: ignore + + with _hide_prints(): + coco_target.createIndex() + coco_preds.createIndex() + coco_eval = self.COCOeval(coco_target, coco_preds, 'bbox') # type: ignore + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + stats = coco_eval.stats + + map_per_class_values: Tensor = torch.Tensor([-1]) + mar_100_per_class_values: Tensor = torch.Tensor([-1]) + # if class mode is enabled, evaluate metrics per class + if self.class_metrics: + map_per_class_list = [] + mar_100_per_class_list = [] + for class_id in torch.cat(self.detection_labels + + self.groundtruth_labels).unique().cpu().tolist(): # type: ignore + coco_eval.params.catIds = [class_id] + with _hide_prints(): + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + class_stats = coco_eval.stats + + map_per_class_list.append(torch.Tensor([class_stats[0]])) + mar_100_per_class_list.append(torch.Tensor([class_stats[8]])) + map_per_class_values = torch.Tensor(map_per_class_list) + mar_100_per_class_values = torch.Tensor(mar_100_per_class_list) + + metrics = MAPMetricResults( + map=torch.Tensor([stats[0]]), + map_50=torch.Tensor([stats[1]]), + map_75=torch.Tensor([stats[2]]), + map_small=torch.Tensor([stats[3]]), + map_medium=torch.Tensor([stats[4]]), + map_large=torch.Tensor([stats[5]]), + mar_1=torch.Tensor([stats[6]]), + mar_10=torch.Tensor([stats[7]]), + mar_100=torch.Tensor([stats[8]]), + mar_small=torch.Tensor([stats[9]]), + mar_medium=torch.Tensor([stats[10]]), + mar_large=torch.Tensor([stats[11]]), + map_per_class=map_per_class_values, + mar_100_per_class=mar_100_per_class_values, + ) + return metrics.__dict__ + + def _get_coco_format(self, + boxes: List[torch.Tensor], + labels: List[torch.Tensor], + scores: Optional[List[torch.Tensor]] = None) -> Dict: + """Transforms and returns all cached targets or predictions in COCO format. + + Format is defined at https://cocodataset.org/#format-data. + """ + images = [] + annotations = [] + annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong + + boxes = [box_convert(box, in_fmt='xyxy', out_fmt='xywh') for box in boxes] # type: ignore + for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)): + image_boxes = image_boxes.cpu().tolist() + image_labels = image_labels.cpu().tolist() + + images.append({'id': image_id}) + for k, (image_box, image_label) in enumerate(zip(image_boxes, image_labels)): + if len(image_box) != 4: + raise ValueError( + f'Invalid input box of sample {image_id}, element {k} (expected 4 values, got {len(image_box)})' + ) + + if type(image_label) != int: + raise ValueError(f'Invalid input class of sample {image_id}, element {k}' + f' (expected value of type integer, got type {type(image_label)})') + + annotation = { + 'id': annotation_id, + 'image_id': image_id, + 'bbox': image_box, + 'category_id': image_label, + 'area': image_box[2] * image_box[3], + 'iscrowd': 0, + } + if scores is not None: # type: ignore + score = scores[image_id][k].cpu().tolist() # type: ignore + if type(score) != float: # type: ignore + raise ValueError(f'Invalid input score of sample {image_id}, element {k}' + f' (expected value of type float, got type {type(score)})') + annotation['score'] = score + annotations.append(annotation) + annotation_id += 1 + + classes = [{ + 'id': i, + 'name': str(i) + } for i in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist()] # type: ignore + return {'images': images, 'annotations': annotations, 'categories': classes} diff --git a/composer/models/__init__.py b/composer/models/__init__.py index e6aa442c5f..3e53fa517f 100644 --- a/composer/models/__init__.py +++ b/composer/models/__init__.py @@ -24,6 +24,7 @@ from composer.models.gpt2 import create_gpt2 as create_gpt2 from composer.models.huggingface import HuggingFaceModel as HuggingFaceModel from composer.models.initializers import Initializer as Initializer +from composer.models.mmdetection import MMDetModel as MMDetModel from composer.models.model_hparams import ModelHparams as ModelHparams from composer.models.resnet import ResNetHparams as ResNetHparams from composer.models.resnet import composer_resnet as composer_resnet diff --git a/composer/models/mmdetection.py b/composer/models/mmdetection.py new file mode 100644 index 0000000000..837b991cf3 --- /dev/null +++ b/composer/models/mmdetection.py @@ -0,0 +1,124 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""A wrapper class that converts mmdet detection models to composer models""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Optional + +import numpy as np +import torch +from torchmetrics import Metric +from torchmetrics.collections import MetricCollection + +from composer.models import ComposerModel + +if TYPE_CHECKING: + import mmdet + +__all__ = ['MMDetModel'] + + +class MMDetModel(ComposerModel): + """A wrapper class that adapts mmdetection detectors to composer models. + + Args: + model (mmdet.models.detectors.BaseDetector): An MMdetection Detector. + metrics (list[Metric], optional): list of torchmetrics to apply to the output of `eval_forward`. Default: ``None``. + + .. warning:: This wrapper is designed to work with mmdet datasets. + + Example: + + .. code-block:: python + + from mmdet.models import build_model + from mmcv import ConfigDict + from composer.models import MMDetModel + + yolox_s_config = dict( + type='YOLOX', + input_size=(640, 640), + random_size_range=(15, 25), + random_size_interval=10, + backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5), + neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1), + bbox_head=dict(type='YOLOXHead', num_classes=num_classes, in_channels=128, feat_channels=128), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + yolox = build_model(ConfigDict(yolox_s_config)) + yolox.init_weights() + model = MMDetModel(yolox) + """ + + def __init__( + self, + model: mmdet.models.detectors.BaseDetector, # type: ignore + metrics: Optional[List[Metric]] = None) -> None: + super().__init__() + self.model = model + + self.train_metrics = None + self.val_metrics = None + + if metrics: + metric_collection = MetricCollection(metrics) + self.train_metrics = metric_collection.clone(prefix='train_') + self.val_metrics = metric_collection.clone(prefix='val_') + + def forward(self, batch): + # this will return a dictionary of losses in train mode and model outputs in test mode. + return self.model(**batch) + + def loss(self, outputs, batch, **kwargs): + return outputs + + def eval_forward(self, batch, outputs: Optional[Any] = None): + """ + Args: + batch (dict): a eval batch of the format: + + + ``img`` (List[torch.Tensor]): list of image torch.Tensors of shape (batch, c, h , w). + + + ``img_metas`` (List[Dict]): (1, batch_size) list of ``image_meta`` dicts. + Returns: model predictions: A batch_size length list of dictionaries containg detection boxes in (x,y, x2, y2) format, class labels, and class probabilities. + """ + device = batch['img'][0].device + batch.pop('gt_labels') + batch.pop('gt_bboxes') + results = self.model(return_loss=False, rescale=True, **batch) # models behave differently in eval mode + + # outputs are a list of bbox results (x, y, x2, y2, score) + # pack mmdet bounding boxes and labels into the format for torchmetrics MAP expects + preds = [] + for bbox_result in results: + boxes_scores = np.vstack(bbox_result) + boxes, scores = torch.from_numpy(boxes_scores[..., :-1]).to(device), torch.from_numpy( + boxes_scores[..., -1]).to(device) + labels = [np.full(result.shape[0], i, dtype=np.int32) for i, result in enumerate(bbox_result)] + pred = { + 'labels': torch.from_numpy(np.concatenate(labels)).to(device).long(), + 'boxes': boxes.float(), + 'scores': scores.float() + } + preds.append(pred) + return preds + + def get_metrics(self, is_train: bool = False): + if is_train: + metrics = self.train_metrics + else: + metrics = self.val_metrics + return metrics if metrics else {} + + def update_metric(self, batch: Any, outputs: Any, metric: Metric): + targets_box = batch.pop('gt_bboxes')[0] + targets_cls = batch.pop('gt_labels')[0] + targets = [] + for i in range(len(targets_box)): + t = {'boxes': targets_box[i], 'labels': targets_cls[i]} + targets.append(t) + metric.update(outputs, targets) diff --git a/docs/source/composer_model.rst b/docs/source/composer_model.rst index b83aa4971a..06fc3a206c 100644 --- a/docs/source/composer_model.rst +++ b/docs/source/composer_model.rst @@ -233,8 +233,37 @@ and make it compatible with our trainer. # composer model, ready to be passed to our trainer composer_model = HuggingFaceModel(model, metrics=metrics) +YOLOX Example with MMDetection +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In this example, we create a YOLO model loaded from MMDetection +and make it compatible with our trainer. + +.. code:: python + + from mmdet.models import build_model + from mmcv import ConfigDict + from composer.models import MMDetModel + + # yolox config from https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox/yolox_s_8x8_300e_coco.py + yolox_s_config = dict( + type='YOLOX', + input_size=(640, 640), + random_size_range=(15, 25), + random_size_interval=10, + backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5), + neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1), + bbox_head=dict(type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + + yolox = build_model(ConfigDict(yolox_s_config)) + yolox.init_weights() + model = MMDetModel(yolox) + .. |forward| replace:: :meth:`~.ComposerModel.forward` .. |loss| replace:: :meth:`~.ComposerModel.loss` +.. _MMDetection: https://mmdetection.readthedocs.io/en/latest/ .. _Transformers: https://huggingface.co/docs/transformers/index .. _TIMM: https://fastai.github.io/timmdocs/ .. _torchvision: https://pytorch.org/vision/stable/models.html diff --git a/tests/models/test_mmdet_model.py b/tests/models/test_mmdet_model.py new file mode 100644 index 0000000000..fafeeb1ac5 --- /dev/null +++ b/tests/models/test_mmdet_model.py @@ -0,0 +1,200 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + + +@pytest.fixture +def mmdet_detection_batch(): + batch_size = 2 + num_labels_per_image = 20 + image_size = 224 + return { + 'img_metas': [{ + 'filename': '../../data/coco/train2017/fake_img.jpg', + 'ori_filename': 'fake_image.jpg', + 'img_shape': (image_size, image_size, 3), + 'ori_shape': (image_size, image_size, 3), + 'pad_shape': (image_size, image_size, 3), + 'scale_factor': np.array([1., 1., 1., 1.], dtype=np.float32) + }] * batch_size, + 'img': + torch.zeros(batch_size, 3, image_size, image_size, dtype=torch.float32), + 'gt_bboxes': [torch.zeros(num_labels_per_image, 4, dtype=torch.float32)] * batch_size, + 'gt_labels': [torch.zeros(num_labels_per_image, dtype=torch.int64)] * batch_size + } + + +@pytest.fixture +def mmdet_detection_eval_batch(): + # Eval settings for mmdetection datasets have an extra list around inputs. + batch_size = 2 + num_labels_per_image = 20 + image_size = 224 + return { + 'img_metas': [[{ + 'filename': '../../data/coco/train2017/fake_img.jpg', + 'ori_filename': 'fake_image.jpg', + 'img_shape': (image_size, image_size, 3), + 'ori_shape': (image_size, image_size, 3), + 'pad_shape': (image_size, image_size, 3), + 'scale_factor': np.array([1., 1., 1., 1.], dtype=np.float32), + }] * batch_size], + 'img': [torch.zeros(batch_size, 3, image_size, image_size, dtype=torch.float32)], + 'gt_bboxes': [[torch.zeros(num_labels_per_image, 4, dtype=torch.float32)] * batch_size], + 'gt_labels': [[torch.zeros(num_labels_per_image, dtype=torch.int64)] * batch_size] + } + + +@pytest.fixture +def yolox_config(): + # from https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox/yolox_s_8x8_300e_coco.py + return dict( + type='YOLOX', + input_size=(640, 640), + random_size_range=(15, 25), + random_size_interval=10, + backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5), + neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1), + bbox_head=dict(type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + + +@pytest.fixture +def faster_rcnn_config(): + # modified from https://github.com/open-mmlab/mmdetection/blob/master/configs/_base_/models/faster_rcnn_r50_fpn.py + return dict( + type='FasterRCNN', + backbone=dict(type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict(type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), + rpn_head=dict(type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict(type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict(type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict(type='StandardRoIHead', + bbox_roi_extractor=dict(type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict(type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict(type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict(rpn=dict(assigner=dict(type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict(type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict(nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict(assigner=dict(type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict(type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict(nms_pre=1000, max_per_img=1000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0), + rcnn=dict(score_thr=0.05, nms=dict(type='nms', iou_threshold=0.5), max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + )) + + +def test_mmdet_model_forward_yolox(mmdet_detection_batch, yolox_config): + pytest.importorskip('mmdet') + + from mmcv import ConfigDict + from mmdet.models import build_detector + + from composer.models import MMDetModel + + config = ConfigDict(yolox_config) + # non pretrained model to avoid a slow test that downloads the weights. + model = build_detector(config) + model.init_weights() + model = MMDetModel(model=model) + out = model(mmdet_detection_batch) + assert list(out.keys()) == ['loss_cls', 'loss_bbox', 'loss_obj'] + + +def test_mmdet_model_eval_forward_yolox(mmdet_detection_eval_batch, yolox_config): + pytest.importorskip('mmdet') + + from mmcv import ConfigDict + from mmdet.models import build_detector + + from composer.models import MMDetModel + + config = ConfigDict(yolox_config) + # non pretrained model to avoid a slow test that downloads the weights. + model = build_detector(config) + model.init_weights() + model = MMDetModel(model=model) + out = model.eval_forward(mmdet_detection_eval_batch) + assert len(out) == mmdet_detection_eval_batch['img'][0].shape[0] # batch size + assert list(out[0].keys()) == ['labels', 'boxes', 'scores'] + + +def test_mmdet_model_forward_faster_rcnn(mmdet_detection_batch, faster_rcnn_config): + pytest.importorskip('mmdet') + + from mmcv import ConfigDict + from mmdet.models import build_detector + + from composer.models import MMDetModel + + config = ConfigDict(faster_rcnn_config) + + # non pretrained model to avoid a slow test that downloads the weights. + model = build_detector(config) + model.init_weights() + model = MMDetModel(model=model) + out = model(mmdet_detection_batch) + assert list(out.keys()) == ['loss_rpn_cls', 'loss_rpn_bbox', 'loss_cls', 'acc', 'loss_bbox']