From 817453044f2c2065acec547199e782d47ec7a77e Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Tue, 24 Nov 2020 18:11:55 -0500 Subject: [PATCH] fixed other dqn --- pl_bolts/models/rl/double_dqn_model.py | 14 ++------------ pl_bolts/models/rl/per_dqn_model.py | 14 ++------------ 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/pl_bolts/models/rl/double_dqn_model.py b/pl_bolts/models/rl/double_dqn_model.py index 284c328f2d..150ea14dd9 100644 --- a/pl_bolts/models/rl/double_dqn_model.py +++ b/pl_bolts/models/rl/double_dqn_model.py @@ -65,27 +65,17 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD if self.global_step % self.sync_rate == 0: self.target_net.load_state_dict(self.net.state_dict()) - log = { + self.log_dict({ "total_reward": self.total_rewards[-1], "avg_reward": self.avg_rewards, "train_loss": loss, # "episodes": self.total_episode_steps, - } - status = { - "steps": self.global_step, - "avg_reward": self.avg_rewards, - "total_reward": self.total_rewards[-1], - "episodes": self.done_episodes, - # "episode_steps": self.episode_steps, - "epsilon": self.agent.epsilon, - } + }) return OrderedDict( { "loss": loss, "avg_reward": self.avg_rewards, - "log": log, - "progress_bar": status, } ) diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py index 69fe61bbe9..ec8636265e 100644 --- a/pl_bolts/models/rl/per_dqn_model.py +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -130,27 +130,17 @@ def training_step(self, batch, _) -> OrderedDict: if self.global_step % self.sync_rate == 0: self.target_net.load_state_dict(self.net.state_dict()) - log = { + self.log_dict({ "total_reward": self.total_rewards[-1], "avg_reward": self.avg_rewards, "train_loss": loss, # "episodes": self.total_episode_steps, - } - status = { - "steps": self.global_step, - "avg_reward": self.avg_rewards, - "total_reward": self.total_rewards[-1], - "episodes": self.done_episodes, - # "episode_steps": self.episode_steps, - "epsilon": self.agent.epsilon, - } + }) return OrderedDict( { "loss": loss, "avg_reward": self.avg_rewards, - "log": log, - "progress_bar": status, } )