Skip to content

Commit

Permalink
Add FunsorDistribution to wrap funsors for use in Pyro (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Jul 24, 2019
1 parent a2aea43 commit 7c0b068
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Welcome to Funsor's documentation!
:caption: Interfaces:

distributions
pyro
minipyro
einsum

Expand Down
24 changes: 24 additions & 0 deletions docs/source/pyro.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Pyro-Compatible Distributions
-----------------------------
This interface provides a number of PyTorch-style distributions that use
funsors internally to perform inference. These high-level objects are based on
a wrapping class: :class:`~funsor.pyro_distributions.FunsorDistribution` which
wraps a funsor in a PyTorch-distributions-compatible interface.
:class:`~funsor.pyro.distribution.FunsorDistribution` objects can be used
directly in Pyro models (using the standard Pyro backend).

FunsorDistribution Base Class
=============================
.. automodule:: funsor.pyro.distribution
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Conversion Utilities
====================
.. automodule:: funsor.pyro.convert
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
5 changes: 5 additions & 0 deletions funsor/pyro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from funsor.pyro.distribution import FunsorDistribution

__all__ = [
"FunsorDistribution",
]
81 changes: 81 additions & 0 deletions funsor/pyro/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import absolute_import, division, print_function

from collections import OrderedDict

import pyro.distributions as dist
import torch

from funsor.domains import bint
from funsor.torch import Tensor

# Conversion functions use fixed names for Pyro batch dims.
DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0)))
NAME_TO_DIM = dict(zip(DIM_TO_NAME, range(-100, 0)))


def tensor_to_funsor(tensor, event_dim=0, dtype="real"):
"""
Convert a :class:`torch.Tensor` to a :class:`funsor.torch.Tensor` .
"""
assert isinstance(tensor, torch.Tensor)
batch_shape = tensor.shape[:tensor.dim() - event_dim]
event_shape = tensor.shape[tensor.dim() - event_dim:]

# Squeeze batch_shape.
inputs = OrderedDict()
squeezed_batch_shape = []
for dim, size in enumerate(batch_shape):
if size > 1:
name = DIM_TO_NAME[dim - len(batch_shape)]
inputs[name] = bint(size)
squeezed_batch_shape.append(size)
squeezed_batch_shape = torch.Size(squeezed_batch_shape)
if squeezed_batch_shape != batch_shape:
batch_shape = squeezed_batch_shape
tensor = tensor.reshape(batch_shape + event_shape)

return Tensor(tensor, inputs, dtype)


def funsor_to_tensor(funsor_, ndims):
"""
Convert a :class:`funsor.torch.Tensor` to a :class:`torch.Tensor` .
"""
assert isinstance(funsor_, Tensor)
assert all(k.startswith("_pyro_dim_") for k in funsor_.inputs)
names = tuple(sorted(funsor_.inputs, key=NAME_TO_DIM.__getitem__))
tensor = funsor_.align(names).data
if names:
# Unsqueeze batch_shape.
dims = list(map(NAME_TO_DIM.__getitem__, names))
batch_shape = [1] * (-dims[0])
for dim, size in zip(dims, tensor.shape):
batch_shape[dim] = size
batch_shape = torch.Size(batch_shape)
tensor = tensor.reshape(batch_shape + funsor_.output.shape)
if ndims != tensor.dim():
tensor = tensor.reshape((1,) * (ndims - tensor.dim()) + tensor.shape)
assert tensor.dim() == ndims
return tensor


def dist_to_funsor(pyro_dist, reinterpreted_batch_ndims=0):
"""
Convert a :class:`torch.distributions.Distribution` to a
:class:`~funsor.terms.Funsor` .
"""
assert isinstance(pyro_dist, torch.distributions.Distribution)
while isinstance(pyro_dist, dist.Independent):
reinterpreted_batch_ndims += pyro_dist.reinterpreted_batch_ndims
pyro_dist = pyro_dist.base_dist
event_dim = pyro_dist.event_dim + reinterpreted_batch_ndims

if isinstance(pyro_dist, dist.Categorical):
return tensor_to_funsor(pyro_dist.logits, event_dim=event_dim + 1)
if isinstance(pyro_dist, dist.Normal):
raise NotImplementedError("TODO")
if isinstance(pyro_dist, dist.MultivariateNormal):
raise NotImplementedError("TODO")

