-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
56 lines (44 loc) · 1.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import logging
import os
import torch
import torch.distributed as dist
from streaming.base.util import clean_stale_shared_memory
from torch.distributed.elastic.multiprocessing.errors import record
from progen_conditional.composer import get_trainer
def setup_dist():
rank = int(os.environ.get("RANK", -1))
if dist.is_available() and torch.cuda.is_available() and rank != -1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="config/default.yml")
parser.add_argument("--new-run", action="store_true", default=False)
parser.add_argument("--disable-logging", action="store_true", default=False)
parser.add_argument("--debug", action="store_true", default=False)
args = parser.parse_args()
return args
@record
def main():
os.chdir(os.path.dirname(os.path.realpath(__file__))) #change directory to the current directory
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
format="[%(asctime)s] [%(levelname)s] [%(name)s:%(lineno)s:%(funcName)s] %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
setup_dist()
clean_stale_shared_memory()
args = parse_args()
trainer = get_trainer(
args.config,
force_new_run=args.new_run,
disable_logging=args.disable_logging,
debug=args.debug,
)
if trainer.state.max_duration >= trainer.state.timestamp.get(trainer.state.max_duration.unit):
trainer.fit()
torch.cuda.empty_cache()
if __name__ == "__main__":
main()