forked from Janspiry/Palette-Image-to-Image-Diffusion-Models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
104 lines (89 loc) · 4.29 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import os
import warnings
import torch
import torch.multiprocessing as mp
import core.praser as Praser
import core.util as Util
from core.logger import VisualWriter, InfoLogger
from data import define_dataloader
from models import create_model, define_network, define_loss, define_metric, define_optimizer, define_scheduler
def main_worker(gpu, ngpus_per_node, opt):
""" threads running on each GPU """
if 'local_rank' not in opt:
opt['local_rank'] = opt['global_rank'] = gpu
if opt['distributed']:
torch.cuda.set_device(int(opt['local_rank']))
print('using GPU {} for training'.format(int(opt['local_rank'])))
torch.distributed.init_process_group(backend = 'nccl',
init_method = opt['init_method'],
world_size = opt['world_size'],
rank = opt['global_rank'],
group_name='mtorch'
)
'''set seed and and cuDNN environment '''
torch.backends.cudnn.enabled = True
warnings.warn('You have chosen to use cudnn for accleration. torch.backends.cudnn.enabled=True')
Util.set_seed(opt['seed'])
''' set logger '''
phase_logger = InfoLogger(opt)
phase_writer = VisualWriter(opt, phase_logger)
phase_logger.info('Create the log file in directory {}.\n'.format(opt['path']['experiments_root']))
'''set networks and dataset'''
phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test.
networks = [define_network(phase_logger, opt, item_opt) for item_opt in opt['model']['which_networks']]
''' set metrics, loss, optimizer and schedulers '''
metrics = [define_metric(phase_logger, item_opt) for item_opt in opt['model']['which_metrics']]
losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']]
trian_params = [list(filter(lambda p: p.requires_grad, network.parameters())) for network in networks]
optimizers = [define_optimizer(trian_params[_idx], phase_logger, item_opt)
for _idx, item_opt in enumerate(opt['model']['which_optimizers'])]
optimizers = [optimizer for optimizer in optimizers if optimizer is not None]
lr_schedulers = [define_scheduler(optimizers[_idx], phase_logger, item_opt)
for _idx, item_opt in enumerate(opt['model']['which_lr_schedulers'])]
lr_schedulers = [lr_scheduler for lr_scheduler in lr_schedulers if lr_scheduler is not None]
model = create_model(
opt = opt,
networks = networks,
phase_loader = phase_loader,
val_loader = val_loader,
optimizers = optimizers,
lr_schedulers = lr_schedulers,
losses = losses,
metrics = metrics,
logger = phase_logger,
writer = phase_writer
)
phase_logger.info('Begin model {}.'.format(opt['phase']))
try:
if opt['phase'] == 'train':
model.train()
else:
model.test()
finally:
phase_writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config/colorization_mirflickr25k.json', help='JSON file for configuration')
parser.add_argument('-p', '--phase', type=str, choices=['train','test'], help='Run train or test', default='train')
parser.add_argument('-b', '--batch', type=int, default=None, help='Batch size in every gpu')
parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
parser.add_argument('-d', '--debug', action='store_true')
parser.add_argument('-P', '--port', default='21012', type=str)
''' parser configs '''
args = parser.parse_args()
opt = Praser.parse(args)
''' cuda devices '''
gpu_str = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str
print('export CUDA_VISIBLE_DEVICES={}'.format(gpu_str))
''' use DistributedDataParallel(DDP) and multiprocessing for multi-gpu training'''
# [Todo]: multi GPU on multi machine
if opt['distributed']:
ngpus_per_node = len(opt['gpu_ids']) # or torch.cuda.device_count()
opt['world_size'] = ngpus_per_node
opt['init_method'] = 'tcp://127.0.0.1:'+ args.port
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt))
else:
opt['world_size'] = 1
main_worker(0, 1, opt)