raise ValueError("Cannot convert {} distribution to a Funsor"
.format(type(pyro_dist).__name__))
82 changes: 82 additions & 0 deletions funsor/pyro/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import absolute_import, division, print_function

from collections import OrderedDict

import pyro.distributions as dist
import torch

from funsor.delta import Delta
from funsor.domains import bint
from funsor.interpreter import interpretation, reinterpret
from funsor.joint import Joint
from funsor.optimizer import apply_optimizer
from funsor.pyro.convert import DIM_TO_NAME, funsor_to_tensor, tensor_to_funsor
from funsor.terms import Funsor, lazy


class FunsorDistribution(dist.TorchDistribution):
"""
:class:`~torch.distributions.Distribution` wrapper around a
:class:`~funsor.terms.Funsor` for use in Pyro code. This is typically used
as a base class for specific funsor inference algorithms wrapped in a
distribution interface.
:param funsor.terms.Funsor funsor_dist: A funsor with an input named
"value" that is treated as a random variable. The distribution should
be normalized over "value".
:param torch.Size batch_shape: The distribution's batch shape. This must
be in the same order as the input of the ``funsor_dist``, but may
contain extra dims of size 1.
:param event_shape: The distribution's event shape.
"""
arg_constraints = {}

def __init__(self, funsor_dist, batch_shape=torch.Size(), event_shape=torch.Size(),
dtype="real"):
assert isinstance(funsor_dist, Funsor)
assert isinstance(batch_shape, tuple)
assert isinstance(event_shape, tuple)
assert "value" in funsor_dist.inputs
super(FunsorDistribution, self).__init__(batch_shape, event_shape)
self.funsor_dist = funsor_dist
self.dtype = dtype

def log_prob(self, value):
ndims = max(len(self.batch_shape), value.dim() - self.event_dim)
value = tensor_to_funsor(value, event_dim=self.event_dim, dtype=self.dtype)
with interpretation(lazy):
log_prob = apply_optimizer(self.funsor_dist(value=value))
log_prob = reinterpret(log_prob)
log_prob = funsor_to_tensor(log_prob, ndims=ndims)
return log_prob

def _sample_delta(self, sample_shape):
sample_inputs = None
if sample_shape:
sample_inputs = OrderedDict()
shape = sample_shape + self.batch_shape
for dim in range(-len(shape), -len(self.batch_shape)):
if shape[dim] > 1:
sample_inputs[DIM_TO_NAME[dim]] = bint(shape[dim])
delta = self.funsor_dist.sample(frozenset({"value"}), sample_inputs)
if isinstance(delta, Joint):
delta, = delta.deltas
assert isinstance(delta, Delta)
return delta

@torch.no_grad()
def sample(self, sample_shape=torch.Size()):
delta = self._sample_delta(sample_shape)
ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape)
value = funsor_to_tensor(delta.point, ndims=ndims)
return value.detach()

def rsample(self, sample_shape=torch.Size()):
delta = self._sample_delta(sample_shape)
assert not delta.log_prob.requires_grad, "distribution is not fully reparametrized"
ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape)
value = funsor_to_tensor(delta.point, ndims=ndims)
return value

def expand(self, batch_shape, _instance=None):
raise NotImplementedError("TODO")
4 changes: 4 additions & 0 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ def item(self):
"only one element Funsors can be converted to Python scalars")
raise NotImplementedError

@property
def requires_grad(self):
return False

