-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement NanMaskedNormal, NanMaskedMultivariateNormal #3116
Conversation
Nice, I remember that this is requested by many forum users. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. obviously there are various ways the computation could be sped-up in different regimes but since this is probably most useful in the relatively low dimensional setting anyway...
result = value.new_zeros(n) | ||
|
||
# Evaluate ok elements. | ||
for pattern in sorted(set(map(tuple, ok.tolist()))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh i thought you were computing one big marginalized covariance with 0s/1s where appropriate so that everything could be vectorized (no for loop)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😄 that's beyond my linear algebra skills / patience. In practice I'm working with 3 columns so there are at most 7 patterns.
ok_value = value[row_mask][:, col_mask] | ||
ok_loc = loc[row_mask][:, col_mask] | ||
ok_cov = cov[row_mask][:, col_mask][:, :, col_mask] | ||
marginal = MultivariateNormal(ok_loc, ok_cov, validate_args=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do these invocation not need covariance_matrix=
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess one nice thing about this pattern is that you don't need to worry about factors of log 2pi explicitly...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
covariance_matrix is the default first argument, so no kwarg is necessary.
* Implement NanMaskedNormal, NanMaskedMultivariateNormal * Fix test * Add test for fully-unobserved data
This implements two distributions to serve as likelihoods for partially observed data, where unobserved elements are specified as NAN values. This is new functionality beyond
pyro.mask()
andDistribution.mask()
in that it allows NAN values within an event of MultivariateNormal; in this case we can analytically marginalize out the missing value. TheNanMaskedNormal
is similar toNormal.mask(...)
, but I've included it for easier compatibility with the nontrivialNanMaskedMultivariateNormal
.My motivating example is a Bayesian multivariate linear regression model with learned multivariate noise distribution and partially observed response as specified in a pandas dataframe. Each of the response columns is differently partially observed.
Tested
NanMaskedNormal
NanMaskedMultivariateNormal
NanMaskedMultivariateNormal