Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ProjectedNormal distribution and reparametrizer #2736

Merged
merged 17 commits into from
Jan 13, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jan 10, 2021

This implements an isotropic ProjectedNormal distribution for inference of latent directional data of arbitrary dimension.

Currently the VonMises and VonMises3d distributions implement .sample() but not .rsample(), and use a bogus real constraint. This PR attempts to support .rsample(), implements a reparametrizer making this distribution compatible with autoguides, and adds a sphere constraint.

Density computation

I have implemented .log_prob() only for dim=2 and dim=3. Note .log_prob() is required for use of this distribution as a likelihood or as a latent variable without poutine.reparam (in either model or guide). Many papers such as (Hernandez-Stumpfhauser et al. 2017) try to generalize to arbitrary covariance matrices. However I'm choosing to standardize to a unit covariance, thus guaranteeing unimodality ("let the guide handle multiple modes, and the distribution focus on a single mode"). The density integral reduces to a definite integral of the gaussian density times a simple polynomial. We can evaluate this integral using Wolfram alpha, but the result involves a Kummer confluent hypergeometric distribution:
image
I suspect the result simplifies to a pair of expressions, one for even n and one for odd n.

Tasks

  • implement a constraints.sphere
  • implement ProjectedNormal.rsample()
  • implement a ProjectedNormal.log_prob() for use as a likelihood
  • implement a ProjectedNormalReparam for use with autoguides
  • test ProjectedNormal
  • test constraints.sphere
  • test ProjectedNormalReparam
  • add helpful error messages to autoguides

@fritzo fritzo added the WIP label Jan 10, 2021


@contextmanager
def helpful_support_errors(site):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

@fritzo
Copy link
Member Author

fritzo commented Jan 13, 2021

Thanks for reviewing @martinjankowiak!

@martinjankowiak martinjankowiak marked this pull request as ready for review January 13, 2021 23:02
@martinjankowiak martinjankowiak merged commit ef27d10 into dev Jan 13, 2021
@fritzo
Copy link
Member Author

fritzo commented Jan 14, 2021

@ahmadsalim @LysSanzMoreta do you have any applications where you're learning latent directional variables? If so I'd be happy to help you try these new distributions to test the new reparametrizers. My original motivation was a 2D model, but I figured you might have 3D models involving protein folding.

@LysSanzMoreta
Copy link
Contributor

Hi @fritzo! Thanks so much for your interest in our models :) . Right now I only have normal distributions on my latent spaces ops, but I will discuss more with my supervisor and the research group and see if we can have something :)

@LysSanzMoreta
Copy link
Contributor

Hi @fritzo I am back, so see if this implementation of TorusDBN (" A generative, probabilistic model of local protein
structure.", W.Boosma et al) in numpyro serves your purposes: https://github.com/aleatory-science/numpyro/blob/feature/stein-vi/examples/stein_vi/torus_dbn.py

The means and the kappas are sampled from a Von Mises distribution

@fritzo
Copy link
Member Author

fritzo commented Jan 19, 2021

Hi @LysSanzMoreta, thanks that's a great example! I'm not certain whether inference will suffer from the support bug this PR aims to address, but I'll think about it and maybe submit a PR to fix it (with a NumPyro implementation of ProjectedNormal).

@LysSanzMoreta
Copy link
Contributor

@fritzo Ok, glad we could help :)

@RylanSchaeffer
Copy link

@fritzo , quick question - I'm interested in using the mean of ProjectedNormal, but I want to make sure I understand the note "Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance."

What submanifold are you referring to here? The surface of the hypersphere? If so, to confirm, that would mean that the mean is not the expected value of the distribution?

Looking at the code, I suspect the mean is not the expected value of the distribution because the mean and mode call the same function, but the mode must lie on the hypersphere whereas the mean does not.

@RylanSchaeffer
Copy link

If mean is not the expected value of the distribution, do you know of how (within the PyTorch ecosystem) I can compute the expected value of vMF(mu, kappa)?

@fritzo
Copy link
Member Author

fritzo commented Mar 11, 2022

Hi @RylanSchaeffer, I do not know how to analytically compute the 3d mean vector of either distribution, but if you do find a formula, feel free to contribute it!

Note 'mean' can mean a few different things. In convex subsets of vector spaces, the mean can be defined as a limiting weighted linear combination of values. In metric spaces, the mean can be defined as the minimizer of expected squared distance. In Euclidean space these two definitions coincide, but on the surface of the sphere only the latter definition makes sense. Because we try to be manifold-aware in torch.distributions (via the constraints library) we chose the latter definition of mean. Well and also because it's easier to compute 😄

@RylanSchaeffer
Copy link

RylanSchaeffer commented Mar 11, 2022

@fritzo , thanks for the prompt response, and for clarifying the different meanings of "mean".

By expected value, I meant E[X] := \int x p(x) dx for continuous random variable X. I'm not sure which of your two definitions this applies to; I suspect the first? To check that I'm understanding you correctly, when you say "distance", I'm guessing you did not mean Lp metrics in R^D space?

Regardless, this is the expression for the expected value of the von Mises-Fisher distribution https://stats.stackexchange.com/a/116911/62060, which TensorFlow has implemented https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/VonMisesFisher#mean, and I'm trying to see if anyone has implemented it in PyTorch. I'd try implementing it myself but I'm sure there's lots of tricks that one needs to be aware of that I'm not.

@RylanSchaeffer
Copy link

To ask a slightly different question, if I wanted to implement the expected value E[X] := \int x p(x) dx for the vMF distribution following the expression here (https://stats.stackexchange.com/a/116911/62060), what do I need to be aware of when computing it? Are there any numerical instabilities associated with spherical distributions?

@fritzo
Copy link
Member Author

fritzo commented Mar 11, 2022

@RylanSchaeffer From wikipedia, it looks like you'll need to implement a separate function. This extended discussion probably deserves a new issue on either Pyro or PyTorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants