-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_mp.py
29 lines (25 loc) · 919 Bytes
/
main_mp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import hydra
import torch
import logging
from utils import config_logging, get_log_dict
from core import train
import torch.multiprocessing as mp
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# device='cpu'
logger = logging.getLogger(__name__)
config_logging("main_mp.log")
torch.set_float32_matmul_precision('high')
@hydra.main(config_path="cfgs", config_name="config", version_base="1.3")
def main(cfg):
logger = logging.getLogger(__name__)
manager = mp.Manager()
num_seeds = len(cfg.seeds)
barrier = manager.Barrier(num_seeds)
log_dict = get_log_dict(cfg.agent._target_, manager, num_seeds)
pool = mp.Pool(num_seeds)
pool.starmap(train, [(cfg, seed, log_dict, idx, logger, barrier) for (idx, seed) in enumerate(cfg.seeds)])
pool.close()
pool.join()
if __name__ == "__main__":
mp.set_start_method('spawn') # set spawn for linux servers
main()