def reduce(self, op, reduced_vars=None):
"""
Reduce along all or a subset of inputs.
Expand Down
1 change: 1 addition & 0 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6):
assert_close(actual.gaussian, expected.gaussian, atol=atol, rtol=rtol)
elif isinstance(actual, torch.Tensor):
assert actual.dtype == expected.dtype, msg
assert actual.shape == expected.shape, msg
if actual.dtype in (torch.long, torch.uint8):
assert (actual == expected).all(), msg
else:
Expand Down
4 changes: 4 additions & 0 deletions funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def __bool__(self):
def item(self):
return self.data.item()

@property
def requires_grad(self):
return self.data.requires_grad

def align(self, names):
assert isinstance(names, tuple)
assert all(name in self.inputs for name in names)
Expand Down
8 changes: 8 additions & 0 deletions test/pyro/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import absolute_import, division, print_function

import pyro


def pytest_runtest_setup(item):
pyro.set_rng_seed(0)
pyro.enable_validation(True)
34 changes: 34 additions & 0 deletions test/pyro/test_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import absolute_import, division, print_function

import pyro.distributions as dist
import pytest
import torch

from funsor.pyro.convert import dist_to_funsor, funsor_to_tensor, tensor_to_funsor
from funsor.testing import assert_close
from funsor.torch import Tensor

EVENT_SHAPES = [(), (1,), (5,), (4, 3)]
BATCH_SHAPES = [(), (1,), (4,), (2, 3), (1, 2, 1, 3, 1)]


@pytest.mark.parametrize("event_shape", EVENT_SHAPES, ids=str)
@pytest.mark.parametrize("batch_shape", BATCH_SHAPES, ids=str)
def test_tensor_funsor_tensor(batch_shape, event_shape):
event_dim = len(event_shape)
t = torch.randn(batch_shape + event_shape)
f = tensor_to_funsor(t, event_dim=event_dim)
t2 = funsor_to_tensor(f, ndims=t.dim())
assert_close(t2, t)


@pytest.mark.parametrize("batch_shape", BATCH_SHAPES, ids=str)
@pytest.mark.parametrize("cardinality", [2, 3, 5])
def test_dist_to_funsor_categorical(batch_shape, cardinality):
logits = torch.randn(batch_shape + (cardinality,))
logits -= logits.logsumexp(dim=-1, keepdim=True)
d = dist.Categorical(logits=logits)
f = dist_to_funsor(d)
assert isinstance(f, Tensor)
expected = tensor_to_funsor(logits, event_dim=1)
assert_close(f, expected)
56 changes: 56 additions & 0 deletions test/pyro/test_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import absolute_import, division, print_function

import pyro.distributions as dist
import pytest
import torch

from funsor.pyro.convert import tensor_to_funsor
from funsor.pyro.distribution import FunsorDistribution
from funsor.testing import assert_close

SHAPES = [(), (1,), (4,), (2, 3), (1, 2, 1, 3, 1)]


class Categorical(FunsorDistribution):
def __init__(self, logits):
batch_shape = logits.shape[:-1]
event_shape = torch.Size()
funsor_dist = tensor_to_funsor(logits, event_dim=1)["value"]
dtype = int(logits.size(-1))
super(Categorical, self).__init__(
funsor_dist, batch_shape, event_shape, dtype)


@pytest.mark.parametrize("cardinality", [2, 3])
@pytest.mark.parametrize("sample_shape", SHAPES, ids=str)
@pytest.mark.parametrize("batch_shape", SHAPES, ids=str)
def test_categorical_log_prob(sample_shape, batch_shape, cardinality):
logits = torch.randn(batch_shape + (cardinality,))
logits -= logits.logsumexp(dim=-1, keepdim=True)
actual = Categorical(logits=logits)
expected = dist.Categorical(logits=logits)
assert actual.batch_shape == expected.batch_shape
assert actual.event_shape == expected.event_shape

value = expected.sample(sample_shape)
actual_log_prob = actual.log_prob(value)
expected_log_prob = expected.log_prob(value)
assert_close(actual_log_prob, expected_log_prob)


@pytest.mark.parametrize("cardinality", [2, 3])
@pytest.mark.parametrize("sample_shape", SHAPES, ids=str)
@pytest.mark.parametrize("batch_shape", SHAPES, ids=str)
def test_categorical_sample(sample_shape, batch_shape, cardinality):
logits = torch.randn(batch_shape + (cardinality,))
logits -= logits.logsumexp(dim=-1, keepdim=True)
actual = Categorical(logits=logits)
expected = dist.Categorical(logits=logits)
assert actual.batch_shape == expected.batch_shape
assert actual.event_shape == expected.event_shape

actual_sample = actual.sample(sample_shape)
expected_sample = expected.sample(sample_shape)
assert actual_sample.dtype == expected_sample.dtype
assert actual_sample.shape == expected_sample.shape
expected.log_prob(actual_sample) # validates sample

0 comments on commit 7c0b068

Please sign in to comment.