-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature]Add loader args to config (#1388)
* update pr * add ut * fix ut * fix ut err * fix ut name * resolve comment
- Loading branch information
Showing
5 changed files
with
307 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import copy | ||
import warnings | ||
|
||
from mmcv import ConfigDict | ||
|
||
|
||
def compat_cfg(cfg): | ||
"""This function would modify some filed to keep the compatibility of | ||
config. | ||
For example, it will move some args which will be deprecated to the correct | ||
fields. | ||
""" | ||
cfg = copy.deepcopy(cfg) | ||
cfg = compat_imgs_per_gpu(cfg) | ||
cfg = compat_loader_args(cfg) | ||
cfg = compat_runner_args(cfg) | ||
return cfg | ||
|
||
|
||
def compat_runner_args(cfg): | ||
if 'runner' not in cfg: | ||
cfg.runner = ConfigDict({ | ||
'type': 'EpochBasedRunner', | ||
'max_epochs': cfg.total_epochs | ||
}) | ||
warnings.warn( | ||
'config is now expected to have a `runner` section, ' | ||
'please set `runner` in your config.', UserWarning) | ||
else: | ||
if 'total_epochs' in cfg: | ||
assert cfg.total_epochs == cfg.runner.max_epochs | ||
return cfg | ||
|
||
|
||
def compat_imgs_per_gpu(cfg): | ||
cfg = copy.deepcopy(cfg) | ||
if 'imgs_per_gpu' in cfg.data: | ||
warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. ' | ||
'Please use "samples_per_gpu" instead') | ||
if 'samples_per_gpu' in cfg.data: | ||
warnings.warn( | ||
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' | ||
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' | ||
f'={cfg.data.imgs_per_gpu} is used in this experiments') | ||
else: | ||
warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"=' | ||
f'{cfg.data.imgs_per_gpu} in this experiments') | ||
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu | ||
return cfg | ||
|
||
|
||
def compat_loader_args(cfg): | ||
"""Deprecated sample_per_gpu in cfg.data.""" | ||
|
||
cfg = copy.deepcopy(cfg) | ||
if 'train_dataloader' not in cfg.data: | ||
cfg.data['train_dataloader'] = ConfigDict() | ||
if 'val_dataloader' not in cfg.data: | ||
cfg.data['val_dataloader'] = ConfigDict() | ||
if 'test_dataloader' not in cfg.data: | ||
cfg.data['test_dataloader'] = ConfigDict() | ||
|
||
# special process for train_dataloader | ||
if 'samples_per_gpu' in cfg.data: | ||
|
||
samples_per_gpu = cfg.data.pop('samples_per_gpu') | ||
assert 'samples_per_gpu' not in \ | ||
cfg.data.train_dataloader, ('`samples_per_gpu` are set ' | ||
'in `data` field and ` ' | ||
'data.train_dataloader` ' | ||
'at the same time. ' | ||
'Please only set it in ' | ||
'`data.train_dataloader`. ') | ||
cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu | ||
|
||
if 'persistent_workers' in cfg.data: | ||
|
||
persistent_workers = cfg.data.pop('persistent_workers') | ||
assert 'persistent_workers' not in \ | ||
cfg.data.train_dataloader, ('`persistent_workers` are set ' | ||
'in `data` field and ` ' | ||
'data.train_dataloader` ' | ||
'at the same time. ' | ||
'Please only set it in ' | ||
'`data.train_dataloader`. ') | ||
cfg.data.train_dataloader['persistent_workers'] = persistent_workers | ||
|
||
if 'workers_per_gpu' in cfg.data: | ||
|
||
workers_per_gpu = cfg.data.pop('workers_per_gpu') | ||
cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu | ||
cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu | ||
cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu | ||
|
||
# special process for val_dataloader | ||
if 'samples_per_gpu' in cfg.data.val: | ||
# keep default value of `sample_per_gpu` is 1 | ||
assert 'samples_per_gpu' not in \ | ||
cfg.data.val_dataloader, ('`samples_per_gpu` are set ' | ||
'in `data.val` field and ` ' | ||
'data.val_dataloader` at ' | ||
'the same time. ' | ||
'Please only set it in ' | ||
'`data.val_dataloader`. ') | ||
cfg.data.val_dataloader['samples_per_gpu'] = \ | ||
cfg.data.val.pop('samples_per_gpu') | ||
# special process for val_dataloader | ||
|
||
# in case the test dataset is concatenated | ||
if isinstance(cfg.data.test, dict): | ||
if 'samples_per_gpu' in cfg.data.test: | ||
assert 'samples_per_gpu' not in \ | ||
cfg.data.test_dataloader, ('`samples_per_gpu` are set ' | ||
'in `data.test` field and ` ' | ||
'data.test_dataloader` ' | ||
'at the same time. ' | ||
'Please only set it in ' | ||
'`data.test_dataloader`. ') | ||
|
||
cfg.data.test_dataloader['samples_per_gpu'] = \ | ||
cfg.data.test.pop('samples_per_gpu') | ||
|
||
elif isinstance(cfg.data.test, list): | ||
for ds_cfg in cfg.data.test: | ||
if 'samples_per_gpu' in ds_cfg: | ||
assert 'samples_per_gpu' not in \ | ||
cfg.data.test_dataloader, ('`samples_per_gpu` are set ' | ||
'in `data.test` field and ` ' | ||
'data.test_dataloader` at' | ||
' the same time. ' | ||
'Please only set it in ' | ||
'`data.test_dataloader`. ') | ||
samples_per_gpu = max( | ||
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test]) | ||
cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu | ||
|
||
return cfg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import pytest | ||
from mmcv import ConfigDict | ||
|
||
from mmdet3d.utils.compat_cfg import (compat_imgs_per_gpu, compat_loader_args, | ||
compat_runner_args) | ||
|
||
|
||
def test_compat_runner_args(): | ||
cfg = ConfigDict(dict(total_epochs=12)) | ||
with pytest.warns(None) as record: | ||
cfg = compat_runner_args(cfg) | ||
assert len(record) == 1 | ||
assert 'runner' in record.list[0].message.args[0] | ||
assert 'runner' in cfg | ||
assert cfg.runner.type == 'EpochBasedRunner' | ||
assert cfg.runner.max_epochs == cfg.total_epochs | ||
|
||
|
||
def test_compat_loader_args(): | ||
cfg = ConfigDict(dict(data=dict(val=dict(), test=dict(), train=dict()))) | ||
cfg = compat_loader_args(cfg) | ||
# auto fill loader args | ||
assert 'val_dataloader' in cfg.data | ||
assert 'train_dataloader' in cfg.data | ||
assert 'test_dataloader' in cfg.data | ||
cfg = ConfigDict( | ||
dict( | ||
data=dict( | ||
samples_per_gpu=1, | ||
persistent_workers=True, | ||
workers_per_gpu=1, | ||
val=dict(samples_per_gpu=3), | ||
test=dict(samples_per_gpu=2), | ||
train=dict()))) | ||
with pytest.warns(None) as record: | ||
cfg = compat_loader_args(cfg) | ||
# 5 warning | ||
assert len(record) == 5 | ||
# assert the warning message | ||
assert 'train_dataloader' in record.list[0].message.args[0] | ||
assert 'samples_per_gpu' in record.list[0].message.args[0] | ||
assert 'persistent_workers' in record.list[1].message.args[0] | ||
assert 'train_dataloader' in record.list[1].message.args[0] | ||
assert 'workers_per_gpu' in record.list[2].message.args[0] | ||
assert 'train_dataloader' in record.list[2].message.args[0] | ||
assert cfg.data.train_dataloader.workers_per_gpu == 1 | ||
assert cfg.data.train_dataloader.samples_per_gpu == 1 | ||
assert cfg.data.train_dataloader.persistent_workers | ||
assert cfg.data.val_dataloader.workers_per_gpu == 1 | ||
assert cfg.data.val_dataloader.samples_per_gpu == 3 | ||
assert cfg.data.test_dataloader.workers_per_gpu == 1 | ||
assert cfg.data.test_dataloader.samples_per_gpu == 2 | ||
|
||
# test test is a list | ||
cfg = ConfigDict( | ||
dict( | ||
data=dict( | ||
samples_per_gpu=1, | ||
persistent_workers=True, | ||
workers_per_gpu=1, | ||
val=dict(samples_per_gpu=3), | ||
test=[dict(samples_per_gpu=2), | ||
dict(samples_per_gpu=3)], | ||
train=dict()))) | ||
|
||
with pytest.warns(None) as record: | ||
cfg = compat_loader_args(cfg) | ||
# 6 warning | ||
assert len(record) == 6 | ||
assert cfg.data.test_dataloader.samples_per_gpu == 3 | ||
|
||
# assert can not set args at the same time | ||
cfg = ConfigDict( | ||
dict( | ||
data=dict( | ||
samples_per_gpu=1, | ||
persistent_workers=True, | ||
workers_per_gpu=1, | ||
val=dict(samples_per_gpu=3), | ||
test=dict(samples_per_gpu=2), | ||
train=dict(), | ||
train_dataloader=dict(samples_per_gpu=2)))) | ||
# samples_per_gpu can not be set in `train_dataloader` | ||
# and data field at the same time | ||
with pytest.raises(AssertionError): | ||
compat_loader_args(cfg) | ||
cfg = ConfigDict( | ||
dict( | ||
data=dict( | ||
samples_per_gpu=1, | ||
persistent_workers=True, | ||
workers_per_gpu=1, | ||
val=dict(samples_per_gpu=3), | ||
test=dict(samples_per_gpu=2), | ||
train=dict(), | ||
val_dataloader=dict(samples_per_gpu=2)))) | ||
# samples_per_gpu can not be set in `val_dataloader` | ||
# and data field at the same time | ||
with pytest.raises(AssertionError): | ||
compat_loader_args(cfg) | ||
cfg = ConfigDict( | ||
dict( | ||
data=dict( | ||
samples_per_gpu=1, | ||
persistent_workers=True, | ||
workers_per_gpu=1, | ||
val=dict(samples_per_gpu=3), | ||
test=dict(samples_per_gpu=2), | ||
test_dataloader=dict(samples_per_gpu=2)))) | ||
# samples_per_gpu can not be set in `test_dataloader` | ||
# and data field at the same time | ||
with pytest.raises(AssertionError): | ||
compat_loader_args(cfg) | ||
|
||
|
||
def test_compat_imgs_per_gpu(): | ||
cfg = ConfigDict( | ||
dict( | ||
data=dict( | ||
imgs_per_gpu=1, | ||
samples_per_gpu=2, | ||
val=dict(), | ||
test=dict(), | ||
train=dict()))) | ||
cfg = compat_imgs_per_gpu(cfg) | ||
assert cfg.data.samples_per_gpu == cfg.data.imgs_per_gpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters