Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mmdet adapter #1545

Merged
merged 30 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1a4aed8
finish pre-hooks
A-Jacobson Sep 21, 2022
9f9a420
mmdet adapter, map, tests, docs
A-Jacobson Sep 21, 2022
dc72ae0
Merge branch 'dev' into mmdet-adapter
A-Jacobson Sep 21, 2022
c5f458f
fix docs
A-Jacobson Sep 23, 2022
42039d6
fix docs
A-Jacobson Sep 23, 2022
b0dd17b
Merge branch 'dev' into mmdet-adapter
A-Jacobson Sep 23, 2022
f1c29e5
fix map imports
A-Jacobson Sep 23, 2022
45cf64a
fix map docstring linting
A-Jacobson Sep 24, 2022
58ed739
testcode -> codeblock to avoid mmdet requirement
A-Jacobson Sep 26, 2022
b35f9ac
clarify model results format, update docs
A-Jacobson Sep 26, 2022
8b98280
Update composer/models/mmdetection.py
A-Jacobson Sep 26, 2022
abdd360
Update composer/models/mmdetection.py
A-Jacobson Sep 26, 2022
950463f
Update composer/metrics/map.py
A-Jacobson Sep 26, 2022
b91f957
Merge branch 'dev' into mmdet-adapter
A-Jacobson Sep 26, 2022
de590c3
Update composer/metrics/map.py
A-Jacobson Sep 26, 2022
32100ad
Merge branch 'dev' into mmdet-adapter
A-Jacobson Oct 3, 2022
b4b2ad9
Merge branch 'dev' into mmdet-adapter
A-Jacobson Oct 5, 2022
a22321a
cleanup map docs
A-Jacobson Oct 5, 2022
615d770
more docstring updates
A-Jacobson Oct 5, 2022
23d2130
Merge branch 'dev' into mmdet-adapter
Landanjs Oct 7, 2022
37c9a41
Update composer/metrics/map.py
A-Jacobson Oct 10, 2022
eac7d19
Update composer/metrics/map.py
A-Jacobson Oct 10, 2022
d031cbc
Update composer/models/mmdetection.py
A-Jacobson Oct 10, 2022
17b4c3b
Update composer/models/mmdetection.py
A-Jacobson Oct 10, 2022
4c5d0a5
Update composer/metrics/map.py
A-Jacobson Oct 10, 2022
a14465a
remove docs about import error
A-Jacobson Oct 10, 2022
ba6d730
Merge branch 'dev' into mmdet-adapter
A-Jacobson Oct 10, 2022
ffdf22a
format docstring
A-Jacobson Oct 10, 2022
ee1be91
reformat docstrings
A-Jacobson Oct 10, 2022
06cacd0
Merge branch 'dev' into mmdet-adapter
A-Jacobson Oct 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ repos:
- id: check-added-large-files
- id: check-ast
- id: check-builtin-literals
args:
A-Jacobson marked this conversation as resolved.
Show resolved Hide resolved
- --no-allow-dict-kwargs
- id: check-case-conflict
- id: check-docstring-first
- id: check-executables-have-shebangs
Expand Down
3 changes: 2 additions & 1 deletion composer/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

"""A collection of common torchmetrics."""

from composer.metrics.map import MAP
from composer.metrics.metrics import CrossEntropy, Dice, LossMetric, MIoU
from composer.metrics.nlp import BinaryF1Score, HFCrossEntropy, LanguageCrossEntropy, MaskedAccuracy, Perplexity

__all__ = [
'MIoU', 'Dice', 'CrossEntropy', 'LossMetric', 'Perplexity', 'BinaryF1Score', 'HFCrossEntropy',
'MAP', 'MIoU', 'Dice', 'CrossEntropy', 'LossMetric', 'Perplexity', 'BinaryF1Score', 'HFCrossEntropy',
'LanguageCrossEntropy', 'MaskedAccuracy'
]
363 changes: 363 additions & 0 deletions composer/metrics/map.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions composer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from composer.models.gpt2 import create_gpt2 as create_gpt2
from composer.models.huggingface import HuggingFaceModel as HuggingFaceModel
from composer.models.initializers import Initializer as Initializer
from composer.models.mmdetection import MMDetModel as MMDetModel
from composer.models.model_hparams import ModelHparams as ModelHparams
from composer.models.resnet import ResNetHparams as ResNetHparams
from composer.models.resnet import composer_resnet as composer_resnet
Expand Down
124 changes: 124 additions & 0 deletions composer/models/mmdetection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""A wrapper class that converts mmdet detection models to composer models"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, List, Optional

import numpy as np
import torch
from torchmetrics import Metric
from torchmetrics.collections import MetricCollection

