Skip to content

Commit

Permalink
[Feature]Add loader args to config (#1388)
Browse files Browse the repository at this point in the history
* update pr

* add ut

* fix ut

* fix ut err

* fix ut name

* resolve comment
  • Loading branch information
VVsssssk authored Apr 20, 2022
1 parent 047f790 commit ff159fe
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 26 deletions.
3 changes: 2 additions & 1 deletion mmdet3d/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from mmcv.utils import Registry, build_from_cfg, print_log

from .collect_env import collect_env
from .compat_cfg import compat_cfg
from .logger import get_root_logger
from .setup_env import setup_multi_processes

__all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
'print_log', 'setup_multi_processes'
'print_log', 'setup_multi_processes', 'compat_cfg'
]
139 changes: 139 additions & 0 deletions mmdet3d/utils/compat_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings

from mmcv import ConfigDict


def compat_cfg(cfg):
"""This function would modify some filed to keep the compatibility of
config.
For example, it will move some args which will be deprecated to the correct
fields.
"""
cfg = copy.deepcopy(cfg)
cfg = compat_imgs_per_gpu(cfg)
cfg = compat_loader_args(cfg)
cfg = compat_runner_args(cfg)
return cfg


def compat_runner_args(cfg):
if 'runner' not in cfg:
cfg.runner = ConfigDict({
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
})
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
return cfg


