Skip to content

Commit

Permalink
fixed other dqn
Browse files Browse the repository at this point in the history
  • Loading branch information
teddykoker committed Nov 24, 2020
1 parent b334c49 commit 8174530
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 24 deletions.
14 changes: 2 additions & 12 deletions pl_bolts/models/rl/double_dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,17 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())

log = {
self.log_dict({
"total_reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,
"train_loss": loss,
# "episodes": self.total_episode_steps,
}
status = {
"steps": self.global_step,
"avg_reward": self.avg_rewards,
"total_reward": self.total_rewards[-1],
"episodes": self.done_episodes,
# "episode_steps": self.episode_steps,
"epsilon": self.agent.epsilon,
}
})

return OrderedDict(
{
"loss": loss,
"avg_reward": self.avg_rewards,
"log": log,
"progress_bar": status,
}
)

Expand Down
14 changes: 2 additions & 12 deletions pl_bolts/models/rl/per_dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,17 @@ def training_step(self, batch, _) -> OrderedDict:
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())

log = {
self.log_dict({
"total_reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,
"train_loss": loss,
# "episodes": self.total_episode_steps,
}
status = {
"steps": self.global_step,
"avg_reward": self.avg_rewards,
"total_reward": self.total_rewards[-1],
"episodes": self.done_episodes,
# "episode_steps": self.episode_steps,
"epsilon": self.agent.epsilon,
}
})

return OrderedDict(
{
"loss": loss,
"avg_reward": self.avg_rewards,
"log": log,
"progress_bar": status,
}
)

Expand Down

0 comments on commit 8174530

Please sign in to comment.