Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conditional inverse and compose TransformModules #3185

Merged
merged 10 commits into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions pyro/distributions/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn

from .torch import TransformedDistribution
from .torch_transform import ComposeTransformModule


class ConditionalDistribution(ABC):
Expand Down Expand Up @@ -36,6 +37,70 @@ def __init__(self, *args, **kwargs):
def __hash__(self):
return super().__hash__()

@property
def inv(self) -> "ConditionalTransformModule":
return _ConditionalInverseTransformModule(self)


class _ConditionalInverseTransformModule(ConditionalTransformModule):
def __init__(self, transform: ConditionalTransform):
super().__init__()
self._transform = transform

@property
def inv(self) -> ConditionalTransform:
return self._transform

def condition(self, context: torch.Tensor):
return self._transform.condition(context).inv


class ConditionalComposeTransformModule(
ConditionalTransformModule, torch.nn.ModuleList
):
"""
Conditional analogue of :class:`~pyro.distributions.torch_transform.ComposeTransformModule` .

Useful as a base class for specifying complicated conditional distributions::

>>> class ConditionalFlowStack(dist.conditional.ConditionalComposeTransformModule):
... def __init__(self, input_dim, context_dim, hidden_dims, num_flows):
... super().__init__([
... dist.transforms.conditional_planar(input_dim, context_dim, hidden_dims)
... for _ in range(num_flows)
... ], cache_size=1)

>>> cond_dist = dist.conditional.ConditionalTransformedDistribution(
... dist.Normal(torch.zeros(3), torch.ones(3)).to_event(1),
... [ConditionalFlowStack(3, 2, [8, 8], num_flows=4).inv]
... )

>>> context = torch.rand(10, 2)
>>> data = torch.rand(10, 3)
>>> nll = -cond_dist.condition(context).log_prob(data)
"""

def __init__(self, transforms, cache_size: int = 0):
self.transforms = [
ConstantConditionalTransform(t)
if not isinstance(t, ConditionalTransform)
else t
for t in transforms
]
super().__init__()
if cache_size not in {0, 1}:
raise ValueError("cache_size must be 0 or 1")
self._cache_size = cache_size
# for parameter storage
for t in transforms:
if isinstance(t, torch.nn.Module):
self.append(t)

def condition(self, context: torch.Tensor) -> ComposeTransformModule:
return ComposeTransformModule(
[t.condition(context) for t in self.transforms]
).with_cache(self._cache_size)


class ConstantConditionalDistribution(ConditionalDistribution):
def __init__(self, base_dist):
Expand Down
12 changes: 9 additions & 3 deletions pyro/distributions/torch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ class ComposeTransformModule(torch.distributions.ComposeTransform, torch.nn.Modu
store when used in :class:`~pyro.nn.module.PyroModule` instances.
"""

def __init__(self, parts):
super().__init__(parts)
def __init__(self, parts, cache_size=0):
super().__init__(parts, cache_size=cache_size)
for part in parts:
self.append(part)
if isinstance(part, torch.nn.Module):
self.append(part)

def __hash__(self):
return super(torch.nn.Module, self).__hash__()

def with_cache(self, cache_size=1):
if cache_size == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is cache_size == self._cache_size?

return self
return ComposeTransformModule(self.parts, cache_size=cache_size)
71 changes: 71 additions & 0 deletions tests/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,74 @@ def test_lower_cholesky_transform(transform, batch_shape, dim):
y2 = transform(x2)
assert y2.shape == shape
assert_close(y, y2)


@pytest.mark.parametrize("batch_shape", [(), (7,), (6, 7)])
@pytest.mark.parametrize("input_dim", [2, 3, 5])
@pytest.mark.parametrize("context_dim", [2, 3, 5])
def test_inverse_conditional_transform_module(batch_shape, input_dim, context_dim):
cond_transform = T.conditional_spline(input_dim, context_dim, [6])

noise = torch.rand(batch_shape + (input_dim,))
context = torch.rand(batch_shape + (context_dim,))

assert_close(
cond_transform.inv.condition(context)(noise),
cond_transform.condition(context).inv(noise),
)

assert cond_transform.inv.inv is cond_transform
assert_close(
cond_transform.inv.condition(context).inv(noise),
cond_transform.condition(context).inv.inv(noise),
)


@pytest.mark.parametrize("batch_shape", [(), (7,), (6, 7)])
@pytest.mark.parametrize("input_dim", [2, 3, 5])
@pytest.mark.parametrize("context_dim", [2, 3, 5])
@pytest.mark.parametrize("cache_size", [0, 1])
def test_conditional_compose_transform_module(
batch_shape, input_dim, context_dim, cache_size
):
conditional_transforms = [
T.AffineTransform(1.0, 2.0),
T.Spline(input_dim),
T.conditional_spline(input_dim, context_dim, [5]),
T.SoftplusTransform(),
T.conditional_spline(input_dim, context_dim, [6]),
]
cond_transform = dist.conditional.ConditionalComposeTransformModule(
conditional_transforms, cache_size=cache_size
)

base_dist = dist.Normal(0, 1).expand(batch_shape + (input_dim,)).to_event(1)
cond_dist = dist.ConditionalTransformedDistribution(base_dist, [cond_transform])

context = torch.rand(batch_shape + (context_dim,))
d = cond_dist.condition(context)
transform = d.transforms[0]
assert isinstance(transform, T.ComposeTransformModule)

data = d.rsample()
assert data.shape == batch_shape + (input_dim,)
assert d.log_prob(data).shape == batch_shape

actual_params = set(cond_transform.parameters())
expected_params = set(
torch.nn.ModuleList(
[t for t in conditional_transforms if isinstance(t, torch.nn.Module)]
).parameters()
)
assert set() != actual_params == expected_params

noise = base_dist.rsample()
expected = noise
for t in conditional_transforms:
expected = (t.condition(context) if hasattr(t, "condition") else t)(expected)

actual = transform(noise)
assert_close(actual, expected)

assert_close(cond_transform.inv.condition(context)(actual), noise)
assert_close(cond_transform.condition(context).inv(expected), noise)