from composer.models import ComposerModel

if TYPE_CHECKING:
import mmdet

__all__ = ['MMDetModel']


class MMDetModel(ComposerModel):
"""A wrapper class that adapts mmdetection detectors to composer models.

Args:
model (mmdet.models.detectors.BaseDetector): An MMdetection Detector.
metrics (list[Metric], optional): list of torchmetrics to apply to the output of `eval_forward`. Default: ``None``.

.. warning:: This wrapper is designed to work with mmdet datasets.

Example:

.. code-block:: python

from mmdet.models import build_model
from mmcv import ConfigDict
from composer.models import MMDetModel

yolox_s_config = dict(
type='YOLOX',
input_size=(640, 640),
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
bbox_head=dict(type='YOLOXHead', num_classes=num_classes, in_channels=128, feat_channels=128),
A-Jacobson marked this conversation as resolved.
Show resolved Hide resolved
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
yolox = build_model(ConfigDict(yolox_s_config))
yolox.init_weights()
model = MMDetModel(yolox)
"""

def __init__(
self,
model: mmdet.models.detectors.BaseDetector, # type: ignore
metrics: Optional[List[Metric]] = None) -> None:
super().__init__()
self.model = model

self.train_metrics = None
self.val_metrics = None

if metrics:
metric_collection = MetricCollection(metrics)
self.train_metrics = metric_collection.clone(prefix='train_')
self.val_metrics = metric_collection.clone(prefix='val_')

def forward(self, batch):
# this will return a dictionary of losses in train mode and model outputs in test mode.
return self.model(**batch)

def loss(self, outputs, batch, **kwargs):
return outputs

def eval_forward(self, batch, outputs: Optional[Any] = None):
"""
Args:
batch (dict): a eval batch of the format:


``img`` (List[torch.Tensor]): list of image torch.Tensors of shape (batch, c, h , w).


``img_metas`` (List[Dict]): (1, batch_size) list of ``image_meta`` dicts.
Returns: model predictions: A batch_size length list of dictionaries containg detection boxes in (x,y, x2, y2) format, class labels, and class probabilities.
"""
device = batch['img'][0].device
batch.pop('gt_labels')
batch.pop('gt_bboxes')
results = self.model(return_loss=False, rescale=True, **batch) # models behave differently in eval mode

# outputs are a list of bbox results (x, y, x2, y2, score)
# pack mmdet bounding boxes and labels into the format for torchmetrics MAP expects
preds = []
for bbox_result in results:
boxes_scores = np.vstack(bbox_result)
A-Jacobson marked this conversation as resolved.
Show resolved Hide resolved
boxes, scores = torch.from_numpy(boxes_scores[..., :-1]).to(device), torch.from_numpy(
boxes_scores[..., -1]).to(device)
labels = [np.full(result.shape[0], i, dtype=np.int32) for i, result in enumerate(bbox_result)]
pred = {
'labels': torch.from_numpy(np.concatenate(labels)).to(device).long(),
'boxes': boxes.float(),
'scores': scores.float()
}
preds.append(pred)
return preds

def get_metrics(self, is_train: bool = False):
if is_train:
metrics = self.train_metrics
else:
metrics = self.val_metrics
return metrics if metrics else {}

def update_metric(self, batch: Any, outputs: Any, metric: Metric):
targets_box = batch.pop('gt_bboxes')[0]
targets_cls = batch.pop('gt_labels')[0]
targets = []
for i in range(len(targets_box)):
t = {'boxes': targets_box[i], 'labels': targets_cls[i]}
targets.append(t)
metric.update(outputs, targets)
29 changes: 29 additions & 0 deletions docs/source/composer_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,37 @@ and make it compatible with our trainer.
# composer model, ready to be passed to our trainer
composer_model = HuggingFaceModel(model, metrics=metrics)

YOLOX Example with MMDetection
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In this example, we create a YOLO model loaded from MMDetection
and make it compatible with our trainer.

.. code:: python

from mmdet.models import build_model
from mmcv import ConfigDict
from composer.models import MMDetModel

# yolox config from https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox/yolox_s_8x8_300e_coco.py
yolox_s_config = dict(
type='YOLOX',
input_size=(640, 640),
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
bbox_head=dict(type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128),
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))

yolox = build_model(ConfigDict(yolox_s_config))
yolox.init_weights()
model = MMDetModel(yolox)

.. |forward| replace:: :meth:`~.ComposerModel.forward`
.. |loss| replace:: :meth:`~.ComposerModel.loss`
.. _MMDetection: https://mmdetection.readthedocs.io/en/latest/
.. _Transformers: https://huggingface.co/docs/transformers/index
.. _TIMM: https://fastai.github.io/timmdocs/
.. _torchvision: https://pytorch.org/vision/stable/models.html
200 changes: 200 additions & 0 deletions tests/models/test_mmdet_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch


