-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change coordinatization of AutoMultivariateNormal #2963
Conversation
is there any reason to think that this works better than the non-overparameterized version i implemented in pyro-ppl/numpyro#1146? i would be mildly surprised if that were the case |
I would expect this PR to work about the same as the non-overreparametrized version, but with only a fraction of the coding effort. You could consider this PR a stepping stone, in case you want to implement the fancier version in a subsequent PR. |
) | ||
|
||
def get_base_dist(self): | ||
return dist.Normal( | ||
torch.zeros_like(self.loc), torch.zeros_like(self.loc) | ||
torch.zeros_like(self.loc), torch.ones_like(self.loc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, what a bug!? It is surprised to me that we didn't catch this earlier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, apparently nobody uses AutoMultivariateNormal
with NeuTraReparam
. My motivation for this PR is to write a tutorial about autoguides, so hopefully they'll get more exposure
|
||
def _loc_scale(self, *args, **kwargs): | ||
return self.loc, self.scale_tril.diag() | ||
return self.loc, self.scale * self.scale_tril.diag() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does scale
need an unsqueeze here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope, scale
is already the correct shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh i see this is only for marginals
Thanks for reviewing @martinjankowiak! |
Addresses #2924
This adds a row-wise
.scale
parameter toAutoLowRankMultivariateNormal
. The resulting overparametrized triple (loc,scale,scale_tril) has been observed to speed up learning inAutoLowRankMultivariateNormal
and inAutoStructured
.This also fixes
.get_base_dist()
and adds tests of use inNeuTraReparam
, and moves two tests from tests/contrib/autoguide to tests/infer/autoguide.Tested
.get_base_dist()
covered by a new regression test