Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement NanMaskedNormal, NanMaskedMultivariateNormal #3116

Merged
merged 3 commits into from
Jul 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,20 @@ MultivariateStudentT
:undoc-members:
:show-inheritance:

NanMaskedNormal
---------------
.. autoclass:: pyro.distributions.NanMaskedNormal
:members:
:undoc-members:
:show-inheritance:

NanMaskedMultivariateNormal
---------------------------
.. autoclass:: pyro.distributions.NanMaskedMultivariateNormal
:members:
:undoc-members:
:show-inheritance:

OMTMultivariateNormal
---------------------
.. autoclass:: pyro.distributions.OMTMultivariateNormal
Expand Down
9 changes: 6 additions & 3 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pyro.distributions.logistic import Logistic, SkewLogistic
from pyro.distributions.mixture import MaskedMixture
from pyro.distributions.multivariate_studentt import MultivariateStudentT
from pyro.distributions.nanmasked import NanMaskedMultivariateNormal, NanMaskedNormal
from pyro.distributions.omt_mvn import OMTMultivariateNormal
from pyro.distributions.one_one_matching import OneOneMatching
from pyro.distributions.one_two_matching import OneTwoMatching
Expand Down Expand Up @@ -92,9 +93,9 @@
from . import constraints, kl, transforms

