Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework garage.torch.optimizers #2177

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Conversation

krzentner
Copy link
Contributor

No description provided.

@krzentner krzentner requested a review from a team as a code owner November 16, 2020 18:34
@krzentner krzentner requested review from ahtsan, irisliucy and ryanjulian and removed request for a team and ahtsan November 16, 2020 18:34
@krzentner
Copy link
Contributor Author

This change does not yet pass tests, but is 90% complete.

@mergify mergify bot requested review from a team and zequnyu and removed request for a team November 16, 2020 18:34
@ryanjulian
Copy link
Member

Can you add a little bit more explanation for the design here? I'm concerned about using an ADT as the blanket input to policies, which makes the interface pretty complicated even in the simplest use cases.

@krzentner
Copy link
Contributor Author

krzentner commented Nov 18, 2020

The core motivation here is to provide a way for recurrent and non-recurrent policies to share the same API at optimization time.
However, I definitely agree that making this change has shown me that it significantly increases the complexity of the garage.torch APIs. It doesn't result in a significant increase in complexity in any algorithm (except for TutorialVPG), but it's noticeable.
In the future, I also generally intended for this datatype to play a role (although with a very different design to) state_info_spec in the TF branch.

This PR only adds the bare minimum fields needed for recurrent policies to have reasonable .forward methods. However, we could replace the observation field on PolicyInput by instead having PolicyInput inherit from torch.Tensor.
Then, algorithms that only want to train stochastic non-recurrent policies (i.e. SAC), could just pass a torch.Tensor (as they do now). Alternatively, we could use a helper function at the start of every torch policies .forward method to convert any torch.Tensor input into a PolicyInput (in SHUFFLED mode).

cnn_output = self._cnn_module(observations)
mlp_output = self._mlp_module(cnn_output)[0]
logits = torch.softmax(mlp_output, axis=1)
dist = torch.probability.Categorical(logits=logits)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be torch.distributions.Categorical ?

@mergify mergify bot requested a review from a team November 19, 2020 18:11
@krzentner krzentner changed the title Rework StochasticPolicy to use PolicyInput Rework garage.torch.optimizers Jan 19, 2021
WIP torch optimizer refactor

WIP torch optimizer refactor

WIP
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants