-
Notifications
You must be signed in to change notification settings - Fork 8
/
embedder.py
140 lines (107 loc) · 4.09 KB
/
embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import sys
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import os
import os.path as osp
from tqdm import tqdm
from arguments import get_args
from rlf.rl import algo, utils
from method.emb_mem import EmbeddingMemory
from method.dist_mem import DistributionMemory
from method.embedder.embedder import Embedder
from rlf.rl.logger import Logger
import copy
from main import ExpRunSettings
from rlf.rl import utils
# Embedding Specific
from method.embedder.htvae import HTVAE
from method.embedder.embedder import Embedder
if __name__ == "__main__":
run_settings = ExpRunSettings()
args = run_settings.get_args()
log_args = run_settings.get_set_args()
log_dir = os.path.expanduser(args.log_dir)
trial_log_dir = log_dir + "_trial"
utils.cleanup_log_dir(trial_log_dir)
# Set Seeds
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(1)
args.device = torch.device("cuda:0" if args.cuda else "cpu")
args.grid_playing = True
if args.save_dataset:
args.both_train_test = True
env_trans_fn = run_settings.get_env_trans_fn(args)
emb_mem = EmbeddingMemory(cuda=args.cuda, args=args)
real_name = args.env_name
data_folder = args.play_data_folder
model_folder = args.emb_model_folder
logger = Logger(log_args, './data/embedder/logs/')
logger.set_prefix(args)
# Data Generation Phase:
if args.save_dataset:
train_option_ids = args.overall_aval_actions
embedder = Embedder(
args, trial_log_dir,
env_trans_fn, data_folder,
option_ids=train_option_ids
)
print('No data to load... Generating Dataset')
embedder.generate_dataset(args, trial_log_dir,
env_trans_fn)
print('Loading data... ')
embedder.load_data_params()
print('Successfully Loaded data params... ')
args.load_dataset = True
# Training Phase
if args.train_embeddings or args.resume_emb_training:
args.train_split = True
args.test_split = False
train_option_ids = args.overall_aval_actions
embedder = Embedder(
args, trial_log_dir,
env_trans_fn, data_folder,
option_ids=train_option_ids
)
# If no model to load, then generate training data
if args.load_emb_model_file is None:
# Generate Training Data
if not args.load_dataset:
print('No data to load... Generating Dataset')
embedder.generate_dataset(args, trial_log_dir,
env_trans_fn)
print('Loading data... ')
embedder.load_data_params()
print('Loaded data params... ')
embedder.prepare_model(args, model_folder, logger, method=args.emb_method)
# Testing Phase
elif args.test_embeddings:
test_option_ids = args.overall_aval_actions
embedder = Embedder(args, trial_log_dir,
env_trans_fn, data_folder,
option_ids=test_option_ids)
# Generate Test Data
if not args.load_dataset:
print('No data to load... Generating Dataset')
embedder.generate_dataset(args, trial_log_dir,
env_trans_fn)
print('Loading data...')
embedder.load_data_params()
assert args.load_emb_model_file is not None
# Only loading model in testing phase
embedder.prepare_model(args, model_folder, logger, method=args.emb_method)
dist_mem = DistributionMemory(cuda=args.cuda, args=args)
embedder.eval_dists_from_ids(dist_mem, None, 1)
embedder.eval_embs_from_ids(emb_mem, None, dist_mem=dist_mem)
embedder.visualize_trajectory_embeddings(args.emb_method,
reconstruction=False, emb_mem=emb_mem, dist_mem=dist_mem)