diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 95943af850..f3d497701d 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -179,6 +179,13 @@ PermuteTransform :undoc-members: :show-inheritance: +PlanarFlow +---------------- +.. autoclass:: pyro.distributions.PlanarFlow + :members: + :undoc-members: + :show-inheritance: + TransformModule ---------------- .. autoclass:: pyro.distributions.TransformModule diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 02a0cfa674..995a2e9939 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -15,6 +15,7 @@ from pyro.distributions.mixture import MaskedMixture from pyro.distributions.omt_mvn import OMTMultivariateNormal from pyro.distributions.permute import PermuteTransform +from pyro.distributions.planar import PlanarFlow from pyro.distributions.rejector import Rejector from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough) @@ -46,6 +47,7 @@ "MixtureOfDiagNormals", "OMTMultivariateNormal", "PermuteTransform", + "PlanarFlow", "Rejector", "RelaxedBernoulliStraightThrough", "RelaxedOneHotCategoricalStraightThrough", diff --git a/pyro/distributions/iaf.py b/pyro/distributions/iaf.py index 9c2810e2c5..e96e1787ba 100644 --- a/pyro/distributions/iaf.py +++ b/pyro/distributions/iaf.py @@ -118,6 +118,8 @@ def _inverse(self, y): x[idx] = (y[..., idx] - mean) * inverse_scale x = torch.stack(x, dim=-1) + log_scale = clamp_preserve_gradients(log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip) + self._add_intermediate_to_cache(log_scale, y, 'log_scale') return x def _add_intermediate_to_cache(self, intermediate, y, name): @@ -132,11 +134,7 @@ def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian """ - if (y, 'log_scale') in self._intermediates_cache: - log_scale = self._intermediates_cache.pop((y, 'log_scale')) - else: - _, log_scale = self.arn(x) - log_scale = clamp_preserve_gradients(log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip) + log_scale = self._intermediates_cache.pop((y, 'log_scale')) return log_scale diff --git a/pyro/distributions/planar.py b/pyro/distributions/planar.py new file mode 100644 index 0000000000..1ade156fbe --- /dev/null +++ b/pyro/distributions/planar.py @@ -0,0 +1,126 @@ +from __future__ import absolute_import, division, print_function + +import math + +import torch +import torch.nn as nn +from torch.distributions import constraints +import torch.nn.functional as F + +from pyro.distributions.torch_transform import TransformModule +from pyro.distributions.util import copy_docs_from + + +@copy_docs_from(TransformModule) +class PlanarFlow(TransformModule): + """ + A 'planar' normalizing flow that uses the transformation + + :math:`\\mathbf{y} = \\mathbf{x} + \\mathbf{u}\\tanh(\\mathbf{w}^T\\mathbf{z}+b)` + + where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, and the learnable parameters + are :math:`b\\in\\mathbb{R}`, :math:`\\mathbf{u}\\in\\mathbb{R}^D`, :math:`\\mathbf{w}\\in\\mathbb{R}^D` for input + dimension :math:`D`. For this to be an invertible transformation, the condition + :math:`\\mathbf{w}^T\\mathbf{u}>-1` is enforced. + + Together with `TransformedDistribution` this provides a way to create richer variational approximations. + + Example usage: + + >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) + >>> plf = PlanarFlow(10) + >>> plf_module = pyro.module("my_plf", plf) + >>> plf_dist = dist.TransformedDistribution(base_dist, [plf]) + >>> plf_dist.sample() # doctest: +SKIP + tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868, + 0.1389, -0.4629, 0.0986]) + + The inverse of this transform does not possess an analytical solution and is left unimplemented. However, + the inverse is cached when the forward operation is called during sampling, and so samples drawn using + planar flow can be scored. + + :param input_dim: the dimension of the input (and output) variable. + :type autoregressive_nn: int + + References: + + Variational Inference with Normalizing Flows [arXiv:1505.05770] + Danilo Jimenez Rezende, Shakir Mohamed + + """ + + codomain = constraints.real + + def __init__(self, input_dim): + super(PlanarFlow, self).__init__() + + self.input_dim = input_dim + self.lin = nn.Linear(input_dim, 1) + self.u = nn.Parameter(torch.Tensor(input_dim)) + self.reset_parameters() + self._intermediates_cache = {} + self.add_inverse_to_cache = True + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.u.size(0)) + self.lin.weight.data.uniform_(-stdv, stdv) + self.u.data.uniform_(-stdv, stdv) + + def __hash__(self): + return super(nn.Module, self).__hash__() + + # This method ensures that torch(u_hat, w) > -1, required for invertibility + def u_hat(self): + u = self.u + w = self.lin.weight.squeeze(0) + alpha = torch.dot(u, w) + a_prime = -1 + F.softplus(alpha) + return u + (a_prime - alpha) * w.div(w.norm()) + + def _call(self, x): + """ + :param x: the input into the bijection + :type x: torch.Tensor + + Invokes the bijection x=>y; in the prototypical context of a TransformedDistribution `x` is a + sample from the base distribution (or the output of a previous flow) + """ + + y = x + self.u_hat() * torch.tanh(self.lin(x)) + + self._add_intermediate_to_cache(x, y, 'x') + return y + + def _inverse(self, y): + """ + :param y: the output of the bijection + :type y: torch.Tensor + + Inverts y => x. As noted above, this implementation is incapable of inverting arbitrary values + `y`; rather it assumes `y` is the result of a previously computed application of the bijector + to some `x` (which was cached on the forward call) + """ + if (y, 'x') in self._intermediates_cache: + x = self._intermediates_cache.pop((y, 'x')) + return x + else: + raise KeyError("PlanarFlow expected to find " + "key in intermediates cache but didn't") + + def _add_intermediate_to_cache(self, intermediate, y, name): + """ + Internal function used to cache intermediate results computed during the forward call + """ + assert((y, name) not in self._intermediates_cache),\ + "key collision in _add_intermediate_to_cache" + self._intermediates_cache[(y, name)] = intermediate + + def log_abs_det_jacobian(self, x, y): + """ + Calculates the elementwise determinant of the log jacobian + """ + psi_z = (1 - torch.tanh(self.lin(x)).pow(2)) * self.lin.weight + + # TODO: Simplify following line once using multivariate base distributions for multivariate flows + return torch.log(torch.abs(1 + torch.matmul(psi_z, self.u_hat())).unsqueeze(-1)) * \ + torch.ones_like(x) / x.size(-1) diff --git a/tests/distributions/test_flows.py b/tests/distributions/test_flows.py index 6d7c23bcb8..f2a89837ed 100644 --- a/tests/distributions/test_flows.py +++ b/tests/distributions/test_flows.py @@ -12,7 +12,7 @@ pytestmark = pytest.mark.init(rng_seed=123) -class AutoregressiveFlowTests(TestCase): +class FlowTests(TestCase): def setUp(self): # Epsilon is used to compare numerical gradient to analytical one self.epsilon = 1e-3 @@ -22,57 +22,69 @@ def setUp(self): def _test_jacobian(self, input_dim, make_flow): jacobian = torch.zeros(input_dim, input_dim) - iaf = make_flow(input_dim) + flow = make_flow(input_dim) def nonzero(x): return torch.sign(torch.abs(x)) x = torch.randn(1, input_dim) - iaf_x = iaf(x) - analytic_ldt = iaf.log_abs_det_jacobian(x, iaf_x).data.sum() + flow_x = flow(x) + analytic_ldt = flow.log_abs_det_jacobian(x, flow_x).data.sum() for j in range(input_dim): for k in range(input_dim): epsilon_vector = torch.zeros(1, input_dim) epsilon_vector[0, j] = self.epsilon - delta = (iaf(x + 0.5 * epsilon_vector) - iaf(x - 0.5 * epsilon_vector)) / self.epsilon + delta = (flow(x + 0.5 * epsilon_vector) - flow(x - 0.5 * epsilon_vector)) / self.epsilon jacobian[j, k] = float(delta[0, k].data.sum()) - permutation = iaf.arn.get_permutation() - permuted_jacobian = jacobian.clone() - for j in range(input_dim): - for k in range(input_dim): - permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]] - numeric_ldt = torch.sum(torch.log(torch.diag(permuted_jacobian))) - ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt) - - diag_sum = torch.sum(torch.diag(nonzero(permuted_jacobian))) - lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=-1)) + # Apply permutation for autoregressive flows + if hasattr(flow, 'arn'): + permutation = flow.arn.get_permutation() + permuted_jacobian = jacobian.clone() + for j in range(input_dim): + for k in range(input_dim): + permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]] + jacobian = permuted_jacobian + + # For autoregressive flow, Jacobian is sum of diagonal, otherwise need full determinate + if hasattr(flow, 'arn'): + numeric_ldt = torch.sum(torch.log(torch.diag(jacobian))) + else: + numeric_ldt = torch.log(torch.abs(jacobian.det())) + ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt) assert ldt_discrepancy < self.epsilon - assert diag_sum == float(input_dim) - assert lower_sum == float(0.0) + + # Test that lower triangular with unit diagonal for autoregressive flows + if hasattr(flow, 'arn'): + diag_sum = torch.sum(torch.diag(nonzero(jacobian))) + lower_sum = torch.sum(torch.tril(nonzero(jacobian), diagonal=-1)) + assert diag_sum == float(input_dim) + assert lower_sum == float(0.0) def _test_inverse(self, input_dim, make_flow): base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim)) - iaf = make_flow(input_dim) + flow = make_flow(input_dim) x_true = base_dist.sample(torch.Size([10])) - y = iaf._call(x_true) + y = flow._call(x_true) # This line empties the inverse cache, if the flow uses it - iaf._inverse(y) + if hasattr(flow, '_intermediates_cache'): + flow._intermediates_cache.pop((y, 'log_scale')) + flow._intermediates_cache.pop((y, 'x')) # Cache is empty, hence must be calculating inverse afresh - x_calculated = iaf._inverse(y) + x_calculated = flow._inverse(y) assert torch.norm(x_true - x_calculated, dim=-1).max().item() < self.delta def _test_shape(self, base_shape, make_flow): base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape)) last_dim = base_shape[-1] if isinstance(base_shape, tuple) else base_shape - iaf = make_flow(input_dim=last_dim) - sample = dist.TransformedDistribution(base_dist, [iaf]).sample() + flow = make_flow(input_dim=last_dim) + sample = dist.TransformedDistribution(base_dist, [flow]).sample() assert sample.shape == base_shape def _make_iaf(self, input_dim): @@ -87,6 +99,9 @@ def _make_flipflow(self, input_dim): permutation = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device) return dist.PermuteTransform(permutation) + def _make_planar(self, input_dim): + return dist.PlanarFlow(input_dim) + def test_iaf_jacobians(self): for input_dim in [2, 3, 5, 7, 9, 11]: self._test_jacobian(input_dim, self._make_iaf) @@ -95,6 +110,10 @@ def test_iaf_stable_jacobians(self): for input_dim in [2, 3, 5, 7, 9, 11]: self._test_jacobian(input_dim, self._make_iaf_stable) + def test_planar_jacobians(self): + for input_dim in [2, 3, 5, 7, 9, 11]: + self._test_jacobian(input_dim, self._make_planar) + def test_iaf_inverses(self): for input_dim in [2, 3, 5, 7, 9, 11]: self._test_inverse(input_dim, self._make_iaf) @@ -118,3 +137,7 @@ def test_iaf_stable_shapes(self): def test_flipflow_shapes(self): for shape in [(3,), (3, 4), (3, 4, 2)]: self._test_shape(shape, self._make_flipflow) + + def test_planar_shapes(self): + for shape in [(3,), (3, 4), (3, 4, 2)]: + self._test_shape(shape, self._make_planar)