Skip to content

Commit

Permalink
[CodeCamp2023-491]Add new configuration files for MaskRCNN algorithm …
Browse files Browse the repository at this point in the history
…in mmdetection. (#10905)
  • Loading branch information
ZhaoCake authored Sep 12, 2023
1 parent dece858 commit 59b0fc5
Show file tree
Hide file tree
Showing 43 changed files with 2,163 additions and 1 deletion.
158 changes: 158 additions & 0 deletions mmdet/configs/_base_/models/mask_rcnn_r50_caffe_c4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.ops import RoIAlign, nms
from mmengine.model.weight_init import PretrainedInit
from torch.nn import BatchNorm2d

from mmdet.models.backbones.resnet import ResNet
from mmdet.models.data_preprocessors.data_preprocessor import \
DetDataPreprocessor
from mmdet.models.dense_heads.rpn_head import RPNHead
from mmdet.models.detectors.mask_rcnn import MaskRCNN
from mmdet.models.layers import ResLayer
from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss
from mmdet.models.losses.smooth_l1_loss import L1Loss
from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead
from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead
from mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor import \
SingleRoIExtractor
from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead
from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner
from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import \
DeltaXYWHBBoxCoder
from mmdet.models.task_modules.prior_generators.anchor_generator import \
AnchorGenerator
from mmdet.models.task_modules.samplers.random_sampler import RandomSampler

# model settings
norm_cfg = dict(type=BatchNorm2d, requires_grad=False)
# model settings
model = dict(
type=MaskRCNN,
data_preprocessor=dict(
type=DetDataPreprocessor,
mean=[103.530, 116.280, 123.675],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=False,
pad_mask=True,
pad_size_divisor=32),
backbone=dict(
type=ResNet,
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
norm_cfg=dict(type=BatchNorm2d, requires_grad=True),
norm_eval=True,
style='caffe',
init_cfg=dict(
type=PretrainedInit,
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
rpn_head=dict(
type=RPNHead,
in_channels=1024,
feat_channels=1024,
anchor_generator=dict(
type=AnchorGenerator,
scales=[2, 4, 8, 16, 32],
ratios=[0.5, 1.0, 2.0],
strides=[16]),
bbox_coder=dict(
type=DeltaXYWHBBoxCoder,
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type=CrossEntropyLoss, use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type=L1Loss, loss_weight=1.0)),
roi_head=dict(
type=StandardRoIHead,
shared_head=dict(
type=ResLayer,
depth=50,
stage=3,
stride=2,
dilation=1,
style='caffe',
norm_cfg=norm_cfg,
norm_eval=True),
bbox_roi_extractor=dict(
type=SingleRoIExtractor,
roi_layer=dict(type=RoIAlign, output_size=14, sampling_ratio=0),
out_channels=1024,
featmap_strides=[16]),
bbox_head=dict(
type=BBoxHead,
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=80,
bbox_coder=dict(
type=DeltaXYWHBBoxCoder,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type=L1Loss, loss_weight=1.0)),
mask_roi_extractor=None,
mask_head=dict(
type=FCNMaskHead,
num_convs=0,
in_channels=2048,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type=CrossEntropyLoss, use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type=MaxIoUAssigner,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type=RandomSampler,
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=12000,
max_per_img=2000,
nms=dict(type=nms, iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type=MaxIoUAssigner,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type=RandomSampler,
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=14,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=6000,
max_per_img=1000,
nms=dict(type=nms, iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type=nms, iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
4 changes: 3 additions & 1 deletion mmdet/configs/_base_/models/mask_rcnn_r50_fpn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.ops import RoIAlign, nms
from mmengine.model.weight_init import PretrainedInit
from torch.nn import BatchNorm2d

from mmdet.models.backbones.resnet import ResNet
Expand Down Expand Up @@ -42,7 +43,8 @@
norm_cfg=dict(type=BatchNorm2d, requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
init_cfg=dict(
type=PretrainedInit, checkpoint='torchvision://resnet50')),
neck=dict(
type=FPN,
in_channels=[256, 512, 1024, 2048],
Expand Down
33 changes: 33 additions & 0 deletions mmdet/configs/_base_/schedules/schedule_2x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR
from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop
from torch.optim.sgd import SGD

# training schedule for 1x
train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=24, val_interval=1)
val_cfg = dict(type=ValLoop)
test_cfg = dict(type=TestLoop)

# learning rate
param_scheduler = [
dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type=MultiStepLR,
begin=0,
end=24,
by_epoch=True,
milestones=[16, 22],
gamma=0.1)
]

# optimizer
optim_wrapper = dict(
type=OptimWrapper,
optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001))

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)
134 changes: 134 additions & 0 deletions mmdet/configs/common/lsj_100e_coco_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) OpenMMLab. All rights reserved.

# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0

from mmengine.config import read_base

with read_base():
from .._base_.default_runtime import *

from mmengine.dataset.sampler import DefaultSampler
from mmengine.optim import OptimWrapper
from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR
from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop
from torch.optim import SGD

from mmdet.datasets import CocoDataset, RepeatDataset
from mmdet.datasets.transforms.formatting import PackDetInputs
from mmdet.datasets.transforms.loading import (FilterAnnotations,
LoadAnnotations,
LoadImageFromFile)
from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic,
Pad, RandomCrop, RandomFlip,
RandomResize, Resize)
from mmdet.evaluation import CocoMetric

# dataset settings
dataset_type = CocoDataset
data_root = 'data/coco/'
image_size = (1024, 1024)

backend_args = None

train_pipeline = [
dict(type=LoadImageFromFile, backend_args=backend_args),
dict(type=LoadAnnotations, with_bbox=True, with_mask=True),
dict(
type=RandomResize,
scale=image_size,
ratio_range=(0.1, 2.0),
keep_ratio=True),
dict(
type=RandomCrop,
crop_type='absolute_range',
crop_size=image_size,
recompute_bbox=True,
allow_negative_crop=True),
dict(type=FilterAnnotations, min_gt_bbox_wh=(1e-2, 1e-2)),
dict(type=RandomFlip, prob=0.5),
dict(type=PackDetInputs)
]
test_pipeline = [
dict(type=LoadImageFromFile, backend_args=backend_args),
dict(type=Resize, scale=(1333, 800), keep_ratio=True),
dict(type=LoadAnnotations, with_bbox=True),
dict(
type=PackDetInputs,
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]

# Use RepeatDataset to speed up training
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type=DefaultSampler, shuffle=True),
dataset=dict(
type=RepeatDataset,
times=4, # simply change this from 2 to 16 for 50e - 400e training.
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline,
backend_args=backend_args)))
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=test_pipeline,
backend_args=backend_args))
test_dataloader = val_dataloader

val_evaluator = dict(
type=CocoMetric,
ann_file=data_root + 'annotations/instances_val2017.json',
metric=['bbox', 'segm'],
format_only=False,
backend_args=backend_args)
test_evaluator = val_evaluator

max_epochs = 25

train_cfg = dict(
type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=5)
val_cfg = dict(type=ValLoop)
test_cfg = dict(type=TestLoop)

# optimizer assumes bs=64
optim_wrapper = dict(
type=OptimWrapper,
optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.00004))

# learning rate
param_scheduler = [
dict(type=LinearLR, start_factor=0.067, by_epoch=False, begin=0, end=500),
dict(
type=MultiStepLR,
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[22, 24],
gamma=0.1)
]

# only keep latest 2 checkpoints
default_hooks.update(dict(checkpoint=dict(max_keep_ckpts=2)))

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (32 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=64)
Loading

0 comments on commit 59b0fc5

Please sign in to comment.