diff --git a/chainerrl/agents/ppo.py b/chainerrl/agents/ppo.py index b65f9779b..f6f59cc72 100644 --- a/chainerrl/agents/ppo.py +++ b/chainerrl/agents/ppo.py @@ -71,6 +71,7 @@ def __init__(self, model, optimizer, clip_eps_vf=None, standardize_advantages=True, average_v_decay=0.999, average_loss_decay=0.99, + batch_states=batch_states, ): self.model = model @@ -98,6 +99,8 @@ def __init__(self, model, optimizer, self.average_loss_entropy = 0 self.average_loss_decay = average_loss_decay + self.batch_states = batch_states + self.xp = self.model.xp self.last_state = None @@ -107,7 +110,7 @@ def __init__(self, model, optimizer, def _act(self, state): xp = self.xp with chainer.using_config('train', False): - b_state = batch_states([state], xp, self.phi) + b_state = self.batch_states([state], xp, self.phi) with chainer.no_backprop_mode(): action_distrib, v = self.model(b_state) action = action_distrib.sample() @@ -200,7 +203,7 @@ def update(self): dataset_iter.reset() while dataset_iter.epoch < self.epochs: batch = dataset_iter.__next__() - states = batch_states([b['state'] for b in batch], xp, self.phi) + states = self.batch_states([b['state'] for b in batch], xp, self.phi) actions = xp.array([b['action'] for b in batch]) distribs, vs_pred = self.model(states) with chainer.no_backprop_mode(): @@ -224,7 +227,7 @@ def update(self): def act_and_train(self, obs, reward): if hasattr(self.model, 'obs_filter'): xp = self.xp - b_state = batch_states([obs], xp, self.phi) + b_state = self.batch_states([obs], xp, self.phi) self.model.obs_filter.experience(b_state) action, v = self._act(obs)