-
Notifications
You must be signed in to change notification settings - Fork 9
/
main.py
38 lines (30 loc) · 1.04 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from runner import Runner
import util
import options
import colored_traceback.always
from ipdb import set_trace as debug
print(util.yellow("======================================================="))
print(util.yellow(" Likelihood-Training of Schrodinger Bridge"))
print(util.yellow("======================================================="))
print(util.magenta("setting configurations..."))
opt = options.set()
def main(opt):
run = Runner(opt)
# ====== Training functions ======
if opt.train_method=='alternate':
run.sb_alternate_train(opt)
elif opt.train_method=='joint':
run.sb_joint_train(opt)
# ====== Test functions ======
elif opt.compute_FID:
run.evaluate(opt, util.get_load_it(opt.load), metrics=['FID','snapshot'])
elif opt.compute_NLL:
run.compute_NLL(opt)
else:
raise RuntimeError()
if not opt.cpu:
with torch.cuda.device(opt.gpu):
main(opt)
else: main(opt)