This repository has been archived by the owner on Sep 19, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
storage.py
executable file
·131 lines (96 loc) · 3.35 KB
/
storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import numpy as np
class ACStorage(object):
"""
Class to effeciently store and operate with actor-critic model's data.
"""
def __init__(self, n_steps, states_shape):
"""
Class constructor
Keyword arguments:
n_steps -- number of agent steps in one batch
states_shape -- shape of variable describing state (observation)
"""
self.n_steps = n_steps
self.states = np.zeros((self.n_steps, ) + states_shape)
self.actions = np.zeros(self.n_steps)
self.rewards = np.zeros(self.n_steps)
self.last_step = -1
def insert(self, state, action, reward):
"""
Inserts new triple <state,action,reward> into storage
Keyword arguments:
state -- new state
action -- new action
reward -- new reward
"""
self.last_step += 1
assert self.last_step < self.n_steps, 'storage capacity exceeded'
self.states [self.last_step] = state
self.actions[self.last_step] = action
self.rewards[self.last_step] = reward
def get_states(self):
"""
Get all states from storage
"""
return self.states[:self.last_step+1, ...]
def get_actions(self):
"""
Get all actions from storage
"""
return self.actions[:self.last_step+1]
def get_rewards(self):
"""
Get all rewards from storage
"""
return self.rewards[:self.last_step+1]
def clear(self):
"""
Clear storage
"""
self.last_step = -1
def calc_G_0(self, gamma):
"""
Calculates G_0 = \sum_{r=0}^{storage_size} [gamma^r * reward_r]
Keyword arguments:
gamma -- discount factor, float
"""
n = self.last_step+1
g_0 = self.rewards[n-1]
for i in range(n-2, -1, -1):
g_0 = self.rewards[i] + g_0 * gamma
return g_0
def calc_returns(self, gamma):
"""
Calculates G_t for every time stamp
Keyword arguments:
gamma -- discount factor, float
"""
n = self.last_step+1
returns = np.zeros((n))
returns[-1] = self.rewards[n-1]
for i in range(n-2, -1, -1):
returns[i] = self.rewards[i] + returns[i+1] * gamma
return returns
def calc_gae(self, values, next_value, gamma, k=5.0):
"""
Calculates GAE (https://arxiv.org/pdf/1602.01783.pdf) for every time stamp
Keyword arguments:
values -- values of critic on every time stamp
next_value -- value of critic for next state after current batch
gamma -- discount factor, float
k -- bootstrap size, float (default 5.0)
"""
n = self.last_step+1
assert n == values.data.shape[0], 'shapes mismatch'
k = min(n, k)
gae = np.zeros((n))
gae[n-1] = self.rewards[n-1]
for i in range(n-2, -1, -1):
gae[i] = self.rewards[i] + gamma * gae[i+1]
if k < n:
gae[:-k] += gamma**k * (values[k:].cpu().data.numpy().ravel() - gae[k:])
gamma_buf = 1.0
for i in range(1, k+1):
gamma_buf *= gamma
gae[-i] += gamma_buf * next_value
return gae