@pytest.fixture
def mmdet_detection_batch():
batch_size = 2
num_labels_per_image = 20
image_size = 224
return {
'img_metas': [{
'filename': '../../data/coco/train2017/fake_img.jpg',
'ori_filename': 'fake_image.jpg',
'img_shape': (image_size, image_size, 3),
'ori_shape': (image_size, image_size, 3),
'pad_shape': (image_size, image_size, 3),
'scale_factor': np.array([1., 1., 1., 1.], dtype=np.float32)
}] * batch_size,
'img':
torch.zeros(batch_size, 3, image_size, image_size, dtype=torch.float32),
'gt_bboxes': [torch.zeros(num_labels_per_image, 4, dtype=torch.float32)] * batch_size,
'gt_labels': [torch.zeros(num_labels_per_image, dtype=torch.int64)] * batch_size
}


@pytest.fixture
def mmdet_detection_eval_batch():
# Eval settings for mmdetection datasets have an extra list around inputs.
batch_size = 2
num_labels_per_image = 20
image_size = 224
return {
'img_metas': [[{
'filename': '../../data/coco/train2017/fake_img.jpg',
'ori_filename': 'fake_image.jpg',
'img_shape': (image_size, image_size, 3),
'ori_shape': (image_size, image_size, 3),
'pad_shape': (image_size, image_size, 3),
'scale_factor': np.array([1., 1., 1., 1.], dtype=np.float32),
}] * batch_size],
'img': [torch.zeros(batch_size, 3, image_size, image_size, dtype=torch.float32)],
'gt_bboxes': [[torch.zeros(num_labels_per_image, 4, dtype=torch.float32)] * batch_size],
'gt_labels': [[torch.zeros(num_labels_per_image, dtype=torch.int64)] * batch_size]
}


@pytest.fixture
def yolox_config():
# from https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox/yolox_s_8x8_300e_coco.py
return dict(
type='YOLOX',
input_size=(640, 640),
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
bbox_head=dict(type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128),
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
# In order to align the source code, the threshold of the val phase is
# 0.01, and the threshold of the test phase is 0.001.
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))


@pytest.fixture
def faster_rcnn_config():
# modified from https://github.com/open-mmlab/mmdetection/blob/master/configs/_base_/models/faster_rcnn_r50_fpn.py
return dict(
type='FasterRCNN',
backbone=dict(type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5),
rpn_head=dict(type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(type='StandardRoIHead',
bbox_roi_extractor=dict(type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(rpn=dict(assigner=dict(type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(assigner=dict(type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(nms_pre=1000, max_per_img=1000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0),
rcnn=dict(score_thr=0.05, nms=dict(type='nms', iou_threshold=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
))


def test_mmdet_model_forward_yolox(mmdet_detection_batch, yolox_config):
pytest.importorskip('mmdet')

from mmcv import ConfigDict
from mmdet.models import build_detector

from composer.models import MMDetModel

config = ConfigDict(yolox_config)
# non pretrained model to avoid a slow test that downloads the weights.
model = build_detector(config)
model.init_weights()
model = MMDetModel(model=model)
out = model(mmdet_detection_batch)
assert list(out.keys()) == ['loss_cls', 'loss_bbox', 'loss_obj']


def test_mmdet_model_eval_forward_yolox(mmdet_detection_eval_batch, yolox_config):
pytest.importorskip('mmdet')

from mmcv import ConfigDict
from mmdet.models import build_detector

from composer.models import MMDetModel

config = ConfigDict(yolox_config)
# non pretrained model to avoid a slow test that downloads the weights.
model = build_detector(config)
model.init_weights()
model = MMDetModel(model=model)
out = model.eval_forward(mmdet_detection_eval_batch)
assert len(out) == mmdet_detection_eval_batch['img'][0].shape[0] # batch size
assert list(out[0].keys()) == ['labels', 'boxes', 'scores']


def test_mmdet_model_forward_faster_rcnn(mmdet_detection_batch, faster_rcnn_config):
pytest.importorskip('mmdet')

from mmcv import ConfigDict
from mmdet.models import build_detector

from composer.models import MMDetModel

config = ConfigDict(faster_rcnn_config)

# non pretrained model to avoid a slow test that downloads the weights.
model = build_detector(config)
model.init_weights()
model = MMDetModel(model=model)
out = model(mmdet_detection_batch)
assert list(out.keys()) == ['loss_rpn_cls', 'loss_rpn_bbox', 'loss_cls', 'acc', 'loss_bbox']