diff --git a/bsuite/environments/cartpole.py b/bsuite/environments/cartpole.py index 3dfb93e..f5c283e 100644 --- a/bsuite/environments/cartpole.py +++ b/bsuite/environments/cartpole.py @@ -145,9 +145,9 @@ def step(self, action): reward = 1. if is_reward else 0. self._raw_return += reward self._episode_return += reward - self._best_episode = max(self._episode_return, self._best_episode) if self._state.time_elapsed > self._max_time or not is_reward: + self._best_episode = max(self._episode_return, self._best_episode) self._reset_next_step = True return dm_env.termination(reward=reward, observation=self.observation) return dm_env.transition(reward=reward, observation=self.observation) diff --git a/bsuite/experiments/cartpole_swingup/cartpole_swingup.py b/bsuite/experiments/cartpole_swingup/cartpole_swingup.py index 0c2eb38..da04af4 100644 --- a/bsuite/experiments/cartpole_swingup/cartpole_swingup.py +++ b/bsuite/experiments/cartpole_swingup/cartpole_swingup.py @@ -111,11 +111,11 @@ def step(self, action): self._total_upright += 1 self._raw_return += reward self._episode_return += reward - self._best_episode = max(self._episode_return, self._best_episode) is_end_of_episode = (self._state.time_elapsed > self._max_time or np.abs(self._state.x) > self._x_threshold) if is_end_of_episode: + self._best_episode = max(self._episode_return, self._best_episode) self._reset_next_step = True return dm_env.termination(reward=reward, observation=self.observation) else: # continuing transition.