Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix _loc_scale method in AutoMultivariateNormal #3233

Merged
merged 2 commits into from
Jun 23, 2023
Merged

fix _loc_scale method in AutoMultivariateNormal #3233

merged 2 commits into from
Jun 23, 2023

Conversation

martinjankowiak
Copy link
Collaborator

addresses #3232

cc @fritzo

pyro.sample("y", dist.LogNormal(0.0, 1.0))
pyro.sample("z", dist.Beta(2.0, 2.0).expand([2]).to_event(1))
pyro.sample("x", dist.Normal(0.0, 1.0))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move "x" so that it's not in the trivial top-left of scale_tril position

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise this test would still pass in master : )

if hasattr(auto_class, "get_posterior"):
posterior = guide.get_posterior()
posterior_scale = posterior.variance[-1].sqrt()
q = guide.quantiles([0.158655, 0.8413447])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume [0.158655, 0.8413447] == Normal(0, 1).icdf(torch.tensor([-1.0, 1.0]))?

Copy link
Collaborator Author

@martinjankowiak martinjankowiak Jun 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can replace with this if you prefer

guide.quantiles(dist.Normal(0, 1).cdf(torch.tensor([-1.0, 1.0])).data.numpy().tolist())

clunky because quantiles wraps in a tensor which raises a user warning if the input is a tensor....

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

alternatively can add a comment

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we need a correlated model to test this change? E.g.

def model():
    w = pyro.sample("w", dist.Normal(0.0, 1.0))
    x = pyro.sample("x", dist.Normal(x, 1.0))

@martinjankowiak
Copy link
Collaborator Author

Wouldn't we need a correlated model to test this change? E.g.

def model():
    w = pyro.sample("w", dist.Normal(0.0, 1.0))
    x = pyro.sample("x", dist.Normal(x, 1.0))

i don't think so. we just need to make sure we're computing the correct diagonal element of the full covariance matrix. so as long as we're doing that in a case where that computation is non-trivial (because e.g. it involves off-diagonal elements of scale_tril) we'll be good. as noted in the above comment, it's enough if x is not in the top-left of scale_tril

@martinjankowiak
Copy link
Collaborator Author

@fritzo afaict only quantiles was affected and only the AutoMultivariateNormal autoguide so afaik this PR is complete, pending further comments

@fritzo
Copy link
Member

fritzo commented Jun 23, 2023

Wouldn't we need a correlated model to test this change?

that computation is non-trivial (because e.g. it involves off-diagonal elements of scale_tril)

My concern is that in the existing test, all those off-diagonal elements happen to be zero, so the old and new formulae happen to agree. Am I missing something?

@martinjankowiak
Copy link
Collaborator Author

martinjankowiak commented Jun 23, 2023

the covariance_matrix is not diagonal since we do 100 svi steps (except in the Laplace case where I believe the full hessian is not computed). for example:

[[ 0.96436582  0.01351931 -0.01359293  0.10933437]
 [ 0.01351931  1.28534454  0.07159596  0.01766826]
 [-0.01359293  0.07159596  1.14718362  0.0288711 ]
 [ 0.10933437  0.01766826  0.0288711   0.96723489]]

@fritzo fritzo merged commit 7e3d62e into dev Jun 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants