Skip to content

Commit

Permalink
Merge pull request #261 from kuni-kuni/ppo_batch_states
Browse files Browse the repository at this point in the history
Enable to change batch_states in PPO
  • Loading branch information
muupan authored May 9, 2018
2 parents 5ac2f72 + dfca1f6 commit 61c44f4
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions chainerrl/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, model, optimizer,
clip_eps_vf=None,
standardize_advantages=True,
average_v_decay=0.999, average_loss_decay=0.99,
batch_states=batch_states,
):
self.model = model

Expand Down Expand Up @@ -98,6 +99,8 @@ def __init__(self, model, optimizer,
self.average_loss_entropy = 0
self.average_loss_decay = average_loss_decay

self.batch_states = batch_states

self.xp = self.model.xp
self.last_state = None

Expand All @@ -107,7 +110,7 @@ def __init__(self, model, optimizer,
def _act(self, state):
xp = self.xp
with chainer.using_config('train', False):
b_state = batch_states([state], xp, self.phi)
b_state = self.batch_states([state], xp, self.phi)
with chainer.no_backprop_mode():
action_distrib, v = self.model(b_state)
action = action_distrib.sample()
Expand Down Expand Up @@ -200,7 +203,7 @@ def update(self):
dataset_iter.reset()
while dataset_iter.epoch < self.epochs:
batch = dataset_iter.__next__()
states = batch_states([b['state'] for b in batch], xp, self.phi)
states = self.batch_states([b['state'] for b in batch], xp, self.phi)
actions = xp.array([b['action'] for b in batch])
distribs, vs_pred = self.model(states)
with chainer.no_backprop_mode():
Expand All @@ -224,7 +227,7 @@ def update(self):
def act_and_train(self, obs, reward):
if hasattr(self.model, 'obs_filter'):
xp = self.xp
b_state = batch_states([obs], xp, self.phi)
b_state = self.batch_states([obs], xp, self.phi)
self.model.obs_filter.experience(b_state)

action, v = self._act(obs)
Expand Down

0 comments on commit 61c44f4

Please sign in to comment.