Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate utils from mmcv #447

Merged
merged 1 commit into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mmengine/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
from .parrots_wrapper import TORCH_VERSION
from .path import (check_file_exist, fopen, is_abs, is_filepath,
mkdir_or_exist, scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .setup_env import set_multi_processing
from .sync_bn import revert_sync_batchnorm
from .timer import Timer, check_time
from .torch_ops import torch_meshgrid
from .trace import is_jit_tracing
from .version_utils import digit_version, get_git_hash

# TODO: creates intractable circular import issues
Expand All @@ -32,5 +37,8 @@
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
'is_abs', 'is_installed', 'call_command', 'get_installed_path',
'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env'
'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env',
'Timer', 'check_time', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress', 'torch_meshgrid',
'is_jit_tracing'
]
11 changes: 11 additions & 0 deletions mmengine/utils/parrots_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,14 @@ def _get_norm() -> tuple:
BuildExtension, CppExtension, CUDAExtension = _get_extension()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()


class SyncBatchNorm(SyncBatchNorm_): # type: ignore

def _check_input_dim(self, input):
if TORCH_VERSION == 'parrots':
if input.dim() < 2:
raise ValueError(
f'expected at least 2D input (got {input.dim()}D input)')
else:
super()._check_input_dim(input)
208 changes: 208 additions & 0 deletions mmengine/utils/progressbar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from collections.abc import Iterable
from multiprocessing import Pool
from shutil import get_terminal_size

from .timer import Timer


class ProgressBar:
"""A progress bar which can print the progress."""

def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
self.task_num = task_num
self.bar_width = bar_width
self.completed = 0
self.file = file
if start:
self.start()

@property
def terminal_width(self):
width, _ = get_terminal_size()
return width

def start(self):
if self.task_num > 0:
self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
'elapsed: 0s, ETA:')
else:
self.file.write('completed: 0, elapsed: 0s')
self.file.flush()
self.timer = Timer()

def update(self, num_tasks=1):
assert num_tasks > 0
self.completed += num_tasks
elapsed = self.timer.since_start()
if elapsed > 0:
fps = self.completed / elapsed
else:
fps = float('inf')
if self.task_num > 0:
percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
f'ETA: {eta:5}s'

bar_width = min(self.bar_width,
int(self.terminal_width - len(msg)) + 2,
int(self.terminal_width * 0.6))
bar_width = max(2, bar_width)
mark_width = int(bar_width * percentage)
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
self.file.write(msg.format(bar_chars))
else:
self.file.write(
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
f' {fps:.1f} tasks/s')
self.file.flush()


def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
"""Track the progress of tasks execution with a progress bar.

Tasks are done with a simple for-loop.

Args:
func (callable): The function to be applied to each task.
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
bar_width (int): Width of progress bar.

Returns:
list: The task results.
"""
if isinstance(tasks, tuple):
assert len(tasks) == 2
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, Iterable):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple')
prog_bar = ProgressBar(task_num, bar_width, file=file)
results = []
for task in tasks:
results.append(func(task, **kwargs))
prog_bar.update()
prog_bar.file.write('\n')
return results


def init_pool(process_num, initializer=None, initargs=None):
if initializer is None:
return Pool(process_num)
elif initargs is None:
return Pool(process_num, initializer)
else:
if not isinstance(initargs, tuple):
raise TypeError('"initargs" must be a tuple')
return Pool(process_num, initializer, initargs)


def track_parallel_progress(func,
tasks,
nproc,
initializer=None,
initargs=None,
bar_width=50,
chunksize=1,
skip_first=False,
keep_order=True,
file=sys.stdout):
"""Track the progress of parallel task execution with a progress bar.

The built-in :mod:`multiprocessing` module is used for process pools and
tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.

Args:
func (callable): The function to be applied to each task.
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
nproc (int): Process (worker) number.
initializer (None or callable): Refer to :class:`multiprocessing.Pool`
for details.
initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
details.
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
bar_width (int): Width of progress bar.
skip_first (bool): Whether to skip the first sample for each worker
when estimating fps, since the initialization step may takes
longer.
keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
:func:`Pool.imap_unordered` is used.

Returns:
list: The task results.
"""
if isinstance(tasks, tuple):
assert len(tasks) == 2
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, Iterable):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple')
pool = init_pool(nproc, initializer, initargs)
start = not skip_first
task_num -= nproc * chunksize * int(skip_first)
prog_bar = ProgressBar(task_num, bar_width, start, file=file)
results = []
if keep_order:
gen = pool.imap(func, tasks, chunksize)
else:
gen = pool.imap_unordered(func, tasks, chunksize)
for result in gen:
results.append(result)
if skip_first:
if len(results) < nproc * chunksize:
continue
elif len(results) == nproc * chunksize:
prog_bar.start()
continue
prog_bar.update()
prog_bar.file.write('\n')
pool.close()
pool.join()
return results


