Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Variable.array instead of Variable.data #336

Merged
merged 6 commits into from
Oct 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ python:
- "2.7"
- "3.5.1"
env:
- CHAINER_VERSION=2
- CHAINER_VERSION=3
- CHAINER_VERSION=stable
# command to install dependencies
install:
- pip install --upgrade pip setuptools wheel
- |
if [[ $CHAINER_VERSION == 2 ]]; then
pip install "chainer<3.0.0"
if [[ $CHAINER_VERSION == 3 ]]; then
pip install "chainer==3.1.0"
else
pip install chainer
pip install pytest-cov
fi
- pip install pytest-cov
- pip install -r requirements.txt --only-binary=numpy,scipy
- pip install jupyter
# atari_py==0.1.4 causes an error
Expand All @@ -38,12 +38,7 @@ script:
- flake8 examples
- autopep8 -r chainerrl tests examples --diff | tee check_autopep8
- test ! -s check_autopep8
- |
if [[ $CHAINER_VERSION == 2 ]]; then
nosetests -a '!gpu,!slow' -x tests --with-coverage --cover-package chainerrl
else
pytest -m "not gpu and not slow" -x tests --cov=chainerrl
fi
- pytest -m "not gpu and not slow" -x tests --cov=chainerrl
- ./test_examples.sh -1
- if [[ $TRAVIS_PYTHON_VERSION == 3.5.1 && $CHAINER_VERSION == stable ]]; then jupyter nbconvert --to notebook --execute examples/quickstart/quickstart.ipynb --ExecutePreprocessor.timeout=600; fi
after_success:
Expand Down
32 changes: 16 additions & 16 deletions chainerrl/action_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ class DiscreteActionValue(ActionValue):

def __init__(self, q_values, q_values_formatter=lambda x: x):
assert isinstance(q_values, chainer.Variable)
self.xp = cuda.get_array_module(q_values.data)
self.xp = cuda.get_array_module(q_values.array)
self.q_values = q_values
self.n_actions = q_values.data.shape[1]
self.n_actions = q_values.array.shape[1]
self.q_values_formatter = q_values_formatter

@cached_property
def greedy_actions(self):
return chainer.Variable(
self.q_values.data.argmax(axis=1).astype(np.int32))
self.q_values.array.argmax(axis=1).astype(np.int32))

@cached_property
def max(self):
Expand All @@ -92,8 +92,8 @@ def compute_expectation(self, beta):

def __repr__(self):
return 'DiscreteActionValue greedy_actions:{} q_values:{}'.format(
self.greedy_actions.data,
self.q_values_formatter(self.q_values.data))
self.greedy_actions.array,
self.q_values_formatter(self.q_values.array))

@property
def params(self):
Expand All @@ -117,17 +117,17 @@ def __init__(self, q_dist, z_values, q_values_formatter=lambda x: x):
assert z_values.ndim == 1
assert q_dist.shape[2] == z_values.shape[0]

self.xp = cuda.get_array_module(q_dist.data)
self.xp = cuda.get_array_module(q_dist.array)
self.z_values = z_values
self.q_values = F.sum(F.scale(q_dist, self.z_values, axis=2), axis=2)
self.q_dist = q_dist
self.n_actions = q_dist.data.shape[1]
self.n_actions = q_dist.array.shape[1]
self.q_values_formatter = q_values_formatter

@cached_property
def greedy_actions(self):
return chainer.Variable(
self.q_values.data.argmax(axis=1).astype(np.int32))
self.q_values.array.argmax(axis=1).astype(np.int32))

@cached_property
def max(self):
Expand All @@ -144,7 +144,7 @@ def max_as_distribution(self):
"""
with chainer.force_backprop_mode():
return self.q_dist[self.xp.arange(self.q_values.shape[0]),
self.greedy_actions.data]
self.greedy_actions.array]

def evaluate_actions(self, actions):
return F.select_item(self.q_values, actions)
Expand Down Expand Up @@ -174,8 +174,8 @@ def compute_expectation(self, beta):

def __repr__(self):
return 'DistributionalDiscreteActionValue greedy_actions:{} q_values:{}'.format( # NOQA
self.greedy_actions.data,
self.q_values_formatter(self.q_values.data))
self.greedy_actions.array,
self.q_values_formatter(self.q_values.array))

@property
def params(self):
Expand All @@ -202,7 +202,7 @@ class QuadraticActionValue(ActionValue):
"""

