-
Notifications
You must be signed in to change notification settings - Fork 7
/
epoch_promptdet_runner.py
110 lines (94 loc) · 4.12 KB
/
epoch_promptdet_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os.path as osp
import platform
import shutil
import time
import warnings
import torch
import mmcv
from mmcv.runner import EpochBasedRunner
from mmcv.runner.base_runner import BaseRunner
from mmcv.runner.builder import RUNNERS
from mmcv.runner.checkpoint import save_checkpoint
from mmcv.runner.utils import get_host_info
@RUNNERS.register_module()
class EpochPromptDetRunner(EpochBasedRunner):
def train(self, data_loaders, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loaders[0]
self._max_iters = 0
lengths = []
data_iters = []
for data_loader in data_loaders:
self._max_iters += self._max_epochs * len(data_loader)
lengths.append(len(data_loader))
data_iters.append(iter(data_loader))
sum_length = sum(lengths)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
curs = [0 for _ in range(len(data_loaders))]
for i in range(sum_length):
index = curs.index(min(curs))
curs[index] += 1.0 / lengths[index]
data_batch = next(data_iters[index])
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
# assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = 0
for data_loader in data_loaders:
self._max_iters += self._max_epochs * len(data_loader)
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders, **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')