-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
140 lines (122 loc) · 5.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import sys
import os
sys.path.append(os.path.dirname('/code/connectomics/'))
import argparse
import random
import numpy as np
import torch
from cilog import create_logger
torch.multiprocessing.set_start_method('spawn', force=True)
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from connectomics.config import load_cfg, save_all_cfg
from connectomics.data.utils import get_blocknames_from_points, save_feature
from connectomics.engine import Trainer, SSL_Trainer
def get_args():
parser = argparse.ArgumentParser(description="Model Training & Inference")
parser.add_argument('--config-file', type=str,
help='configuration file (yaml)')
parser.add_argument('--config-base', type=str,
help='base configuration file (yaml)', default=None)
parser.add_argument('--inference', action='store_true',
help='inference mode')
parser.add_argument('--distributed', action='store_true',
help='distributed training')
parser.add_argument('--local_rank', type=int,
help='node rank for distributed training', default=None)
parser.add_argument('--checkpoint', type=str, default=None,
help='path to load the checkpoint')
parser.add_argument('--debug', action='store_true',
help='run the scripts in debug mode')
# Merge configs from command line (e.g., add 'SYSTEM.NUM_GPUS 8').
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
return args
def init_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def main():
args = get_args()
args.local_rank = int(os.environ["LOCAL_RANK"]) if args.distributed else 0
if args.local_rank == 0 or args.local_rank is None:
print("Command line arguments: ", args)
manual_seed = 0 if args.local_rank is None else args.local_rank
init_seed(manual_seed)
cfg = load_cfg(args)
cfg_image_model = None
args_image_model = None
if cfg.INFERENCE.GET_PC_FEATURE == 'GPT':
seq_data = get_blocknames_from_points(cfg.DATASET.INPUT_PATH)
if cfg.MODEL.IMAGE_MODEL_CFG is not None:
args_image_model = get_args()
args_image_model.config_file = cfg.MODEL.IMAGE_MODEL_CFG
args_image_model.checkpoint = cfg.MODEL.IMAGE_MODEL_CKPT
assert args_image_model.checkpoint is not None
args_image_model.inference = True
cfg_image_model = load_cfg(args_image_model, merge_cmd=False)
log_name = cfg.DATASET.OUTPUT_PATH + '.log'
create_logger(name='l1', file=log_name, sub_print=True,
file_level='DEBUG')
if args.local_rank == 0 or args.local_rank is None:
# In distributed training, only print and save the
# configurations using the node with local_rank=0.
print("PyTorch: ", torch.__version__)
print(cfg)
if not os.path.exists(cfg.DATASET.OUTPUT_PATH):
print('Output directory: ', cfg.DATASET.OUTPUT_PATH)
os.makedirs(cfg.DATASET.OUTPUT_PATH)
save_all_cfg(cfg, cfg.DATASET.OUTPUT_PATH)
if args.distributed:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
assert torch.cuda.is_available(), \
"Distributed training without GPUs is not supported!"
dist.init_process_group("nccl", init_method='env://')
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Rank: {}. Device: {}".format(args.local_rank, device))
cudnn.enabled = True
cudnn.benchmark = True
mode = 'test' if args.inference else 'train'
if cfg.MODEL.SSL == 'none':
trainer = Trainer(cfg, device, mode,
rank=args.local_rank,
checkpoint=args.checkpoint, cfg_image_model=cfg_image_model,
checkpoint_image_model=args_image_model.checkpoint if args_image_model is not None else None)
else:
trainer = SSL_Trainer(cfg, device, mode,
rank=args.local_rank,
checkpoint=args.checkpoint)
# Start training or inference:
if cfg.DATASET.DO_CHUNK_TITLE == 0 and not cfg.DATASET.DO_MULTI_VOLUME:
if cfg.DATASET.BIOLOGICAL_DATSET:
test_func = trainer.test_biological
elif cfg.INFERENCE.DO_SINGLY:
test_func = trainer.test_one_neuron
else:
test_func = trainer.test
test_func() if args.inference else trainer.train()
elif cfg.DATASET.DO_MULTI_VOLUME:
if cfg.INFERENCE.GET_PC_FEATURE != 'None':
if cfg.INFERENCE.GET_PC_FEATURE == 'Tracing':
trainer.get_pc_feature_test(mode, rank=args.local_rank)
elif cfg.INFERENCE.GET_PC_FEATURE == 'Train':
trainer.get_pc_feature(mode, rank=args.local_rank)
elif cfg.INFERENCE.GET_PC_FEATURE == 'GPT':
result_center_cord_dict, result_fafb_cord_dict, result_embedding_dict = trainer.get_pc_feature_gpt(mode, rank=args.local_rank)
if not ('blocks' in cfg.DATASET.OUTPUT_PATH):
save_feature([result_center_cord_dict, result_fafb_cord_dict, result_embedding_dict, list(seq_data.keys())], cfg.DATASET.OUTPUT_PATH)
elif cfg.INFERENCE.GET_PATCH_FEATURE:
trainer.test_patch(mode, rank=args.local_rank)
else:
trainer.run_multivolume(mode, rank=args.local_rank)
print("Rank: {}. Device: {}. Process is finished!".format(
args.local_rank, device))
if __name__ == "__main__":
main()