def __init__(self, mu, mat, v, min_action=None, max_action=None):
self.xp = cuda.get_array_module(mu.data)
self.xp = cuda.get_array_module(mu.array)
self.mu = mu
self.mat = mat
self.v = v
Expand All @@ -215,18 +215,18 @@ def __init__(self, mu, mat, v, min_action=None, max_action=None):
else:
self.max_action = self.xp.asarray(max_action, dtype=np.float32)

self.batch_size = self.mu.data.shape[0]
self.batch_size = self.mu.array.shape[0]

@cached_property
def greedy_actions(self):
with chainer.force_backprop_mode():
a = self.mu
if self.min_action is not None:
a = F.maximum(
self.xp.broadcast_to(self.min_action, a.data.shape), a)
self.xp.broadcast_to(self.min_action, a.array.shape), a)
if self.max_action is not None:
a = F.minimum(
self.xp.broadcast_to(self.max_action, a.data.shape), a)
self.xp.broadcast_to(self.max_action, a.array.shape), a)
return a

@cached_property
Expand Down Expand Up @@ -254,7 +254,7 @@ def compute_double_advantage(self, actions, argmax_actions):

def __repr__(self):
return 'QuadraticActionValue greedy_actions:{} v:{}'.format(
self.greedy_actions.data, self.v.data)
self.greedy_actions.array, self.v.array)

@property
def params(self):
Expand Down
18 changes: 9 additions & 9 deletions chainerrl/agents/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def update(self, statevar):
else:
with state_kept(self.model):
_, vout = self.model.pi_and_v(statevar)
R = float(vout.data)
R = float(vout.array)

pi_loss = 0
v_loss = 0
Expand All @@ -182,13 +182,13 @@ def update(self, statevar):
advantage = R - v
if self.use_average_reward:
self.average_reward += self.average_reward_tau * \
float(advantage.data)
float(advantage.array)
# Accumulate gradients of policy
log_prob = self.past_action_log_prob[i]
entropy = self.past_action_entropy[i]

# Log probability is increased proportionally to advantage
pi_loss -= log_prob * float(advantage.data)
pi_loss -= log_prob * float(advantage.array)
# Entropy is maximized
pi_loss -= self.beta * entropy
# Accumulate gradients of value function
Expand All @@ -213,7 +213,7 @@ def update(self, statevar):
v_loss /= self.t - self.t_start

if self.process_idx == 0:
logger.debug('pi_loss:%s v_loss:%s', pi_loss.data, v_loss.data)
logger.debug('pi_loss:%s v_loss:%s', pi_loss.array, v_loss.array)

total_loss = F.squeeze(pi_loss) + F.squeeze(v_loss)

Expand Down Expand Up @@ -256,7 +256,7 @@ def act_and_train(self, obs, reward):

self.past_states[self.t] = statevar
pout, vout = self.model.pi_and_v(statevar)
action = pout.sample().data # Do not backprop through sampled actions
action = pout.sample().array # Do not backprop through sampled actions
self.past_action_log_prob[self.t] = pout.log_prob(action)
self.past_action_entropy[self.t] = pout.entropy
self.past_values[self.t] = vout
Expand All @@ -268,10 +268,10 @@ def act_and_train(self, obs, reward):
# Update stats
self.average_value += (
(1 - self.average_value_decay) *
(float(vout.data[0]) - self.average_value))
(float(vout.array[0]) - self.average_value))
self.average_entropy += (
(1 - self.average_entropy_decay) *
(float(pout.entropy.data[0]) - self.average_entropy))
(float(pout.entropy.array[0]) - self.average_entropy))
return action