def compat_imgs_per_gpu(cfg):
cfg = copy.deepcopy(cfg)
if 'imgs_per_gpu' in cfg.data:
warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
warnings.warn(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
return cfg


def compat_loader_args(cfg):
"""Deprecated sample_per_gpu in cfg.data."""

cfg = copy.deepcopy(cfg)
if 'train_dataloader' not in cfg.data:
cfg.data['train_dataloader'] = ConfigDict()
if 'val_dataloader' not in cfg.data:
cfg.data['val_dataloader'] = ConfigDict()
if 'test_dataloader' not in cfg.data:
cfg.data['test_dataloader'] = ConfigDict()

# special process for train_dataloader
if 'samples_per_gpu' in cfg.data:

samples_per_gpu = cfg.data.pop('samples_per_gpu')
assert 'samples_per_gpu' not in \
cfg.data.train_dataloader, ('`samples_per_gpu` are set '
'in `data` field and ` '
'data.train_dataloader` '
'at the same time. '
'Please only set it in '
'`data.train_dataloader`. ')
cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu

if 'persistent_workers' in cfg.data:

persistent_workers = cfg.data.pop('persistent_workers')
assert 'persistent_workers' not in \
cfg.data.train_dataloader, ('`persistent_workers` are set '
'in `data` field and ` '
'data.train_dataloader` '
'at the same time. '
'Please only set it in '
'`data.train_dataloader`. ')
cfg.data.train_dataloader['persistent_workers'] = persistent_workers

if 'workers_per_gpu' in cfg.data:

workers_per_gpu = cfg.data.pop('workers_per_gpu')
cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu
cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu
cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu

# special process for val_dataloader
if 'samples_per_gpu' in cfg.data.val:
# keep default value of `sample_per_gpu` is 1
assert 'samples_per_gpu' not in \
cfg.data.val_dataloader, ('`samples_per_gpu` are set '
'in `data.val` field and ` '
'data.val_dataloader` at '
'the same time. '
'Please only set it in '
'`data.val_dataloader`. ')
cfg.data.val_dataloader['samples_per_gpu'] = \
cfg.data.val.pop('samples_per_gpu')
# special process for val_dataloader

# in case the test dataset is concatenated
if isinstance(cfg.data.test, dict):
if 'samples_per_gpu' in cfg.data.test:
assert 'samples_per_gpu' not in \
cfg.data.test_dataloader, ('`samples_per_gpu` are set '
'in `data.test` field and ` '
'data.test_dataloader` '
'at the same time. '
'Please only set it in '
'`data.test_dataloader`. ')

cfg.data.test_dataloader['samples_per_gpu'] = \
cfg.data.test.pop('samples_per_gpu')

elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
if 'samples_per_gpu' in ds_cfg:
assert 'samples_per_gpu' not in \
cfg.data.test_dataloader, ('`samples_per_gpu` are set '
'in `data.test` field and ` '
'data.test_dataloader` at'
' the same time. '
'Please only set it in '
'`data.test_dataloader`. ')
samples_per_gpu = max(
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu

return cfg
10 changes: 8 additions & 2 deletions mmdet3d/utils/setup_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ def setup_multi_processes(cfg):

# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
workers_per_gpu = cfg.data.get('workers_per_gpu', 1)
if 'train_dataloader' in cfg.data:
workers_per_gpu = \
max(cfg.data.train_dataloader.get('workers_per_gpu', 1),
workers_per_gpu)

if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
omp_num_threads = 1
warnings.warn(
f'Setting OMP_NUM_THREADS environment variable for each process '
Expand All @@ -37,7 +43,7 @@ def setup_multi_processes(cfg):
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)

# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
mkl_num_threads = 1
warnings.warn(
f'Setting MKL_NUM_THREADS environment variable for each process '
Expand Down
126 changes: 126 additions & 0 deletions tests/test_utils/test_compat_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pytest
from mmcv import ConfigDict

from mmdet3d.utils.compat_cfg import (compat_imgs_per_gpu, compat_loader_args,
compat_runner_args)


def test_compat_runner_args():
cfg = ConfigDict(dict(total_epochs=12))
with pytest.warns(None) as record:
cfg = compat_runner_args(cfg)
assert len(record) == 1
assert 'runner' in record.list[0].message.args[0]
assert 'runner' in cfg
assert cfg.runner.type == 'EpochBasedRunner'
assert cfg.runner.max_epochs == cfg.total_epochs


def test_compat_loader_args():
cfg = ConfigDict(dict(data=dict(val=dict(), test=dict(), train=dict())))
cfg = compat_loader_args(cfg)
# auto fill loader args
assert 'val_dataloader' in cfg.data
assert 'train_dataloader' in cfg.data
assert 'test_dataloader' in cfg.data
cfg = ConfigDict(
dict(
data=dict(
samples_per_gpu=1,
persistent_workers=True,
workers_per_gpu=1,
val=dict(samples_per_gpu=3),
test=dict(samples_per_gpu=2),
train=dict())))
with pytest.warns(None) as record:
cfg = compat_loader_args(cfg)
# 5 warning
assert len(record) == 5
# assert the warning message
assert 'train_dataloader' in record.list[0].message.args[0]
assert 'samples_per_gpu' in record.list[0].message.args[0]
assert 'persistent_workers' in record.list[1].message.args[0]
assert 'train_dataloader' in record.list[1].message.args[0]
assert 'workers_per_gpu' in record.list[2].message.args[0]
assert 'train_dataloader' in record.list[2].message.args[0]
assert cfg.data.train_dataloader.workers_per_gpu == 1
assert cfg.data.train_dataloader.samples_per_gpu == 1
assert cfg.data.train_dataloader.persistent_workers
assert cfg.data.val_dataloader.workers_per_gpu == 1
assert cfg.data.val_dataloader.samples_per_gpu == 3
assert cfg.data.test_dataloader.workers_per_gpu == 1
assert cfg.data.test_dataloader.samples_per_gpu == 2

# test test is a list
cfg = ConfigDict(
dict(
data=dict(
samples_per_gpu=1,
persistent_workers=True,
workers_per_gpu=1,
val=dict(samples_per_gpu=3),
test=[dict(samples_per_gpu=2),
dict(samples_per_gpu=3)],
train=dict())))

with pytest.warns(None) as record:
cfg = compat_loader_args(cfg)
# 6 warning
assert len(record) == 6
assert cfg.data.test_dataloader.samples_per_gpu == 3

# assert can not set args at the same time
cfg = ConfigDict(
dict(
data=dict(
samples_per_gpu=1,
persistent_workers=True,
workers_per_gpu=1,
val=dict(samples_per_gpu=3),
test=dict(samples_per_gpu=2),
train=dict(),
train_dataloader=dict(samples_per_gpu=2))))
# samples_per_gpu can not be set in `train_dataloader`
# and data field at the same time
with pytest.raises(AssertionError):
compat_loader_args(cfg)
cfg = ConfigDict(
dict(
data=dict(
samples_per_gpu=1,
persistent_workers=True,
workers_per_gpu=1,
val=dict(samples_per_gpu=3),
test=dict(samples_per_gpu=2),
train=dict(),
val_dataloader=dict(samples_per_gpu=2))))
# samples_per_gpu can not be set in `val_dataloader`
# and data field at the same time
with pytest.raises(AssertionError):
compat_loader_args(cfg)
cfg = ConfigDict(
dict(
data=dict(
samples_per_gpu=1,
persistent_workers=True,
workers_per_gpu=1,
val=dict(samples_per_gpu=3),
test=dict(samples_per_gpu=2),
test_dataloader=dict(samples_per_gpu=2))))
# samples_per_gpu can not be set in `test_dataloader`
# and data field at the same time
with pytest.raises(AssertionError):
compat_loader_args(cfg)


def test_compat_imgs_per_gpu():
cfg = ConfigDict(
dict(
data=dict(
imgs_per_gpu=1,
samples_per_gpu=2,
val=dict(),
test=dict(),
train=dict())))
cfg = compat_imgs_per_gpu(cfg)
assert cfg.data.samples_per_gpu == cfg.data.imgs_per_gpu
55 changes: 32 additions & 23 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
except ImportError:
from mmdet3d.utils import setup_multi_processes

try:
# If mmdet version > 2.23.0, compat_cfg would be imported and
# used from mmdet instead of mmdet3d.
from mmdet.utils import compat_cfg
except ImportError:
from mmdet3d.utils import compat_cfg


def parse_args():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -138,6 +145,8 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

cfg = compat_cfg(cfg)

# set multi-process settings
setup_multi_processes(cfg)

Expand All @@ -146,23 +155,6 @@ def main():
torch.backends.cudnn.benchmark = True

cfg.model.pretrained = None
# in case the test dataset is concatenated
samples_per_gpu = 1
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
samples_per_gpu = max(
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
if samples_per_gpu > 1:
for ds_cfg in cfg.data.test:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
Expand All @@ -180,18 +172,35 @@ def main():
distributed = True
init_dist(args.launcher, **cfg.dist_params)

test_dataloader_default_args = dict(
samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)

# in case the test dataset is concatenated
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
for ds_cfg in cfg.data.test:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

test_loader_cfg = {
**test_dataloader_default_args,
**cfg.data.get('test_dataloader', {})
}

# set random seeds
if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic)

# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
data_loader = build_dataloader(dataset, **test_loader_cfg)

# build the model and load checkpoint
cfg.model.train_cfg = None
Expand Down

0 comments on commit ff159fe

Please sign in to comment.