Skip to content

Commit

Permalink
Upgrade DQN to use .log (Lightning-Universe#404)
Browse files Browse the repository at this point in the history
* Upgrade DQN to use .log

* remove unused

* pep8

* fixed other dqn
  • Loading branch information
teddykoker authored and chris-clem committed Dec 9, 2020
1 parent b17a5d8 commit 1bd9ae7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 38 deletions.
14 changes: 2 additions & 12 deletions pl_bolts/models/rl/double_dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,17 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())

log = {
self.log_dict({
"total_reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,
"train_loss": loss,
# "episodes": self.total_episode_steps,
}
status = {
"steps": self.global_step,
"avg_reward": self.avg_rewards,
"total_reward": self.total_rewards[-1],
"episodes": self.done_episodes,
# "episode_steps": self.episode_steps,
"epsilon": self.agent.epsilon,
}
})

return OrderedDict(
{
"loss": loss,
"avg_reward": self.avg_rewards,
"log": log,
"progress_bar": status,
}
)

Expand Down
18 changes: 4 additions & 14 deletions pl_bolts/models/rl/dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,28 +288,18 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())

log = {
self.log_dict({
"total_reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,
"train_loss": loss,
"episodes": self.done_episodes,
"episode_steps": self.total_episode_steps[-1]
}
status = {
"steps": self.global_step,
"avg_reward": self.avg_rewards,
"total_reward": self.total_rewards[-1],
"episodes": self.done_episodes,
"episode_steps": self.total_episode_steps[-1],
"epsilon": self.agent.epsilon,
}
})

return OrderedDict(
{
"loss": loss,
"avg_reward": self.avg_rewards,
"log": log,
"progress_bar": status,
}
)

Expand All @@ -323,8 +313,8 @@ def test_epoch_end(self, outputs) -> Dict[str, torch.Tensor]:
"""Log the avg of the test results"""
rewards = [x["test_reward"] for x in outputs]
avg_reward = sum(rewards) / len(rewards)
tensorboard_logs = {"avg_test_reward": avg_reward}
return {"avg_test_reward": avg_reward, "log": tensorboard_logs}
self.log("avg_test_reward", avg_reward)
return {"avg_test_reward": avg_reward}

def configure_optimizers(self) -> List[Optimizer]:
""" Initialize Adam optimizer"""
Expand Down
14 changes: 2 additions & 12 deletions pl_bolts/models/rl/per_dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,17 @@ def training_step(self, batch, _) -> OrderedDict:
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())

log = {
self.log_dict({
"total_reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,
"train_loss": loss,
# "episodes": self.total_episode_steps,
}
status = {
"steps": self.global_step,
"avg_reward": self.avg_rewards,
"total_reward": self.total_rewards[-1],
"episodes": self.done_episodes,
# "episode_steps": self.episode_steps,
"epsilon": self.agent.epsilon,
}
})

return OrderedDict(
{
"loss": loss,
"avg_reward": self.avg_rewards,
"log": log,
"progress_bar": status,
}
)

Expand Down

0 comments on commit 1bd9ae7

Please sign in to comment.