def act(self, obs):
Expand All @@ -280,9 +280,9 @@ def act(self, obs):
statevar = self.batch_states([obs], np, self.phi)
pout, _ = self.model.pi_and_v(statevar)
if self.act_deterministically:
return pout.most_probable.data[0]
return pout.most_probable.array[0]
else:
return pout.sample().data[0]
return pout.sample().array[0]

def stop_episode_and_train(self, state, reward, done=False):
self.past_rewards[self.t - 1] = reward
Expand Down
50 changes: 25 additions & 25 deletions chainerrl/agents/acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@


def compute_importance(pi, mu, x):
return np.nan_to_num(float(pi.prob(x).data) / float(mu.prob(x).data))
return np.nan_to_num(float(pi.prob(x).array) / float(mu.prob(x).array))


def compute_full_importance(pi, mu):
pimu = pi.all_prob.data / mu.all_prob.data
pimu = pi.all_prob.array / mu.all_prob.array
# NaN occurs when inf/inf or 0/0
pimu[np.isnan(pimu)] = 1
pimu = np.nan_to_num(pimu)
Expand All @@ -50,8 +50,8 @@ def compute_policy_gradient_full_correction(
correction_weight = (
np.maximum(1 - truncation_threshold * rho_all_inv,
np.zeros_like(rho_all_inv)) *
action_distrib.all_prob.data[0])
correction_advantage = action_value.q_values.data[0] - v
action_distrib.all_prob.array[0])
correction_advantage = action_value.q_values.array[0] - v
return -F.sum(correction_weight *
action_distrib.all_log_prob *
correction_advantage, axis=1)
Expand All @@ -64,15 +64,15 @@ def compute_policy_gradient_sample_correction(
assert np.isscalar(v)
assert truncation_threshold is not None
with chainer.no_backprop_mode():
sample_action = action_distrib.sample().data
sample_action = action_distrib.sample().array
rho_dash_inv = compute_importance(
action_distrib_mu, action_distrib, sample_action)
if (truncation_threshold > 0 and
rho_dash_inv >= 1 / truncation_threshold):
return chainer.Variable(np.asarray([0], dtype=np.float32))
correction_weight = max(0, 1 - truncation_threshold * rho_dash_inv)
assert correction_weight <= 1
q = float(action_value.evaluate_actions(sample_action).data[0])
q = float(action_value.evaluate_actions(sample_action).array[0])
correction_advantage = q - v
return -(correction_weight *
action_distrib.log_prob(sample_action) *
Expand Down Expand Up @@ -155,7 +155,7 @@ def __call__(self, obs):
v = self.v(obs)

def evaluator(action):
adv_mean = sum(self.adv(obs, action_distrib.sample().data)
adv_mean = sum(self.adv(obs, action_distrib.sample().array)
for _ in range(self.n)) / self.n
return v + self.adv(obs, action) - adv_mean

Expand Down Expand Up @@ -239,7 +239,7 @@ def compute_loss_with_kl_constraint(distrib, another_distrib, original_loss,
loss = 0
for p, zp in zip(distrib.params, z):
loss += F.sum(p * zp)
return F.reshape(loss, original_loss.shape), float(kl.data)
return F.reshape(loss, original_loss.shape), float(kl.array)


class ACER(agent.AttributeSavingMixin, agent.AsyncAgent):
Expand Down Expand Up @@ -446,16 +446,16 @@ def compute_loss(
assert np.isscalar(Q_ret)
assert np.isscalar(Q_opc)
if self.use_Q_opc:
advantage = Q_opc - float(v.data)
advantage = Q_opc - float(v.array)
else:
advantage = Q_ret - float(v.data)
advantage = Q_ret - float(v.array)
pi_loss += self.compute_one_step_pi_loss(
action=ba,
advantage=advantage,
action_distrib=action_distrib,
action_distrib_mu=action_distrib_mu,
action_value=action_value,
v=float(v.data),
v=float(v.array),
avg_action_distrib=avg_action_distrib)

# Accumulate gradients of value function
Expand All @@ -466,21 +466,21 @@ def compute_loss(
if not discrete:
assert isinstance(v, chainer.Variable), \
"v must be backprop-able"
v_target = (min(1, rho) * (Q_ret - float(Q.data)) +
float(v.data))
v_target = (min(1, rho) * (Q_ret - float(Q.array)) +
float(v.array))
Q_loss += (v_target - v) ** 2 / 2

if self.process_idx == 0:
self.logger.debug(
't:%s v:%s Q:%s Q_ret:%s Q_opc:%s',
i, float(v.data), float(Q.data), Q_ret, Q_opc)
i, float(v.array), float(Q.array), Q_ret, Q_opc)

if discrete:
c = min(1, rho)
else:
c = min(1, rho ** (1 / ba.size))
Q_ret = c * (Q_ret - float(Q.data)) + float(v.data)
Q_opc = Q_opc - float(Q.data) + float(v.data)
Q_ret = c * (Q_ret - float(Q.array)) + float(v.array)
Q_opc = Q_opc - float(Q.array) + float(v.array)

pi_loss *= self.pi_loss_coef
Q_loss *= self.Q_loss_coef
Expand All @@ -491,9 +491,9 @@ def compute_loss(

if self.process_idx == 0:
self.logger.debug('pi_loss:%s Q_loss:%s',
pi_loss.data, Q_loss.data)
pi_loss.array, Q_loss.array)

return pi_loss + F.reshape(Q_loss, pi_loss.data.shape)
return pi_loss + F.reshape(Q_loss, pi_loss.array.shape)

def update(self, t_start, t_stop, R, states, actions, rewards, values,
action_values, action_distribs, action_distribs_mu,
Expand Down Expand Up @@ -576,7 +576,7 @@ def update_from_replay(self):
last_s = last_transition['next_state']
action_distrib, action_value, last_v = self.model(
np.expand_dims(self.phi(last_s), 0))
R = float(last_v.data)
R = float(last_v.array)
return self.update(
R=R, t_start=0, t_stop=len(episode),
states=states, rewards=rewards,
Expand All @@ -597,7 +597,7 @@ def update_on_policy(self, statevar):
with chainer.no_backprop_mode():
with state_kept(self.model):
action_distrib, action_value, v = self.model(statevar)
R = float(v.data)
R = float(v.array)
self.update(
t_start=self.t_start, t_stop=self.t, R=R,
states=self.past_states,
Expand Down Expand Up @@ -625,7 +625,7 @@ def act_and_train(self, obs, reward):
self.past_states[self.t] = statevar
action_distrib, action_value, v = self.model(statevar)
self.past_action_values[self.t] = action_value
action = action_distrib.sample().data[0]
action = action_distrib.sample().array[0]

# Save values for a later update
self.past_values[self.t] = v
Expand All @@ -645,10 +645,10 @@ def act_and_train(self, obs, reward):
# Update stats
self.average_value += (
(1 - self.average_value_decay) *
(float(v.data[0]) - self.average_value))
(float(v.array[0]) - self.average_value))
self.average_entropy += (
(1 - self.average_entropy_decay) *
(float(action_distrib.entropy.data[0]) - self.average_entropy))
(float(action_distrib.entropy.array[0]) - self.average_entropy))

if self.replay_buffer is not None and self.last_state is not None:
assert self.last_action is not None
Expand Down Expand Up @@ -676,9 +676,9 @@ def act(self, obs):
statevar = np.expand_dims(self.phi(obs), 0)
action_distrib, _, _ = self.model(statevar)
if self.act_deterministically:
return action_distrib.most_probable.data[0]
return action_distrib.most_probable.array[0]
else:
return action_distrib.sample().data[0]
return action_distrib.sample().array[0]

def stop_episode_and_train(self, state, reward, done=False):
assert self.last_state is not None
Expand Down
2 changes: 1 addition & 1 deletion chainerrl/agents/categorical_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _compute_target_values(self, exp_batch, gamma):
n_atoms = z_values.size

# next_q_max: (batch_size, n_atoms)
next_q_max = target_next_qout.max_as_distribution.data
next_q_max = target_next_qout.max_as_distribution.array
assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape

# Tz: (batch_size, n_atoms)
Expand Down
Loading