def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
"""Track the progress of tasks iteration or enumeration with a progress
bar.

Tasks are yielded with a simple for-loop.

Args:
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
bar_width (int): Width of progress bar.

Yields:
list: The task results.
"""
if isinstance(tasks, tuple):
assert len(tasks) == 2
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, Iterable):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple')
prog_bar = ProgressBar(task_num, bar_width, file=file)
for task in tasks:
yield task
prog_bar.update()
prog_bar.file.write('\n')
118 changes: 118 additions & 0 deletions mmengine/utils/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) OpenMMLab. All rights reserved.
from time import time


class TimerError(Exception):

def __init__(self, message):
self.message = message
super().__init__(message)


class Timer:
"""A flexible Timer class.

Examples:
>>> import time
>>> import mmcv
>>> with mmcv.Timer():
>>> # simulate a code block that will run for 1s
>>> time.sleep(1)
1.000
>>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
>>> # simulate a code block that will run for 1s
>>> time.sleep(1)
it takes 1.0 seconds
>>> timer = mmcv.Timer()
>>> time.sleep(0.5)
>>> print(timer.since_start())
0.500
>>> time.sleep(0.5)
>>> print(timer.since_last_check())
0.500
>>> print(timer.since_start())
1.000
"""

def __init__(self, start=True, print_tmpl=None):
self._is_running = False
self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
if start:
self.start()

@property
def is_running(self):
"""bool: indicate whether the timer is running"""
return self._is_running

def __enter__(self):
self.start()
return self

def __exit__(self, type, value, traceback):
print(self.print_tmpl.format(self.since_last_check()))
self._is_running = False

def start(self):
"""Start the timer."""
if not self._is_running:
self._t_start = time()
self._is_running = True
self._t_last = time()

def since_start(self):
"""Total time since the timer is started.

Returns:
float: Time in seconds.
"""
if not self._is_running:
raise TimerError('timer is not running')
self._t_last = time()
return self._t_last - self._t_start

def since_last_check(self):
"""Time since the last checking.

Either :func:`since_start` or :func:`since_last_check` is a checking
operation.

Returns:
float: Time in seconds.
"""
if not self._is_running:
raise TimerError('timer is not running')
dur = time() - self._t_last
self._t_last = time()
return dur


_g_timers = {} # global timers


def check_time(timer_id):
"""Add check points in a single line.

This method is suitable for running a task on a list of items. A timer will
be registered when the method is called for the first time.

Examples:
>>> import time
>>> import mmcv
>>> for i in range(1, 6):
>>> # simulate a code block
>>> time.sleep(i)
>>> mmcv.check_time('task1')
2.000
3.000
4.000
5.000

Args:
str: Timer identifier.
"""
if timer_id not in _g_timers:
_g_timers[timer_id] = Timer()
return 0
else:
return _g_timers[timer_id].since_last_check()
29 changes: 29 additions & 0 deletions mmengine/utils/torch_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .parrots_wrapper import TORCH_VERSION
from .version_utils import digit_version

_torch_version_meshgrid_indexing = (
'parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))


def torch_meshgrid(*tensors):
"""A wrapper of torch.meshgrid to compat different PyTorch versions.

Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``.
So we implement a wrapper here to avoid warning when using high-version
PyTorch and avoid compatibility issues when using previous versions of
PyTorch.

Args:
tensors (List[Tensor]): List of scalars or 1 dimensional tensors.

Returns:
Sequence[Tensor]: Sequence of meshgrid tensors.
"""
if _torch_version_meshgrid_indexing:
return torch.meshgrid(*tensors, indexing='ij')
else:
return torch.meshgrid(*tensors) # Uses indexing='ij' by default
Loading