__all__ = [
"AVFMultivariateNormal",
"AffineBeta",
"AsymmetricLaplace",
"AVFMultivariateNormal",
"BetaBinomial",
"CoalescentRateLikelihood",
"CoalescentTimes",
Expand Down Expand Up @@ -124,13 +125,15 @@
"LKJ",
"LKJCorrCholesky",
"LinearHMM",
"Logistic",
"LogNormalNegativeBinomial",
"Logistic",
"MaskedDistribution",
"MaskedMixture",
"MixtureOfDiagNormals",
"MixtureOfDiagNormalsSharedCovariance",
"MultivariateStudentT",
"NanMaskedMultivariateNormal",
"NanMaskedNormal",
"OMTMultivariateNormal",
"OneOneMatching",
"OneTwoMatching",
Expand All @@ -142,8 +145,8 @@
"SineBivariateVonMises",
"SineSkewed",
"SkewLogistic",
"SoftLaplace",
"SoftAsymmetricLaplace",
"SoftLaplace",
"SpanningTree",
"Stable",
"TorchDistribution",
Expand Down
99 changes: 99 additions & 0 deletions pyro/distributions/nanmasked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

from .torch import MultivariateNormal, Normal


class NanMaskedNormal(Normal):
"""
Wrapper around :class:`~pyro.distributions.Normal` to allow partially
observed data as specified by NAN elements in :meth:`log_prob`; the
``log_prob`` of these elements will be zero. This is useful for likelihoods
with missing data.

Example::

from math import nan
data = torch.tensor([0.5, 0.1, nan, 0.9])
with pyro.plate("data", len(data)):
pyro.sample("obs", NanMaskedNormal(0, 1), obs=data)
"""

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
ok = value.isfinite()
if ok.all():
return super().log_prob(value)

# Broadcast all tensors.
value, ok, loc, scale = torch.broadcast_tensors(value, ok, self.loc, self.scale)
result = value.new_zeros(value.shape)

# Evaluate ok elements.
if ok.any():
marginal = Normal(loc[ok], scale[ok], validate_args=False)
result[ok] = marginal.log_prob(value[ok])
return result


class NanMaskedMultivariateNormal(MultivariateNormal):
"""
Wrapper around :class:`~pyro.distributions.MultivariateNormal` to allow
partially observed data as specified by NAN elements in the argument to
:meth:`log_prob`. The ``log_prob`` of these events will marginalize over
the NAN elements. This is useful for likelihoods with missing data.

Example::

from math import nan
data = torch.tensor([
[0.1, 0.2, 3.4],
[0.5, 0.1, nan],
[0.6, nan, nan],
[nan, 0.5, nan],
[nan, nan, nan],
fritzo marked this conversation as resolved.
Show resolved Hide resolved
])
with pyro.plate("data", len(data)):
pyro.sample(
"obs",
NanMaskedMultivariateNormal(torch.zeros(3), torch.eye(3)),
obs=data,
)
"""

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
ok = value.isfinite()
if ok.all():
return super().log_prob(value)

# Broadcast all tensors. This might waste some computation by eagerly
# broadcasting, but the optimal implementation is quite complex.
value, ok, loc = torch.broadcast_tensors(value, ok, self.loc)
cov = self.covariance_matrix.expand(loc.shape + loc.shape[-1:])

# Flatten.
result_shape = value.shape[:-1]
n = result_shape.numel()
p = value.shape[-1]
value = value.reshape(n, p)
ok = ok.reshape(n, p)
loc = loc.reshape(n, p)
cov = cov.reshape(n, p, p)
result = value.new_zeros(n)

# Evaluate ok elements.
for pattern in sorted(set(map(tuple, ok.tolist()))):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i thought you were computing one big marginalized covariance with 0s/1s where appropriate so that everything could be vectorized (no for loop)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄 that's beyond my linear algebra skills / patience. In practice I'm working with 3 columns so there are at most 7 patterns.

if not any(pattern):
continue
# Marginalize out NAN elements.
col_mask = torch.tensor(pattern)
row_mask = (ok == col_mask).all(-1)
ok_value = value[row_mask][:, col_mask]
ok_loc = loc[row_mask][:, col_mask]
ok_cov = cov[row_mask][:, col_mask][:, :, col_mask]
marginal = MultivariateNormal(ok_loc, ok_cov, validate_args=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do these invocation not need covariance_matrix=?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess one nice thing about this pattern is that you don't need to worry about factors of log 2pi explicitly...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

covariance_matrix is the default first argument, so no kwarg is necessary.

result[row_mask] = marginal.log_prob(ok_value)

# Unflatten.
return result.reshape(result_shape)
97 changes: 97 additions & 0 deletions tests/distributions/test_nanmasked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import pytest
import torch

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam
from tests.common import assert_close


@pytest.mark.parametrize("batch_shape", [(), (40,), (11, 9)], ids=str)
def test_normal(batch_shape):
# Test on full data
data = torch.randn(batch_shape)
loc = torch.randn(batch_shape).requires_grad_()
scale = torch.randn(batch_shape).exp().requires_grad_()
d = dist.NanMaskedNormal(loc, scale)
d2 = dist.Normal(loc, scale)
actual = d.log_prob(data)
expected = d2.log_prob(data)
assert_close(actual, expected)

# Test on partial data.
ok = torch.rand(batch_shape) < 0.5
data[~ok] = math.nan
actual = d.log_prob(data)
assert actual.shape == expected.shape
assert actual.isfinite().all()
loc_grad, scale_grad = torch.autograd.grad(actual.sum(), [loc, scale])
assert loc_grad.isfinite().all()
assert scale_grad.isfinite().all()

# Check identity on fully observed and fully unobserved rows.
assert_close(actual[ok], expected[ok])
assert_close(actual[~ok], torch.zeros_like(actual[~ok]))


@pytest.mark.parametrize("batch_shape", [(), (40,), (11, 9)], ids=str)
@pytest.mark.parametrize("p", [1, 2, 3, 10], ids=str)
def test_multivariate_normal(batch_shape, p):
# Test on full data
data = torch.randn(batch_shape + (p,))
loc = torch.randn(batch_shape + (p,)).requires_grad_()
scale_tril = torch.randn(batch_shape + (p, p))
scale_tril.tril_()
scale_tril.diagonal(dim1=-2, dim2=-1).exp_()
scale_tril.requires_grad_()
d = dist.NanMaskedMultivariateNormal(loc, scale_tril=scale_tril)
d2 = dist.MultivariateNormal(loc, scale_tril=scale_tril)
actual = d.log_prob(data)
expected = d2.log_prob(data)
assert_close(actual, expected)

# Test on partial data.
ok = torch.rand(batch_shape + (p,)) < 0.5
data[~ok] = math.nan
actual = d.log_prob(data)
assert actual.shape == expected.shape
assert actual.isfinite().all()
loc_grad, scale_tril_grad = torch.autograd.grad(actual.sum(), [loc, scale_tril])
assert loc_grad.isfinite().all()
assert scale_tril_grad.isfinite().all()

# Check identity on fully observed and fully unobserved rows.
observed = ok.all(-1)
assert_close(actual[observed], expected[observed])
unobserved = ~ok.any(-1)
assert_close(actual[unobserved], torch.zeros_like(actual[unobserved]))


def test_multivariate_normal_model():
def model(data):
loc = pyro.sample("loc", dist.Normal(torch.zeros(3), torch.ones(3)).to_event(1))
scale_tril = torch.eye(3)
with pyro.plate("data", len(data)):
pyro.sample(
"obs",
dist.NanMaskedMultivariateNormal(loc, scale_tril=scale_tril),
obs=data,
)

data = torch.randn(100, 3)
ok = torch.rand(100, 3) < 0.5
assert 100 < ok.long().sum() < 200, "weak test"
data[~ok] = math.nan

guide = AutoNormal(model)
svi = SVI(model, guide, Adam({"lr": 1e-4}), Trace_ELBO())
for step in range(3):
loss = svi.step(data)
assert math.isfinite(loss)