-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an AutoStructured guide and StructuredReparam #2812
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean to me! I can confirm that this is equivalent to if we learn arrowhead matrix directly:
A = [[L @ L.t + w @ D @ w.T, w @ D], [D @ w.T, D]]
where L is scale_tril of x_aux
, D is the variance of y_aux
, w is dep.weight.
pyro/infer/autoguide/guides.py
Outdated
scale_tril = scale[..., None] * scale_tril | ||
aux_value = pyro.sample( | ||
name + "_aux", | ||
dist.MultivariateNormal(zero, scale_tril=scale_tril), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we factor this out to scale_tril @ Normal(0, 1)
, I guess HMC will be a bit happier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point, I guess that is equivalent to reparametrizing. I've also changed Normal(0,scale)
to Normal(0,1) * scale
in the "normal" case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you will need to add logdet of those affine transforms. How about using dist.TransformedDistribution(dist.Normal(...), LowerCholeskyAffine(...))
so that we can use TransformReparam
in the reparam?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I've added the logdet terms by hand here since it's simpler. Does it look right now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it looks correct to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm seeing very different results with the two versions, and this change seems to have broken my SVI inference. I've been staring at these two versions and I can't seem to see the difference:
# Version 1. This works.
aux_value = pyro.sample(..., Normal(zero, scale).to_event(1))
# Version 2. This is in pyro dev, but no longer works.
aux_value = pyro.sample(..., Normal(zero, 1).to_event(1))
aux_value = aux_value * scale
log_density = log_density - scale.log().sum(-1)
Any ideas @fehiepsi?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe two versions are equivalent... Not sure what's going on. Let me play with some tests to see if elbo is the same for the two versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks I'll do the same, at least to create a unit test I can run locally (not on some huge model on a GPU cloud machine)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I find the issue. Here log_density
is calculated as the sum over all dimensions of the site. However, the ldj
term, which is used to calculate the logdet of unconstrained->constrained values, maintains the batch dimension. So sum of them will give wrong result if this site is under some plate. I guess we should use pyro.factor
for those log_density
terms. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi thanks, yes I now see the error. I'll think about this and submit a fix ASAP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! I'll try to find some time next week to port this to NumPyro (hopefully it will be straightforward).
Addresses #2813
This adds a flexible
AutoStructured
guide that allows a variety of distributions modeling each latent site (Delta, Normal, or MultivariateNormal), together with a mechanism to declare (link-)linear dependencies between latent variables. As discussed with @fehiepsi this aims to (1) generalize guides with arrowhead covariance structure while (2) learning parameters that can be cheaply used to precondition NUTS via a reparameterizerStructuredReparam
.This also adds a simple
StructuredReparam
that uses a trainedAutoStructured
guide to precondition a model for use in HMC. This new (guide,reparam) pair can be seen as a structured version of the monolithic (AutoContinuous
,NeuTraReparam
) pair in the same sense thatAutoNormal
is a structured version of the monolithicAutoDiagonalNormal
guide.My main motivation is to use this for high-dimensional models (e.g. 300000 latent variables) with a structured precision matrix, and then use that structured precision matrix as a preconditioner for NUTS.
(Note this does not implement Automatic structured variational inference, a variational family whose stricture is severely limited to dependencies in the model. Nor does this first PR implement automatic suggestion of the guide structure as in Faithful inversion of generative models for effective amortized inference.)
Tested
AutoStructured
StructuredReparam