diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index ae11fdad9d..f193a22e51 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -14,8 +14,13 @@ from .parrots_wrapper import TORCH_VERSION from .path import (check_file_exist, fopen, is_abs, is_filepath, mkdir_or_exist, scandir, symlink) +from .progressbar import (ProgressBar, track_iter_progress, + track_parallel_progress, track_progress) from .setup_env import set_multi_processing from .sync_bn import revert_sync_batchnorm +from .timer import Timer, check_time +from .torch_ops import torch_meshgrid +from .trace import is_jit_tracing from .version_utils import digit_version, get_git_hash # TODO: creates intractable circular import issues @@ -32,5 +37,8 @@ 'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url', 'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm', 'is_abs', 'is_installed', 'call_command', 'get_installed_path', - 'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env' + 'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env', + 'Timer', 'check_time', 'ProgressBar', 'track_iter_progress', + 'track_parallel_progress', 'track_progress', 'torch_meshgrid', + 'is_jit_tracing' ] diff --git a/mmengine/utils/parrots_wrapper.py b/mmengine/utils/parrots_wrapper.py index 7f23b86df5..edda5bdcda 100644 --- a/mmengine/utils/parrots_wrapper.py +++ b/mmengine/utils/parrots_wrapper.py @@ -106,3 +106,14 @@ def _get_norm() -> tuple: BuildExtension, CppExtension, CUDAExtension = _get_extension() _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() + + +class SyncBatchNorm(SyncBatchNorm_): # type: ignore + + def _check_input_dim(self, input): + if TORCH_VERSION == 'parrots': + if input.dim() < 2: + raise ValueError( + f'expected at least 2D input (got {input.dim()}D input)') + else: + super()._check_input_dim(input) diff --git a/mmengine/utils/progressbar.py b/mmengine/utils/progressbar.py new file mode 100644 index 0000000000..0062f670dd --- /dev/null +++ b/mmengine/utils/progressbar.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from collections.abc import Iterable +from multiprocessing import Pool +from shutil import get_terminal_size + +from .timer import Timer + + +class ProgressBar: + """A progress bar which can print the progress.""" + + def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): + self.task_num = task_num + self.bar_width = bar_width + self.completed = 0 + self.file = file + if start: + self.start() + + @property + def terminal_width(self): + width, _ = get_terminal_size() + return width + + def start(self): + if self.task_num > 0: + self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' + 'elapsed: 0s, ETA:') + else: + self.file.write('completed: 0, elapsed: 0s') + self.file.flush() + self.timer = Timer() + + def update(self, num_tasks=1): + assert num_tasks > 0 + self.completed += num_tasks + elapsed = self.timer.since_start() + if elapsed > 0: + fps = self.completed / elapsed + else: + fps = float('inf') + if self.task_num > 0: + percentage = self.completed / float(self.task_num) + eta = int(elapsed * (1 - percentage) / percentage + 0.5) + msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ + f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ + f'ETA: {eta:5}s' + + bar_width = min(self.bar_width, + int(self.terminal_width - len(msg)) + 2, + int(self.terminal_width * 0.6)) + bar_width = max(2, bar_width) + mark_width = int(bar_width * percentage) + bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) + self.file.write(msg.format(bar_chars)) + else: + self.file.write( + f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' + f' {fps:.1f} tasks/s') + self.file.flush() + + +def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs): + """Track the progress of tasks execution with a progress bar. + + Tasks are done with a simple for-loop. + + Args: + func (callable): The function to be applied to each task. + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + bar_width (int): Width of progress bar. + + Returns: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + prog_bar = ProgressBar(task_num, bar_width, file=file) + results = [] + for task in tasks: + results.append(func(task, **kwargs)) + prog_bar.update() + prog_bar.file.write('\n') + return results + + +def init_pool(process_num, initializer=None, initargs=None): + if initializer is None: + return Pool(process_num) + elif initargs is None: + return Pool(process_num, initializer) + else: + if not isinstance(initargs, tuple): + raise TypeError('"initargs" must be a tuple') + return Pool(process_num, initializer, initargs) + + +def track_parallel_progress(func, + tasks, + nproc, + initializer=None, + initargs=None, + bar_width=50, + chunksize=1, + skip_first=False, + keep_order=True, + file=sys.stdout): + """Track the progress of parallel task execution with a progress bar. + + The built-in :mod:`multiprocessing` module is used for process pools and + tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. + + Args: + func (callable): The function to be applied to each task. + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + nproc (int): Process (worker) number. + initializer (None or callable): Refer to :class:`multiprocessing.Pool` + for details. + initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for + details. + chunksize (int): Refer to :class:`multiprocessing.Pool` for details. + bar_width (int): Width of progress bar. + skip_first (bool): Whether to skip the first sample for each worker + when estimating fps, since the initialization step may takes + longer. + keep_order (bool): If True, :func:`Pool.imap` is used, otherwise + :func:`Pool.imap_unordered` is used. + + Returns: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + pool = init_pool(nproc, initializer, initargs) + start = not skip_first + task_num -= nproc * chunksize * int(skip_first) + prog_bar = ProgressBar(task_num, bar_width, start, file=file) + results = [] + if keep_order: + gen = pool.imap(func, tasks, chunksize) + else: + gen = pool.imap_unordered(func, tasks, chunksize) + for result in gen: + results.append(result) + if skip_first: + if len(results) < nproc * chunksize: + continue + elif len(results) == nproc * chunksize: + prog_bar.start() + continue + prog_bar.update() + prog_bar.file.write('\n') + pool.close() + pool.join() + return results + + +def track_iter_progress(tasks, bar_width=50, file=sys.stdout): + """Track the progress of tasks iteration or enumeration with a progress + bar. + + Tasks are yielded with a simple for-loop. + + Args: + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + bar_width (int): Width of progress bar. + + Yields: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + prog_bar = ProgressBar(task_num, bar_width, file=file) + for task in tasks: + yield task + prog_bar.update() + prog_bar.file.write('\n') diff --git a/mmengine/utils/timer.py b/mmengine/utils/timer.py new file mode 100644 index 0000000000..087a969cfa --- /dev/null +++ b/mmengine/utils/timer.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from time import time + + +class TimerError(Exception): + + def __init__(self, message): + self.message = message + super().__init__(message) + + +class Timer: + """A flexible Timer class. + + Examples: + >>> import time + >>> import mmcv + >>> with mmcv.Timer(): + >>> # simulate a code block that will run for 1s + >>> time.sleep(1) + 1.000 + >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'): + >>> # simulate a code block that will run for 1s + >>> time.sleep(1) + it takes 1.0 seconds + >>> timer = mmcv.Timer() + >>> time.sleep(0.5) + >>> print(timer.since_start()) + 0.500 + >>> time.sleep(0.5) + >>> print(timer.since_last_check()) + 0.500 + >>> print(timer.since_start()) + 1.000 + """ + + def __init__(self, start=True, print_tmpl=None): + self._is_running = False + self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}' + if start: + self.start() + + @property + def is_running(self): + """bool: indicate whether the timer is running""" + return self._is_running + + def __enter__(self): + self.start() + return self + + def __exit__(self, type, value, traceback): + print(self.print_tmpl.format(self.since_last_check())) + self._is_running = False + + def start(self): + """Start the timer.""" + if not self._is_running: + self._t_start = time() + self._is_running = True + self._t_last = time() + + def since_start(self): + """Total time since the timer is started. + + Returns: + float: Time in seconds. + """ + if not self._is_running: + raise TimerError('timer is not running') + self._t_last = time() + return self._t_last - self._t_start + + def since_last_check(self): + """Time since the last checking. + + Either :func:`since_start` or :func:`since_last_check` is a checking + operation. + + Returns: + float: Time in seconds. + """ + if not self._is_running: + raise TimerError('timer is not running') + dur = time() - self._t_last + self._t_last = time() + return dur + + +_g_timers = {} # global timers + + +def check_time(timer_id): + """Add check points in a single line. + + This method is suitable for running a task on a list of items. A timer will + be registered when the method is called for the first time. + + Examples: + >>> import time + >>> import mmcv + >>> for i in range(1, 6): + >>> # simulate a code block + >>> time.sleep(i) + >>> mmcv.check_time('task1') + 2.000 + 3.000 + 4.000 + 5.000 + + Args: + str: Timer identifier. + """ + if timer_id not in _g_timers: + _g_timers[timer_id] = Timer() + return 0 + else: + return _g_timers[timer_id].since_last_check() diff --git a/mmengine/utils/torch_ops.py b/mmengine/utils/torch_ops.py new file mode 100644 index 0000000000..b4f2213a43 --- /dev/null +++ b/mmengine/utils/torch_ops.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from .parrots_wrapper import TORCH_VERSION +from .version_utils import digit_version + +_torch_version_meshgrid_indexing = ( + 'parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) + + +def torch_meshgrid(*tensors): + """A wrapper of torch.meshgrid to compat different PyTorch versions. + + Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``. + So we implement a wrapper here to avoid warning when using high-version + PyTorch and avoid compatibility issues when using previous versions of + PyTorch. + + Args: + tensors (List[Tensor]): List of scalars or 1 dimensional tensors. + + Returns: + Sequence[Tensor]: Sequence of meshgrid tensors. + """ + if _torch_version_meshgrid_indexing: + return torch.meshgrid(*tensors, indexing='ij') + else: + return torch.meshgrid(*tensors) # Uses indexing='ij' by default diff --git a/mmengine/utils/trace.py b/mmengine/utils/trace.py new file mode 100644 index 0000000000..c8d40595a4 --- /dev/null +++ b/mmengine/utils/trace.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch + +from .version_utils import digit_version + + +def is_jit_tracing() -> bool: + if (torch.__version__ != 'parrots' + and digit_version(torch.__version__) >= digit_version('1.6.0')): + on_trace = torch.jit.is_tracing() + # In PyTorch 1.6, torch.jit.is_tracing has a bug. + # Refers to https://github.com/pytorch/pytorch/issues/42448 + if isinstance(on_trace, bool): + return on_trace + else: + return torch._C._is_tracing() + else: + warnings.warn( + 'torch.jit.is_tracing is only supported after v1.6.0. ' + 'Therefore is_tracing returns False automatically. Please ' + 'set on_trace manually if you are using trace.', UserWarning) + return False diff --git a/tests/test_utils/test_progressbar.py b/tests/test_utils/test_progressbar.py new file mode 100644 index 0000000000..982aa247f7 --- /dev/null +++ b/tests/test_utils/test_progressbar.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import time +from io import StringIO +from unittest.mock import patch + +import mmcv + + +def reset_string_io(io): + io.truncate(0) + io.seek(0) + + +class TestProgressBar: + + def test_start(self): + out = StringIO() + bar_width = 20 + # without total task num + prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) + assert out.getvalue() == 'completed: 0, elapsed: 0s' + reset_string_io(out) + prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out) + assert out.getvalue() == '' + reset_string_io(out) + prog_bar.start() + assert out.getvalue() == 'completed: 0, elapsed: 0s' + # with total task num + reset_string_io(out) + prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) + assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' + reset_string_io(out) + prog_bar = mmcv.ProgressBar( + 10, bar_width=bar_width, start=False, file=out) + assert out.getvalue() == '' + reset_string_io(out) + prog_bar.start() + assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' + + def test_update(self): + out = StringIO() + bar_width = 20 + # without total task num + prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) + time.sleep(1) + reset_string_io(out) + prog_bar.update() + assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s' + reset_string_io(out) + # with total task num + prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) + time.sleep(1) + reset_string_io(out) + prog_bar.update() + assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \ + 'task/s, elapsed: 1s, ETA: 9s' + + def test_adaptive_length(self): + with patch.dict('os.environ', {'COLUMNS': '80'}): + out = StringIO() + bar_width = 20 + prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) + time.sleep(1) + reset_string_io(out) + prog_bar.update() + assert len(out.getvalue()) == 66 + + os.environ['COLUMNS'] = '30' + reset_string_io(out) + prog_bar.update() + assert len(out.getvalue()) == 48 + + os.environ['COLUMNS'] = '60' + reset_string_io(out) + prog_bar.update() + assert len(out.getvalue()) == 60 + + +def sleep_1s(num): + time.sleep(1) + return num + + +def test_track_progress_list(): + out = StringIO() + ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out) + assert out.getvalue() == ( + '[ ] 0/3, elapsed: 0s, ETA:' + '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' + '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' + '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') + assert ret == [1, 2, 3] + + +def test_track_progress_iterator(): + out = StringIO() + ret = mmcv.track_progress( + sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out) + assert out.getvalue() == ( + '[ ] 0/3, elapsed: 0s, ETA:' + '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' + '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' + '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') + assert ret == [1, 2, 3] + + +def test_track_iter_progress(): + out = StringIO() + ret = [] + for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out): + ret.append(sleep_1s(num)) + assert out.getvalue() == ( + '[ ] 0/3, elapsed: 0s, ETA:' + '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' + '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' + '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') + assert ret == [1, 2, 3] + + +def test_track_enum_progress(): + out = StringIO() + ret = [] + count = [] + for i, num in enumerate( + mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)): + ret.append(sleep_1s(num)) + count.append(i) + assert out.getvalue() == ( + '[ ] 0/3, elapsed: 0s, ETA:' + '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' + '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' + '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') + assert ret == [1, 2, 3] + assert count == [0, 1, 2] + + +def test_track_parallel_progress_list(): + out = StringIO() + results = mmcv.track_parallel_progress( + sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out) + # The following cannot pass CI on Github Action + # assert out.getvalue() == ( + # '[ ] 0/4, elapsed: 0s, ETA:' + # '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s' + # '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s' + # '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s' + # '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n') + assert results == [1, 2, 3, 4] + + +def test_track_parallel_progress_iterator(): + out = StringIO() + results = mmcv.track_parallel_progress( + sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out) + # The following cannot pass CI on Github Action + # assert out.getvalue() == ( + # '[ ] 0/4, elapsed: 0s, ETA:' + # '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s' + # '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s' + # '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s' + # '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n') + assert results == [1, 2, 3, 4] diff --git a/tests/test_utils/test_timer.py b/tests/test_utils/test_timer.py new file mode 100644 index 0000000000..e9f591352f --- /dev/null +++ b/tests/test_utils/test_timer.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time + +import mmcv +import pytest + + +def test_timer_init(): + timer = mmcv.Timer(start=False) + assert not timer.is_running + timer.start() + assert timer.is_running + timer = mmcv.Timer() + assert timer.is_running + + +def test_timer_run(): + timer = mmcv.Timer() + time.sleep(1) + assert abs(timer.since_start() - 1) < 1e-2 + time.sleep(1) + assert abs(timer.since_last_check() - 1) < 1e-2 + assert abs(timer.since_start() - 2) < 1e-2 + timer = mmcv.Timer(False) + with pytest.raises(mmcv.TimerError): + timer.since_start() + with pytest.raises(mmcv.TimerError): + timer.since_last_check() + + +def test_timer_context(capsys): + with mmcv.Timer(): + time.sleep(1) + out, _ = capsys.readouterr() + assert abs(float(out) - 1) < 1e-2 + with mmcv.Timer(print_tmpl='time: {:.1f}s'): + time.sleep(1) + out, _ = capsys.readouterr() + assert out == 'time: 1.0s\n' diff --git a/tests/test_utils/test_torch_ops.py b/tests/test_utils/test_torch_ops.py new file mode 100644 index 0000000000..a3790c0939 --- /dev/null +++ b/tests/test_utils/test_torch_ops.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmengine.utils import torch_meshgrid + + +def test_torch_meshgrid(): + # torch_meshgrid should not throw warning + with pytest.warns(None) as record: + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5, 6]) + grid_x, grid_y = torch_meshgrid(x, y) + + assert len(record) == 0 diff --git a/tests/test_utils/test_trace.py b/tests/test_utils/test_trace.py new file mode 100644 index 0000000000..b723f75037 --- /dev/null +++ b/tests/test_utils/test_trace.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmengine.utils import digit_version, is_jit_tracing + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.6.0'), + reason='torch.jit.is_tracing is not available before 1.6.0') +def test_is_jit_tracing(): + + def foo(x): + if is_jit_tracing(): + return x + else: + return x.tolist() + + x = torch.rand(3) + # test without trace + assert isinstance(foo(x), list) + + # test with trace + traced_foo = torch.jit.trace(foo, (torch.rand(1), )) + assert isinstance(traced_foo(x), torch.Tensor)