-
Notifications
You must be signed in to change notification settings - Fork 7
/
tester.py
158 lines (138 loc) · 7.94 KB
/
tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import collections
import gym
import numpy as np
import ray
from model import EmbeddingNet, LifeLongNet, QNetwork
from utils import UCB, create_beta_list, get_preprocess_func, play_episode
@ray.remote(num_cpus=1)
class Tester:
"""
calculate score to evaluate peformance
Attributes:
env_name (str): name of environment
n_frames (int): number of images to be stacked
env (gym object): environment
action_space (int): dim of action space
frame_process_func : function to preprocess images
in_q_network : q network about intrinsic reward
ex_q_network : q network about extrinsic reward
embedding_net : embedding network to get episodic reward
embedding_classifier : classify action based on embedding representation
original_lifelong_net : lifelong network not to be trained
trained_lifelong_net : lifelong network to be trained
ucb : object of UCB class which solve a multi-armed bandit problem
betas (list): list of beta which decide weights between intrinsic qvalues and extrinsic qvalues
k (int): number of neighbors referenced when calculating episode reward
L (int): upper limit of curiosity
error_list : list of errors to be accommodated when calculating lifelong reward
switch_test_cycle (int): how often to switch test cycle from collecting ucb data cycle
is_test (bool): flag indicating whether it is a test or not
count (int): number of times to play test
"""
def __init__(self,
env_name,
n_frames,
k,
L,
num_arms,
window_size,
ucb_epsilon,
ucb_beta,
switch_test_cycle,
original_lifelong_weight):
"""
Args:
env_name (str): name of environment
n_frames (int): number of images to be stacked
k (int): number of neighbors referenced when calculating episode reward
L (int): upper limit of curiosity
num_arms (int): number of arms used in multi-armed bandit problem
window_size (int): size of window used in multi-armed bandit problem
ucb_epsilon (float): probability to select randomly used in multi-armed bandit problem
ucb_beta (float): weight between frequency and mean reward
switch_test_cycle (int): how often to switch test cycle from collecting ucb data cycle
original_lifelong_weight : original weight of lifelong network
"""
self.env_name = env_name
self.env = gym.make(self.env_name)
self.frame_process_func = get_preprocess_func(env_name)
self.n_frames = n_frames
self.action_space = self.env.action_space.n
self.in_q_network = QNetwork(self.action_space, n_frames)
self.ex_q_network = QNetwork(self.action_space, n_frames)
self.embedding_net = EmbeddingNet(n_frames)
self.original_lifelong_net = LifeLongNet(n_frames)
self.trained_lifelong_net = LifeLongNet(n_frames)
self.ucb = UCB(num_arms, window_size, ucb_epsilon, ucb_beta)
self.betas = create_beta_list(num_arms)
self.error_list = collections.deque(maxlen=int(1e4))
self.k = k
self.L = L
self.switch_test_cycle = switch_test_cycle
self.original_lifelong_net.load_state_dict(original_lifelong_weight)
self.is_test = False
self.count = 0
def test_play(self, in_q_weight, ex_q_weight, embed_weight, lifelong_weight):
"""
load weight and get score which is average of episode rewards
Args:
in_q_weight : weight of intrinsic q network
ex_q_weight : weight of extrinsic q network
embed_weight : weight of embedding network
lifelong_weight : weight of lifelong network
Returns:
np.mean(self.episode_reward) (np.ndarray): average of episode rewards while test
"""
if self.count % (self.switch_test_cycle//2) == 0:
self.in_q_network.load_state_dict(in_q_weight)
self.ex_q_network.load_state_dict(ex_q_weight)
self.embedding_net.load_state_dict(embed_weight)
self.trained_lifelong_net.load_state_dict(lifelong_weight)
j = self.ucb.pull_index()
beta = self.betas[j]
# get episode reward
if self.is_test:
_, episode_reward, self.error_list = play_episode(frame_process_func=self.frame_process_func,
env_name=self.env_name,
n_frames=self.n_frames,
action_space=self.action_space,
j=j,
epsilon=0.0,
k=self.k,
error_list=self.error_list,
L=self.L,
beta=beta,
in_q_network=self.in_q_network,
ex_q_network=self.ex_q_network,
embedding_net=self.embedding_net,
original_lifelong_net=self.original_lifelong_net,
trained_lifelong_net=self.trained_lifelong_net,
is_test=True)
self.episode_reward.append(episode_reward)
self.count += 1
# restore ucb datas
else:
ucb_datas, _, self.error_list = play_episode(frame_process_func=self.frame_process_func,
env_name=self.env_name,
n_frames=self.n_frames,
action_space=self.action_space,
j=j,
epsilon=0.01,
k=self.k,
error_list=self.error_list,
L=self.L,
beta=beta,
in_q_network=self.in_q_network,
ex_q_network=self.ex_q_network,
embedding_net=self.embedding_net,
original_lifelong_net=self.original_lifelong_net,
trained_lifelong_net=self.trained_lifelong_net,
is_test=True)
self.ucb.push_data(ucb_datas)
self.count += 1
if self.count % self.switch_test_cycle == (self.switch_test_cycle//2):
self.is_test = True
self.episode_reward = []
elif self.count % self.switch_test_cycle == 0:
self.is_test = False
return np.mean(self.episode_reward)