-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FunsorDistribution to wrap funsors for use in Pyro (#170)
- Loading branch information
Showing
11 changed files
with
300 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ Welcome to Funsor's documentation! | |
:caption: Interfaces: | ||
|
||
distributions | ||
pyro | ||
minipyro | ||
einsum | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
Pyro-Compatible Distributions | ||
----------------------------- | ||
This interface provides a number of PyTorch-style distributions that use | ||
funsors internally to perform inference. These high-level objects are based on | ||
a wrapping class: :class:`~funsor.pyro_distributions.FunsorDistribution` which | ||
wraps a funsor in a PyTorch-distributions-compatible interface. | ||
:class:`~funsor.pyro.distribution.FunsorDistribution` objects can be used | ||
directly in Pyro models (using the standard Pyro backend). | ||
|
||
FunsorDistribution Base Class | ||
============================= | ||
.. automodule:: funsor.pyro.distribution | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
:member-order: bysource | ||
|
||
Conversion Utilities | ||
==================== | ||
.. automodule:: funsor.pyro.convert | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
:member-order: bysource |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from funsor.pyro.distribution import FunsorDistribution | ||
|
||
__all__ = [ | ||
"FunsorDistribution", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from collections import OrderedDict | ||
|
||
import pyro.distributions as dist | ||
import torch | ||
|
||
from funsor.domains import bint | ||
from funsor.torch import Tensor | ||
|
||
# Conversion functions use fixed names for Pyro batch dims. | ||
DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0))) | ||
NAME_TO_DIM = dict(zip(DIM_TO_NAME, range(-100, 0))) | ||
|
||
|
||
def tensor_to_funsor(tensor, event_dim=0, dtype="real"): | ||
""" | ||
Convert a :class:`torch.Tensor` to a :class:`funsor.torch.Tensor` . | ||
""" | ||
assert isinstance(tensor, torch.Tensor) | ||
batch_shape = tensor.shape[:tensor.dim() - event_dim] | ||
event_shape = tensor.shape[tensor.dim() - event_dim:] | ||
|
||
# Squeeze batch_shape. | ||
inputs = OrderedDict() | ||
squeezed_batch_shape = [] | ||
for dim, size in enumerate(batch_shape): | ||
if size > 1: | ||
name = DIM_TO_NAME[dim - len(batch_shape)] | ||
inputs[name] = bint(size) | ||
squeezed_batch_shape.append(size) | ||
squeezed_batch_shape = torch.Size(squeezed_batch_shape) | ||
if squeezed_batch_shape != batch_shape: | ||
batch_shape = squeezed_batch_shape | ||
tensor = tensor.reshape(batch_shape + event_shape) | ||
|
||
return Tensor(tensor, inputs, dtype) | ||
|
||
|
||
def funsor_to_tensor(funsor_, ndims): | ||
""" | ||
Convert a :class:`funsor.torch.Tensor` to a :class:`torch.Tensor` . | ||
""" | ||
assert isinstance(funsor_, Tensor) | ||
assert all(k.startswith("_pyro_dim_") for k in funsor_.inputs) | ||
names = tuple(sorted(funsor_.inputs, key=NAME_TO_DIM.__getitem__)) | ||
tensor = funsor_.align(names).data | ||
if names: | ||
# Unsqueeze batch_shape. | ||
dims = list(map(NAME_TO_DIM.__getitem__, names)) | ||
batch_shape = [1] * (-dims[0]) | ||
for dim, size in zip(dims, tensor.shape): | ||
batch_shape[dim] = size | ||
batch_shape = torch.Size(batch_shape) | ||
tensor = tensor.reshape(batch_shape + funsor_.output.shape) | ||
if ndims != tensor.dim(): | ||
tensor = tensor.reshape((1,) * (ndims - tensor.dim()) + tensor.shape) | ||
assert tensor.dim() == ndims | ||
return tensor | ||
|
||
|
||
def dist_to_funsor(pyro_dist, reinterpreted_batch_ndims=0): | ||
""" | ||
Convert a :class:`torch.distributions.Distribution` to a | ||
:class:`~funsor.terms.Funsor` . | ||
""" | ||
assert isinstance(pyro_dist, torch.distributions.Distribution) | ||
while isinstance(pyro_dist, dist.Independent): | ||
reinterpreted_batch_ndims += pyro_dist.reinterpreted_batch_ndims | ||
pyro_dist = pyro_dist.base_dist | ||
event_dim = pyro_dist.event_dim + reinterpreted_batch_ndims | ||
|
||
if isinstance(pyro_dist, dist.Categorical): | ||
return tensor_to_funsor(pyro_dist.logits, event_dim=event_dim + 1) | ||
if isinstance(pyro_dist, dist.Normal): | ||
raise NotImplementedError("TODO") | ||
if isinstance(pyro_dist, dist.MultivariateNormal): | ||
raise NotImplementedError("TODO") | ||
|
||
raise ValueError("Cannot convert {} distribution to a Funsor" | ||
.format(type(pyro_dist).__name__)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from collections import OrderedDict | ||
|
||
import pyro.distributions as dist | ||
import torch | ||
|
||
from funsor.delta import Delta | ||
from funsor.domains import bint | ||
from funsor.interpreter import interpretation, reinterpret | ||
from funsor.joint import Joint | ||
from funsor.optimizer import apply_optimizer | ||
from funsor.pyro.convert import DIM_TO_NAME, funsor_to_tensor, tensor_to_funsor | ||
from funsor.terms import Funsor, lazy | ||
|
||
|
||
class FunsorDistribution(dist.TorchDistribution): | ||
""" | ||
:class:`~torch.distributions.Distribution` wrapper around a | ||
:class:`~funsor.terms.Funsor` for use in Pyro code. This is typically used | ||
as a base class for specific funsor inference algorithms wrapped in a | ||
distribution interface. | ||
:param funsor.terms.Funsor funsor_dist: A funsor with an input named | ||
"value" that is treated as a random variable. The distribution should | ||
be normalized over "value". | ||
:param torch.Size batch_shape: The distribution's batch shape. This must | ||
be in the same order as the input of the ``funsor_dist``, but may | ||
contain extra dims of size 1. | ||
:param event_shape: The distribution's event shape. | ||
""" | ||
arg_constraints = {} | ||
|
||
def __init__(self, funsor_dist, batch_shape=torch.Size(), event_shape=torch.Size(), | ||
dtype="real"): | ||
assert isinstance(funsor_dist, Funsor) | ||
assert isinstance(batch_shape, tuple) | ||
assert isinstance(event_shape, tuple) | ||
assert "value" in funsor_dist.inputs | ||
super(FunsorDistribution, self).__init__(batch_shape, event_shape) | ||
self.funsor_dist = funsor_dist | ||
self.dtype = dtype | ||
|
||
def log_prob(self, value): | ||
ndims = max(len(self.batch_shape), value.dim() - self.event_dim) | ||
value = tensor_to_funsor(value, event_dim=self.event_dim, dtype=self.dtype) | ||
with interpretation(lazy): | ||
log_prob = apply_optimizer(self.funsor_dist(value=value)) | ||
log_prob = reinterpret(log_prob) | ||
log_prob = funsor_to_tensor(log_prob, ndims=ndims) | ||
return log_prob | ||
|
||
def _sample_delta(self, sample_shape): | ||
sample_inputs = None | ||
if sample_shape: | ||
sample_inputs = OrderedDict() | ||
shape = sample_shape + self.batch_shape | ||
for dim in range(-len(shape), -len(self.batch_shape)): | ||
if shape[dim] > 1: | ||
sample_inputs[DIM_TO_NAME[dim]] = bint(shape[dim]) | ||
delta = self.funsor_dist.sample(frozenset({"value"}), sample_inputs) | ||
if isinstance(delta, Joint): | ||
delta, = delta.deltas | ||
assert isinstance(delta, Delta) | ||
return delta | ||
|
||
@torch.no_grad() | ||
def sample(self, sample_shape=torch.Size()): | ||
delta = self._sample_delta(sample_shape) | ||
ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape) | ||
value = funsor_to_tensor(delta.point, ndims=ndims) | ||
return value.detach() | ||
|
||
def rsample(self, sample_shape=torch.Size()): | ||
delta = self._sample_delta(sample_shape) | ||
assert not delta.log_prob.requires_grad, "distribution is not fully reparametrized" | ||
ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape) | ||
value = funsor_to_tensor(delta.point, ndims=ndims) | ||
return value | ||
|
||
def expand(self, batch_shape, _instance=None): | ||
raise NotImplementedError("TODO") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import pyro | ||
|
||
|
||
def pytest_runtest_setup(item): | ||
pyro.set_rng_seed(0) | ||
pyro.enable_validation(True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import pyro.distributions as dist | ||
import pytest | ||
import torch | ||
|
||
from funsor.pyro.convert import dist_to_funsor, funsor_to_tensor, tensor_to_funsor | ||
from funsor.testing import assert_close | ||
from funsor.torch import Tensor | ||
|
||
EVENT_SHAPES = [(), (1,), (5,), (4, 3)] | ||
BATCH_SHAPES = [(), (1,), (4,), (2, 3), (1, 2, 1, 3, 1)] | ||
|
||
|
||
@pytest.mark.parametrize("event_shape", EVENT_SHAPES, ids=str) | ||
@pytest.mark.parametrize("batch_shape", BATCH_SHAPES, ids=str) | ||
def test_tensor_funsor_tensor(batch_shape, event_shape): | ||
event_dim = len(event_shape) | ||
t = torch.randn(batch_shape + event_shape) | ||
f = tensor_to_funsor(t, event_dim=event_dim) | ||
t2 = funsor_to_tensor(f, ndims=t.dim()) | ||
assert_close(t2, t) | ||
|
||
|
||
@pytest.mark.parametrize("batch_shape", BATCH_SHAPES, ids=str) | ||
@pytest.mark.parametrize("cardinality", [2, 3, 5]) | ||
def test_dist_to_funsor_categorical(batch_shape, cardinality): | ||
logits = torch.randn(batch_shape + (cardinality,)) | ||
logits -= logits.logsumexp(dim=-1, keepdim=True) | ||
d = dist.Categorical(logits=logits) | ||
f = dist_to_funsor(d) | ||
assert isinstance(f, Tensor) | ||
expected = tensor_to_funsor(logits, event_dim=1) | ||
assert_close(f, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import pyro.distributions as dist | ||
import pytest | ||
import torch | ||
|
||
from funsor.pyro.convert import tensor_to_funsor | ||
from funsor.pyro.distribution import FunsorDistribution | ||
from funsor.testing import assert_close | ||
|
||
SHAPES = [(), (1,), (4,), (2, 3), (1, 2, 1, 3, 1)] | ||
|
||
|
||
class Categorical(FunsorDistribution): | ||
def __init__(self, logits): | ||
batch_shape = logits.shape[:-1] | ||
event_shape = torch.Size() | ||
funsor_dist = tensor_to_funsor(logits, event_dim=1)["value"] | ||
dtype = int(logits.size(-1)) | ||
super(Categorical, self).__init__( | ||
funsor_dist, batch_shape, event_shape, dtype) | ||
|
||
|
||
@pytest.mark.parametrize("cardinality", [2, 3]) | ||
@pytest.mark.parametrize("sample_shape", SHAPES, ids=str) | ||
@pytest.mark.parametrize("batch_shape", SHAPES, ids=str) | ||
def test_categorical_log_prob(sample_shape, batch_shape, cardinality): | ||
logits = torch.randn(batch_shape + (cardinality,)) | ||
logits -= logits.logsumexp(dim=-1, keepdim=True) | ||
actual = Categorical(logits=logits) | ||
expected = dist.Categorical(logits=logits) | ||
assert actual.batch_shape == expected.batch_shape | ||
assert actual.event_shape == expected.event_shape | ||
|
||
value = expected.sample(sample_shape) | ||
actual_log_prob = actual.log_prob(value) | ||
expected_log_prob = expected.log_prob(value) | ||
assert_close(actual_log_prob, expected_log_prob) | ||
|
||
|
||
@pytest.mark.parametrize("cardinality", [2, 3]) | ||
@pytest.mark.parametrize("sample_shape", SHAPES, ids=str) | ||
@pytest.mark.parametrize("batch_shape", SHAPES, ids=str) | ||
def test_categorical_sample(sample_shape, batch_shape, cardinality): | ||
logits = torch.randn(batch_shape + (cardinality,)) | ||
logits -= logits.logsumexp(dim=-1, keepdim=True) | ||
actual = Categorical(logits=logits) | ||
expected = dist.Categorical(logits=logits) | ||
assert actual.batch_shape == expected.batch_shape | ||
assert actual.event_shape == expected.event_shape | ||
|
||
actual_sample = actual.sample(sample_shape) | ||
expected_sample = expected.sample(sample_shape) | ||
assert actual_sample.dtype == expected_sample.dtype | ||
assert actual_sample.shape == expected_sample.shape | ||
expected.log_prob(actual_sample) # validates sample |