From 9c19029d0de6484cd32b30b7071b33852b55d59e Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Tue, 23 Oct 2018 23:14:20 +0100 Subject: [PATCH 01/20] Separated NN tests from flow tests --- tests/distributions/test_flows.py | 67 +++++++++++++++++++ tests/nn/__init__.py | 1 + .../test_iaf.py => nn/test_autoregressive.py} | 56 ---------------- 3 files changed, 68 insertions(+), 56 deletions(-) create mode 100644 tests/distributions/test_flows.py create mode 100644 tests/nn/__init__.py rename tests/{distributions/test_iaf.py => nn/test_autoregressive.py} (66%) diff --git a/tests/distributions/test_flows.py b/tests/distributions/test_flows.py new file mode 100644 index 0000000000..f6be0907a1 --- /dev/null +++ b/tests/distributions/test_flows.py @@ -0,0 +1,67 @@ +from __future__ import absolute_import, division, print_function + +from unittest import TestCase + +import numpy as np +import pytest +import torch + +import pyro.distributions as dist +from pyro.distributions.iaf import InverseAutoregressiveFlow +from pyro.nn import AutoRegressiveNN +from pyro.nn.auto_reg_nn import create_mask + +pytestmark = pytest.mark.init(rng_seed=123) + + +class AutoregressiveFlowTests(TestCase): + def setUp(self): + self.epsilon = 1.0e-3 + + def _test_jacobian(self, input_dim, hidden_dim): + jacobian = torch.zeros(input_dim, input_dim) + iaf = InverseAutoregressiveFlow(AutoRegressiveNN(input_dim, [40]), sigmoid_bias=0.5) + + def nonzero(x): + return torch.sign(torch.abs(x)) + + x = torch.randn(1, input_dim) + iaf_x = iaf(x) + analytic_ldt = iaf.log_abs_det_jacobian(x, iaf_x).data.sum() + + for j in range(input_dim): + for k in range(input_dim): + epsilon_vector = torch.zeros(1, input_dim) + epsilon_vector[0, j] = self.epsilon + delta = (iaf(x + 0.5 * epsilon_vector) - iaf(x - 0.5 * epsilon_vector)) / self.epsilon + jacobian[j, k] = float(delta[0, k].data.sum()) + + permutation = iaf.arn.get_permutation() + permuted_jacobian = jacobian.clone() + for j in range(input_dim): + for k in range(input_dim): + permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]] + numeric_ldt = torch.sum(torch.log(torch.diag(permuted_jacobian))) + ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt) + + diag_sum = torch.sum(torch.diag(nonzero(permuted_jacobian))) + lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=-1)) + + assert ldt_discrepancy < self.epsilon + assert diag_sum == float(input_dim) + assert lower_sum == float(0.0) + + def _test_shape(self, base_shape): + base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape)) + last_dim = base_shape[-1] if isinstance(base_shape, tuple) else base_shape + iaf = InverseAutoregressiveFlow(AutoRegressiveNN(last_dim, [40])) + sample = dist.TransformedDistribution(base_dist, [iaf]).sample() + assert sample.shape == base_shape + + def test_jacobians(self): + for input_dim in [2, 3, 5, 7, 9, 11]: + self._test_jacobian(input_dim, 3 * input_dim + 1) + + def test_shapes(self): + for shape in [(3,), (3, 4), (3, 4, 2)]: + self._test_shape(shape) diff --git a/tests/nn/__init__.py b/tests/nn/__init__.py new file mode 100644 index 0000000000..42f7b0b7e9 --- /dev/null +++ b/tests/nn/__init__.py @@ -0,0 +1 @@ +from __future__ import absolute_import, division, print_function diff --git a/tests/distributions/test_iaf.py b/tests/nn/test_autoregressive.py similarity index 66% rename from tests/distributions/test_iaf.py rename to tests/nn/test_autoregressive.py index 23225487e0..863f6871bf 100644 --- a/tests/distributions/test_iaf.py +++ b/tests/nn/test_autoregressive.py @@ -2,71 +2,15 @@ from unittest import TestCase -import numpy as np import pytest import torch -import pyro.distributions as dist -from pyro.distributions.iaf import InverseAutoregressiveFlow from pyro.nn import AutoRegressiveNN from pyro.nn.auto_reg_nn import create_mask pytestmark = pytest.mark.init(rng_seed=123) -class InverseAutoregressiveFlowTests(TestCase): - def setUp(self): - self.epsilon = 1.0e-3 - - def _test_jacobian(self, input_dim, hidden_dim): - jacobian = torch.zeros(input_dim, input_dim) - iaf = InverseAutoregressiveFlow(AutoRegressiveNN(input_dim, [40]), sigmoid_bias=0.5) - - def nonzero(x): - return torch.sign(torch.abs(x)) - - x = torch.randn(1, input_dim) - iaf_x = iaf(x) - analytic_ldt = iaf.log_abs_det_jacobian(x, iaf_x).data.sum() - - for j in range(input_dim): - for k in range(input_dim): - epsilon_vector = torch.zeros(1, input_dim) - epsilon_vector[0, j] = self.epsilon - delta = (iaf(x + 0.5 * epsilon_vector) - iaf(x - 0.5 * epsilon_vector)) / self.epsilon - jacobian[j, k] = float(delta[0, k].data.sum()) - - permutation = iaf.arn.get_permutation() - permuted_jacobian = jacobian.clone() - for j in range(input_dim): - for k in range(input_dim): - permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]] - numeric_ldt = torch.sum(torch.log(torch.diag(permuted_jacobian))) - ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt) - - diag_sum = torch.sum(torch.diag(nonzero(permuted_jacobian))) - lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=-1)) - - assert ldt_discrepancy < self.epsilon - assert diag_sum == float(input_dim) - assert lower_sum == float(0.0) - - def _test_shape(self, base_shape): - base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape)) - last_dim = base_shape[-1] if isinstance(base_shape, tuple) else base_shape - iaf = InverseAutoregressiveFlow(AutoRegressiveNN(last_dim, [40])) - sample = dist.TransformedDistribution(base_dist, [iaf]).sample() - assert sample.shape == base_shape - - def test_jacobians(self): - for input_dim in [2, 3, 5, 7, 9, 11]: - self._test_jacobian(input_dim, 3 * input_dim + 1) - - def test_shapes(self): - for shape in [(3,), (3, 4), (3, 4, 2)]: - self._test_shape(shape) - - class AutoRegressiveNNTests(TestCase): def setUp(self): self.epsilon = 1.0e-3 From 5fab7abbb4e5ca5109c5772d1973fb8c96a19510 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Tue, 23 Oct 2018 23:23:28 +0100 Subject: [PATCH 02/20] PermutationFlow --- pyro/distributions/__init__.py | 2 + pyro/distributions/permutation_flow.py | 78 ++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 pyro/distributions/permutation_flow.py diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 280f434de7..7ff2b61b85 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -14,6 +14,7 @@ from pyro.distributions.lowrank_mvn import LowRankMultivariateNormal from pyro.distributions.mixture import MaskedMixture from pyro.distributions.omt_mvn import OMTMultivariateNormal +from pyro.distributions.permutation_flow import PermutationFlow from pyro.distributions.rejector import Rejector from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough) @@ -42,6 +43,7 @@ "MixtureOfDiagNormalsSharedCovariance", "MixtureOfDiagNormals", "OMTMultivariateNormal", + "PermutationFlow", "Rejector", "RelaxedBernoulliStraightThrough", "RelaxedOneHotCategoricalStraightThrough", diff --git a/pyro/distributions/permutation_flow.py b/pyro/distributions/permutation_flow.py new file mode 100644 index 0000000000..358755f458 --- /dev/null +++ b/pyro/distributions/permutation_flow.py @@ -0,0 +1,78 @@ +from __future__ import absolute_import, division, print_function + +import torch +from torch.distributions.transforms import Transform +from torch.distributions import constraints + +from pyro.distributions.util import copy_docs_from + + +@copy_docs_from(Transform) +class PermutationFlow(Transform): + """ + A normalizing flow that reorders the input dimensions, that is, multiplies the input by a permutation matrix. + This is useful in between IAF transforms to increase the flexibility of the resulting distribution and + stabilize learning. Whilst not being an autoregressive flow, the log absolute determinate of the Jacobian is + easily calculable as 0. Note that reordering the input dimension between two layers of IAF is not equivalent + to reordering the dimension inside the MADE networks that those IAFs use; using a PermutationFlow results in a + distribution with more flexibility. + + Example usage: + + >>> from pyro.nn import AutoRegressiveNN + >>> from pyro.distributions import InverseAutoregressiveFlow, PermutationFlow + >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) + >>> iaf1 = InverseAutoregressiveFlow(AutoRegressiveNN(10, [40])) + >>> ff = PermutationFlow(torch.randperm(10, dtype=torch.long)) + >>> iaf2 = InverseAutoregressiveFlow(AutoRegressiveNN(10, [40])) + >>> iaf_dist = dist.TransformedDistribution(base_dist, [iaf1, ff, iaf2]) + >>> iaf_dist.sample() # doctest: +SKIP + tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868, + 0.1389, -0.4629, 0.0986]) + + :param permutation: a permutation ordering that is applied to the inputs. + :type permutation: torch.LongTensor + + """ + + codomain = constraints.real + + def __init__(self, permutation): + super(PermutationFlow, self).__init__() + + self.permutation = permutation + + # Calculate the inverse permutation order + self.inv_permutation = torch.empty(permutation.shape, dtype=torch.long) + self.inv_permutation[permutation] = torch.arange(permutation.size(0), dtype=torch.long) + + def _call(self, x): + """ + :param x: the input into the bijection + :type x: torch.Tensor + + Invokes the bijection x=>y; in the prototypical context of a TransformedDistribution `x` is a + sample from the base distribution (or the output of a previous flow) + """ + + return x[..., self.permutation] + + def _inverse(self, y): + """ + :param y: the output of the bijection + :type y: torch.Tensor + + Inverts y => x. + """ + + return y[..., self.inv_permutation] + + def log_abs_det_jacobian(self, x, y): + """ + Calculates the elementwise determinant of the log Jacobian, i.e. log(abs([dy_0/dx_0, ..., dy_{N-1}/dx_{N-1}])). + Note that this type of flow is not autoregressive, so the log Jacobian is not the sum of the previous + expression. However, it turns out it's always 0 (since the determinant is -1 or +1), and so returning a + vector of zeros works. + """ + + return torch.zeros_like(x) From e8719a495c28fae8ffe8a113be97e9c06b5a0302 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Tue, 23 Oct 2018 23:46:20 +0100 Subject: [PATCH 03/20] Tests for PermutationFLow --- tests/distributions/test_flows.py | 49 ++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/tests/distributions/test_flows.py b/tests/distributions/test_flows.py index f6be0907a1..6f8f7d16ab 100644 --- a/tests/distributions/test_flows.py +++ b/tests/distributions/test_flows.py @@ -9,18 +9,21 @@ import pyro.distributions as dist from pyro.distributions.iaf import InverseAutoregressiveFlow from pyro.nn import AutoRegressiveNN -from pyro.nn.auto_reg_nn import create_mask pytestmark = pytest.mark.init(rng_seed=123) class AutoregressiveFlowTests(TestCase): def setUp(self): - self.epsilon = 1.0e-3 + # Epsilon is used to compare numerical gradient to analytical one + self.epsilon = 1e-3 - def _test_jacobian(self, input_dim, hidden_dim): + # Delta is tolerance for testing f(f^{-1}(x)) = x + self.delta = 1e-6 + + def _test_jacobian(self, input_dim, make_flow): jacobian = torch.zeros(input_dim, input_dim) - iaf = InverseAutoregressiveFlow(AutoRegressiveNN(input_dim, [40]), sigmoid_bias=0.5) + iaf = make_flow(input_dim) def nonzero(x): return torch.sign(torch.abs(x)) @@ -51,17 +54,43 @@ def nonzero(x): assert diag_sum == float(input_dim) assert lower_sum == float(0.0) - def _test_shape(self, base_shape): + def _test_inverse(self, input_dim, make_flow): + base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim)) + iaf = make_flow(input_dim) + + x_true = base_dist.sample(torch.Size([10])) + y = iaf._call(x_true) + x_calculated = iaf._inverse(y) + + assert torch.norm(x_true - x_calculated, dim=-1).max().item() < self.delta + + def _test_shape(self, base_shape, make_flow): base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape)) last_dim = base_shape[-1] if isinstance(base_shape, tuple) else base_shape - iaf = InverseAutoregressiveFlow(AutoRegressiveNN(last_dim, [40])) + iaf = make_flow(input_dim=last_dim) sample = dist.TransformedDistribution(base_dist, [iaf]).sample() assert sample.shape == base_shape - def test_jacobians(self): + def _make_iaf(self, input_dim): + arn = AutoRegressiveNN(input_dim, [3 * input_dim + 1]) + return dist.InverseAutoregressiveFlow(arn) + + def _make_flipflow(self, input_dim): + permutation = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device) + return dist.FlipFlow(permutation) + + def test_iaf_jacobians(self): + for input_dim in [2, 3, 5, 7, 9, 11]: + self._test_jacobian(input_dim, self._make_iaf) + + def test_flipflow_inverses(self): for input_dim in [2, 3, 5, 7, 9, 11]: - self._test_jacobian(input_dim, 3 * input_dim + 1) + self._test_inverse(input_dim, self._make_flipflow) + + def test_iaf_shapes(self): + for shape in [(3,), (3, 4), (3, 4, 2)]: + self._test_shape(shape, self._make_iaf) - def test_shapes(self): + def test_flipflow_shapes(self): for shape in [(3,), (3, 4), (3, 4, 2)]: - self._test_shape(shape) + self._test_shape(shape, self._make_flipflow) From 18ec5bb8378d1bfbd2ce13ebb9f5ef549613e815 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Tue, 23 Oct 2018 23:49:12 +0100 Subject: [PATCH 04/20] Bug fix --- tests/distributions/test_flows.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/distributions/test_flows.py b/tests/distributions/test_flows.py index 6f8f7d16ab..2de675c791 100644 --- a/tests/distributions/test_flows.py +++ b/tests/distributions/test_flows.py @@ -7,7 +7,6 @@ import torch import pyro.distributions as dist -from pyro.distributions.iaf import InverseAutoregressiveFlow from pyro.nn import AutoRegressiveNN pytestmark = pytest.mark.init(rng_seed=123) @@ -77,7 +76,7 @@ def _make_iaf(self, input_dim): def _make_flipflow(self, input_dim): permutation = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device) - return dist.FlipFlow(permutation) + return dist.PermutationFlow(permutation) def test_iaf_jacobians(self): for input_dim in [2, 3, 5, 7, 9, 11]: From 5cf65bb0ca2579577066a9c77bd7177bf3855ba7 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Wed, 24 Oct 2018 21:23:35 +0100 Subject: [PATCH 05/20] Renamed PermutationFlow to PermuteTransform --- pyro/distributions/__init__.py | 4 ++-- pyro/distributions/{permutation_flow.py => permute.py} | 10 +++++----- tests/distributions/test_flows.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) rename pyro/distributions/{permutation_flow.py => permute.py} (92%) diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 7ff2b61b85..7176ef5c3d 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -14,7 +14,7 @@ from pyro.distributions.lowrank_mvn import LowRankMultivariateNormal from pyro.distributions.mixture import MaskedMixture from pyro.distributions.omt_mvn import OMTMultivariateNormal -from pyro.distributions.permutation_flow import PermutationFlow +from pyro.distributions.permutate import PermuteTransform from pyro.distributions.rejector import Rejector from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough) @@ -43,7 +43,7 @@ "MixtureOfDiagNormalsSharedCovariance", "MixtureOfDiagNormals", "OMTMultivariateNormal", - "PermutationFlow", + "PermuteTransform", "Rejector", "RelaxedBernoulliStraightThrough", "RelaxedOneHotCategoricalStraightThrough", diff --git a/pyro/distributions/permutation_flow.py b/pyro/distributions/permute.py similarity index 92% rename from pyro/distributions/permutation_flow.py rename to pyro/distributions/permute.py index 358755f458..8a31b81682 100644 --- a/pyro/distributions/permutation_flow.py +++ b/pyro/distributions/permute.py @@ -8,22 +8,22 @@ @copy_docs_from(Transform) -class PermutationFlow(Transform): +class PermuteTransform(Transform): """ A normalizing flow that reorders the input dimensions, that is, multiplies the input by a permutation matrix. This is useful in between IAF transforms to increase the flexibility of the resulting distribution and stabilize learning. Whilst not being an autoregressive flow, the log absolute determinate of the Jacobian is easily calculable as 0. Note that reordering the input dimension between two layers of IAF is not equivalent - to reordering the dimension inside the MADE networks that those IAFs use; using a PermutationFlow results in a + to reordering the dimension inside the MADE networks that those IAFs use; using a PermuteTransform results in a distribution with more flexibility. Example usage: >>> from pyro.nn import AutoRegressiveNN - >>> from pyro.distributions import InverseAutoregressiveFlow, PermutationFlow + >>> from pyro.distributions import InverseAutoregressiveFlow, PermuteTransform >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) >>> iaf1 = InverseAutoregressiveFlow(AutoRegressiveNN(10, [40])) - >>> ff = PermutationFlow(torch.randperm(10, dtype=torch.long)) + >>> ff = PermuteTransform(torch.randperm(10, dtype=torch.long)) >>> iaf2 = InverseAutoregressiveFlow(AutoRegressiveNN(10, [40])) >>> iaf_dist = dist.TransformedDistribution(base_dist, [iaf1, ff, iaf2]) >>> iaf_dist.sample() # doctest: +SKIP @@ -38,7 +38,7 @@ class PermutationFlow(Transform): codomain = constraints.real def __init__(self, permutation): - super(PermutationFlow, self).__init__() + super(PermuteTransform, self).__init__() self.permutation = permutation diff --git a/tests/distributions/test_flows.py b/tests/distributions/test_flows.py index 2de675c791..5195fe1fd0 100644 --- a/tests/distributions/test_flows.py +++ b/tests/distributions/test_flows.py @@ -76,7 +76,7 @@ def _make_iaf(self, input_dim): def _make_flipflow(self, input_dim): permutation = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device) - return dist.PermutationFlow(permutation) + return dist.PermuteTransform(permutation) def test_iaf_jacobians(self): for input_dim in [2, 3, 5, 7, 9, 11]: From 3b3003a46acbc9de7c73a5e1da32318d2b54019c Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Wed, 24 Oct 2018 21:42:44 +0100 Subject: [PATCH 06/20] Added PermuteTransform to docs --- docs/source/distributions.rst | 7 +++++++ pyro/distributions/__init__.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 8bcf8f2aa7..09e3a786bd 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -164,3 +164,10 @@ InverseAutoRegressiveFlow :members: :undoc-members: :show-inheritance: + +PermuteTransform +------------------------- +.. autoclass:: pyro.distributions.PermuteTransform + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 7176ef5c3d..874c73ea0e 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -14,7 +14,7 @@ from pyro.distributions.lowrank_mvn import LowRankMultivariateNormal from pyro.distributions.mixture import MaskedMixture from pyro.distributions.omt_mvn import OMTMultivariateNormal -from pyro.distributions.permutate import PermuteTransform +from pyro.distributions.permute import PermuteTransform from pyro.distributions.rejector import Rejector from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough) From 3e1cdec4b763974d3b13099917f1f233bcce512c Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Wed, 24 Oct 2018 21:49:53 +0100 Subject: [PATCH 07/20] Added device to permutation vectors --- pyro/distributions/permute.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/permute.py b/pyro/distributions/permute.py index 8a31b81682..3424af554a 100644 --- a/pyro/distributions/permute.py +++ b/pyro/distributions/permute.py @@ -36,6 +36,7 @@ class PermuteTransform(Transform): """ codomain = constraints.real + bijective = True def __init__(self, permutation): super(PermuteTransform, self).__init__() @@ -43,8 +44,8 @@ def __init__(self, permutation): self.permutation = permutation # Calculate the inverse permutation order - self.inv_permutation = torch.empty(permutation.shape, dtype=torch.long) - self.inv_permutation[permutation] = torch.arange(permutation.size(0), dtype=torch.long) + self.inv_permutation = torch.empty_like(permutation, dtype=torch.long) + self.inv_permutation[permutation] = torch.arange(permutation.size(0), dtype=torch.long, device=permutation.device) def _call(self, x): """ From 6fcd765b3bf774ef10bf5d963c1f748cc1c4e57d Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Wed, 24 Oct 2018 21:53:00 +0100 Subject: [PATCH 08/20] PEP8 --- pyro/distributions/permute.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyro/distributions/permute.py b/pyro/distributions/permute.py index 3424af554a..4a88075479 100644 --- a/pyro/distributions/permute.py +++ b/pyro/distributions/permute.py @@ -45,7 +45,8 @@ def __init__(self, permutation): # Calculate the inverse permutation order self.inv_permutation = torch.empty_like(permutation, dtype=torch.long) - self.inv_permutation[permutation] = torch.arange(permutation.size(0), dtype=torch.long, device=permutation.device) + self.inv_permutation[permutation] = torch.arange(permutation.size(0), dtype=torch.long, + device=permutation.device) def _call(self, x): """ From 721ee632b753d24cbac0038483495280fc81699a Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Thu, 25 Oct 2018 13:29:41 +0100 Subject: [PATCH 09/20] Removed 'flow', link to IAF in docs, fixed other bug in docs --- pyro/distributions/permute.py | 15 ++++++++------- pyro/distributions/torch_distribution.py | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pyro/distributions/permute.py b/pyro/distributions/permute.py index 4a88075479..b0313cc525 100644 --- a/pyro/distributions/permute.py +++ b/pyro/distributions/permute.py @@ -10,12 +10,13 @@ @copy_docs_from(Transform) class PermuteTransform(Transform): """ - A normalizing flow that reorders the input dimensions, that is, multiplies the input by a permutation matrix. - This is useful in between IAF transforms to increase the flexibility of the resulting distribution and - stabilize learning. Whilst not being an autoregressive flow, the log absolute determinate of the Jacobian is - easily calculable as 0. Note that reordering the input dimension between two layers of IAF is not equivalent - to reordering the dimension inside the MADE networks that those IAFs use; using a PermuteTransform results in a - distribution with more flexibility. + A bijection that reorders the input dimensions, that is, multiplies the input by a permutation matrix. + This is useful in between :class:`~pyro.distributions.InverseAutoregressiveFlow` transforms to increase the + flexibility of the resulting distribution and stabilize learning. Whilst not being an autoregressive flow, + the log absolute determinate of the Jacobian is easily calculable as 0. Note that reordering the input dimension + between two layers of :class:`~pyro.distributions.InverseAutoregressiveFlow` is not equivalent to reordering + the dimension inside the MADE networks that those IAFs use; using a PermuteTransform results in a distribution + with more flexibility. Example usage: @@ -72,7 +73,7 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log Jacobian, i.e. log(abs([dy_0/dx_0, ..., dy_{N-1}/dx_{N-1}])). - Note that this type of flow is not autoregressive, so the log Jacobian is not the sum of the previous + Note that this type of transform is not autoregressive, so the log Jacobian is not the sum of the previous expression. However, it turns out it's always 0 (since the determinant is -1 or +1), and so returning a vector of zeros works. """ diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index 7276478b14..27828013ce 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -199,7 +199,7 @@ class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin assert d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape Distributions provide a vectorized - :meth`~torch.distributions.distribution.Distribution.log_prob` method that + :meth:`~torch.distributions.distribution.Distribution.log_prob` method that evaluates the log probability density of each event in a batch independently, returning a tensor of shape ``sample_shape + d.batch_shape``:: From f64dd91c6938b1c63a666401190ce324adc6e5d0 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Thu, 25 Oct 2018 13:32:12 +0100 Subject: [PATCH 10/20] Removed more 'flow's --- pyro/distributions/permute.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/permute.py b/pyro/distributions/permute.py index b0313cc525..282991ea9e 100644 --- a/pyro/distributions/permute.py +++ b/pyro/distributions/permute.py @@ -12,7 +12,7 @@ class PermuteTransform(Transform): """ A bijection that reorders the input dimensions, that is, multiplies the input by a permutation matrix. This is useful in between :class:`~pyro.distributions.InverseAutoregressiveFlow` transforms to increase the - flexibility of the resulting distribution and stabilize learning. Whilst not being an autoregressive flow, + flexibility of the resulting distribution and stabilize learning. Whilst not being an autoregressive transform, the log absolute determinate of the Jacobian is easily calculable as 0. Note that reordering the input dimension between two layers of :class:`~pyro.distributions.InverseAutoregressiveFlow` is not equivalent to reordering the dimension inside the MADE networks that those IAFs use; using a PermuteTransform results in a distribution @@ -55,7 +55,7 @@ def _call(self, x): :type x: torch.Tensor Invokes the bijection x=>y; in the prototypical context of a TransformedDistribution `x` is a - sample from the base distribution (or the output of a previous flow) + sample from the base distribution (or the output of a previous transform) """ return x[..., self.permutation] From 52979b8a60c519b1ae68ec69056c157b2054f9e6 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Thu, 25 Oct 2018 18:29:08 +0100 Subject: [PATCH 11/20] Added lazy_property to inv_permutation of PermuteTransform --- pyro/distributions/permute.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pyro/distributions/permute.py b/pyro/distributions/permute.py index 282991ea9e..19b79ac08b 100644 --- a/pyro/distributions/permute.py +++ b/pyro/distributions/permute.py @@ -2,6 +2,7 @@ import torch from torch.distributions.transforms import Transform +from torch.distributions.utils import lazy_property from torch.distributions import constraints from pyro.distributions.util import copy_docs_from @@ -44,10 +45,13 @@ def __init__(self, permutation): self.permutation = permutation - # Calculate the inverse permutation order - self.inv_permutation = torch.empty_like(permutation, dtype=torch.long) - self.inv_permutation[permutation] = torch.arange(permutation.size(0), dtype=torch.long, - device=permutation.device) + @lazy_property + def inv_permutation(self): + result = torch.empty_like(self.permutation, dtype=torch.long) + result[self.permutation] = torch.arange(self.permutation.size(0), + dtype=torch.long, + device=self.permutation.device) + return result def _call(self, x): """ From 7b92c38358970b310fb8e09a7053b459155f6e04 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Thu, 25 Oct 2018 21:21:13 +0100 Subject: [PATCH 12/20] Inverse operations for IAF and alternative version --- docs/source/distributions.rst | 7 ++ pyro/contrib/autoguide/__init__.py | 7 +- pyro/distributions/__init__.py | 3 +- pyro/distributions/iaf.py | 193 +++++++++++++++++++++++++---- tests/distributions/test_flows.py | 25 ++++ 5 files changed, 204 insertions(+), 31 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 09e3a786bd..c2c18d54a7 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -165,6 +165,13 @@ InverseAutoRegressiveFlow :undoc-members: :show-inheritance: +InverseAutoRegressiveFlowStable +------------------------- +.. autoclass:: pyro.distributions.InverseAutoregressiveFlowStable + :members: + :undoc-members: + :show-inheritance: + PermuteTransform ------------------------- .. autoclass:: pyro.distributions.PermuteTransform diff --git a/pyro/contrib/autoguide/__init__.py b/pyro/contrib/autoguide/__init__.py index a3fcc5efbd..a107a3921a 100644 --- a/pyro/contrib/autoguide/__init__.py +++ b/pyro/contrib/autoguide/__init__.py @@ -602,11 +602,9 @@ class AutoIAFNormal(AutoContinuous): :param callable model: a generative model :param int hidden_dim: number of hidden dimensions in the IAF - :param float sigmoid_bias: sigmoid bias in the IAF. Defaults to ``2.0`` :param str prefix: a prefix that will be prefixed to all param internal sites """ - def __init__(self, model, hidden_dim=None, sigmoid_bias=2.0, prefix="auto"): - self.sigmoid_bias = sigmoid_bias + def __init__(self, model, hidden_dim=None, prefix="auto"): self.hidden_dim = hidden_dim super(AutoIAFNormal, self).__init__(model, prefix) @@ -619,8 +617,7 @@ def get_posterior(self, *args, **kwargs): raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead') if self.hidden_dim is None: self.hidden_dim = self.latent_dim - iaf = dist.InverseAutoregressiveFlow(AutoRegressiveNN(self.latent_dim, [self.hidden_dim]), - sigmoid_bias=self.sigmoid_bias) + iaf = dist.InverseAutoregressiveFlow(AutoRegressiveNN(self.latent_dim, [self.hidden_dim])) pyro.module("{}_iaf".format(self.prefix), iaf.module) iaf_dist = dist.TransformedDistribution(dist.Normal(0., 1.).expand([self.latent_dim]), [iaf]) return iaf_dist.independent(1) diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 874c73ea0e..2f6f43c43e 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -10,7 +10,7 @@ from pyro.distributions.empirical import Empirical from pyro.distributions.gaussian_scale_mixture import GaussianScaleMixture from pyro.distributions.half_cauchy import HalfCauchy -from pyro.distributions.iaf import InverseAutoregressiveFlow +from pyro.distributions.iaf import InverseAutoregressiveFlow, InverseAutoregressiveFlowStable from pyro.distributions.lowrank_mvn import LowRankMultivariateNormal from pyro.distributions.mixture import MaskedMixture from pyro.distributions.omt_mvn import OMTMultivariateNormal @@ -38,6 +38,7 @@ "GaussianScaleMixture", "HalfCauchy", "InverseAutoregressiveFlow", + "InverseAutoregressiveFlowStable", "LowRankMultivariateNormal", "MaskedMixture", "MixtureOfDiagNormalsSharedCovariance", diff --git a/pyro/distributions/iaf.py b/pyro/distributions/iaf.py index 2bf84faa90..ec2a5f5188 100644 --- a/pyro/distributions/iaf.py +++ b/pyro/distributions/iaf.py @@ -7,12 +7,19 @@ from pyro.distributions.util import copy_docs_from +# This helper function clamps gradients but still passes through the gradient in clamped regions +# NOTE: Not sure how necessary this is, but I was copying the design of the TensorFlow implementation + + +def clamp_preserve_gradients(x, min, max): + return x + (x.clamp(min, max) - x).detach() + @copy_docs_from(Transform) class InverseAutoregressiveFlow(Transform): """ - An implementation of an Inverse Autoregressive Flow. Together with the `TransformedDistribution` this - provides a way to create richer variational approximations. + An implementation of an Inverse Autoregressive Flow, using Eq (10) from Kingma Et Al., 2016. + Together with `TransformedDistribution` this provides a way to create richer variational approximations. Example usage: @@ -25,12 +32,137 @@ class InverseAutoregressiveFlow(Transform): tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868, 0.1389, -0.4629, 0.0986]) - Note that this implementation is only meant to be used in settings where the inverse of the Bijector - is never explicitly computed (rather the result is cached from the forward call). In the context of - variational inference, this means that the InverseAutoregressiveFlow should only be used in the guide, - i.e. in the variational distribution. In other contexts the inverse could in principle be computed but - this would be a (potentially) costly computation that scales with the dimension of the input (and in - any case support for this is not included in this implementation). + The inverse of the Bijector is required when, e.g., scoring the log density of a sample with + `TransformedDistribution`. This implementation caches the inverse of the Bijector when its forward + operation is called, e.g., when sampling from `TransformedDistribution`. However, if the cached value + isn't available, either because it was already popped from the cache, or an arbitary value is being + scored, it will calculate it manually. Note that this is an operation that scales as O(D) where D is + the input dimension, and so should be avoided for large dimensional uses. So in general, it is cheap + to sample from IAF and score a value that was sampled by IAF, but expensive to score an arbitrary value. + + :param autoregressive_nn: an autoregressive neural network whose forward call returns a real-valued + mean and logit-scale as a tuple + :type autoregressive_nn: nn.Module + :param log_scale_min_clip: The minimum value for clipping the log(scale) from the autoregressive NN + :type log_scale_min_clip: float + :param log_scale_max_clip: The maximum value for clipping the log(scale) from the autoregressive NN + :type log_scale_max_clip: float + + References: + + 1. Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934] + Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling + + 2. Variational Inference with Normalizing Flows [arXiv:1505.05770] + Danilo Jimenez Rezende, Shakir Mohamed + + 3. MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509] + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle + """ + + codomain = constraints.real + + def __init__(self, autoregressive_nn, log_scale_min_clip=-5., log_scale_max_clip=3.): + super(InverseAutoregressiveFlow, self).__init__() + self.module = nn.Module() + self.module.arn = autoregressive_nn + self._intermediates_cache = {} + self.add_inverse_to_cache = True + self.log_scale_min_clip = log_scale_min_clip + self.log_scale_max_clip = log_scale_max_clip + + @property + def arn(self): + """ + :rtype: pyro.nn.AutoRegressiveNN + + Return the AutoRegressiveNN associated with the InverseAutoregressiveFlow + """ + return self.module.arn + + def _call(self, x): + """ + :param x: the input into the bijection + :type x: torch.Tensor + + Invokes the bijection x=>y; in the prototypical context of a TransformedDistribution `x` is a + sample from the base distribution (or the output of a previous flow) + """ + mean, log_scale = self.module.arn(x) + log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) + scale = torch.exp(log_scale) + + y = scale * x + mean + self._add_intermediate_to_cache(x, y, 'x') + self._add_intermediate_to_cache(log_scale, y, 'log_scale') + return y + + def _inverse(self, y): + """ + :param y: the output of the bijection + :type y: torch.Tensor + + Inverts y => x. Uses a previously cached inverse if available, otherwise performs the inversion afresh. + """ + if (y, 'x') in self._intermediates_cache: + x = self._intermediates_cache.pop((y, 'x')) + return x + else: + x_size = y.size()[:-1] + perm = self.module.arn.permutation + input_dim = y.size(-1) + x = [torch.zeros(x_size, device=y.device)] * input_dim + + # NOTE: Inversion is an expensive operation that scales in the dimension of the input + for idx in perm: + mean, log_scale = self.module.arn(torch.stack(x, dim=-1)) + inverse_scale = torch.exp(-clamp_preserve_gradients( + log_scale[..., idx], min=self.log_scale_min_clip, max=self.log_scale_max_clip)) + mean = mean[..., idx] + x[idx] = (y[..., idx] - mean) * inverse_scale + + x = torch.stack(x, dim=-1) + return x + + def _add_intermediate_to_cache(self, intermediate, y, name): + """ + Internal function used to cache intermediate results computed during the forward call + """ + assert((y, name) not in self._intermediates_cache),\ + "key collision in _add_intermediate_to_cache" + self._intermediates_cache[(y, name)] = intermediate + + def log_abs_det_jacobian(self, x, y): + """ + Calculates the elementwise determinant of the log jacobian + """ + if (y, 'log_scale') in self._intermediates_cache: + log_scale = self._intermediates_cache.pop((y, 'log_scale')) + else: + _, log_scale = self.module.arn(x) + log_scale = clamp_preserve_gradients(log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip) + return log_scale + + +@copy_docs_from(Transform) +class InverseAutoregressiveFlowStable(Transform): + """ + An implementation of an Inverse Autoregressive Flow, using Eqs (13)/(14) from Kingma Et Al., 2016. + This variant of IAF is claimed by the authors to be more numerically stable than one using Eq (10), + although in practice it leads to a restriction on the distributions that can be represented. + + Example usage: + + >>> from pyro.nn import AutoRegressiveNN + >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) + >>> iaf = InverseAutoregressiveFlowStable(AutoRegressiveNN(10, [40])) + >>> iaf_module = pyro.module("my_iaf", iaf.module) + >>> iaf_dist = dist.TransformedDistribution(base_dist, [iaf]) + >>> iaf_dist.sample() # doctest: +SKIP + tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868, + 0.1389, -0.4629, 0.0986]) + + See `InverseAutoregressiveFlow` docs for a discussion of the running cost. :param autoregressive_nn: an autoregressive neural network whose forward call returns a real-valued mean and logit-scale as a tuple @@ -53,11 +185,12 @@ class InverseAutoregressiveFlow(Transform): codomain = constraints.real def __init__(self, autoregressive_nn, sigmoid_bias=2.0): - super(InverseAutoregressiveFlow, self).__init__() + super(InverseAutoregressiveFlowStable, self).__init__() self.module = nn.Module() self.module.arn = autoregressive_nn - self.module.sigmoid = nn.Sigmoid() - self.module.sigmoid_bias = torch.tensor(sigmoid_bias) + self.sigmoid = nn.Sigmoid() + self.logsigmoid = nn.LogSigmoid() + self.sigmoid_bias = sigmoid_bias self._intermediates_cache = {} self.add_inverse_to_cache = True @@ -78,13 +211,14 @@ def _call(self, x): Invokes the bijection x=>y; in the prototypical context of a TransformedDistribution `x` is a sample from the base distribution (or the output of a previous flow) """ - mean, scale = self.module.arn(x) - scale = self.module.sigmoid(scale + self.module.sigmoid_bias.to(dtype=x.dtype, - device=x.device)) + mean, logit_scale = self.module.arn(x) + logit_scale = logit_scale + self.sigmoid_bias + scale = self.sigmoid(logit_scale) + log_scale = self.logsigmoid(logit_scale) y = scale * x + (1 - scale) * mean self._add_intermediate_to_cache(x, y, 'x') - self._add_intermediate_to_cache(scale, y, 'scale') + self._add_intermediate_to_cache(log_scale, y, 'log_scale') return y def _inverse(self, y): @@ -92,16 +226,25 @@ def _inverse(self, y): :param y: the output of the bijection :type y: torch.Tensor - Inverts y => x. As noted above, this implementation is incapable of inverting arbitrary values - `y`; rather it assumes `y` is the result of a previously computed application of the bijector - to some `x` (which was cached on the forward call) + Inverts y => x. Uses a previously cached inverse if available, otherwise performs the inversion afresh. """ if (y, 'x') in self._intermediates_cache: x = self._intermediates_cache.pop((y, 'x')) return x else: - raise KeyError("InverseAutoregressiveFlow expected to find " - "key in intermediates cache but didn't") + x_size = y.size()[:-1] + perm = self.module.arn.permutation + input_dim = y.size(-1) + x = [torch.zeros(x_size, device=y.device)] * input_dim + + # NOTE: Inversion is an expensive operation that scales in the dimension of the input + for idx in perm: + mean, logit_scale = self.module.arn(torch.stack(x, dim=-1)) + inverse_scale = 1 + torch.exp(-logit_scale[..., idx] - self.sigmoid_bias) + x[idx] = inverse_scale * y[..., idx] + (1 - inverse_scale) * mean[..., idx] + + x = torch.stack(x, dim=-1) + return x def _add_intermediate_to_cache(self, intermediate, y, name): """ @@ -115,9 +258,9 @@ def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian """ - if (y, 'scale') in self._intermediates_cache: - scale = self._intermediates_cache.pop((y, 'scale')) + if (y, 'log_scale') in self._intermediates_cache: + log_scale = self._intermediates_cache.pop((y, 'log_scale')) else: - raise KeyError("Bijector InverseAutoregressiveFlow expected to find" + - "key in intermediates cache but didn't") - return scale.log() + _, logit_scale = self.module.arn(x) + log_scale = self.logsigmoid(logit_scale + self.sigmoid_bias) + return log_scale diff --git a/tests/distributions/test_flows.py b/tests/distributions/test_flows.py index 5195fe1fd0..6d7c23bcb8 100644 --- a/tests/distributions/test_flows.py +++ b/tests/distributions/test_flows.py @@ -59,6 +59,11 @@ def _test_inverse(self, input_dim, make_flow): x_true = base_dist.sample(torch.Size([10])) y = iaf._call(x_true) + + # This line empties the inverse cache, if the flow uses it + iaf._inverse(y) + + # Cache is empty, hence must be calculating inverse afresh x_calculated = iaf._inverse(y) assert torch.norm(x_true - x_calculated, dim=-1).max().item() < self.delta @@ -74,6 +79,10 @@ def _make_iaf(self, input_dim): arn = AutoRegressiveNN(input_dim, [3 * input_dim + 1]) return dist.InverseAutoregressiveFlow(arn) + def _make_iaf_stable(self, input_dim): + arn = AutoRegressiveNN(input_dim, [3 * input_dim + 1]) + return dist.InverseAutoregressiveFlowStable(arn, sigmoid_bias=0.5) + def _make_flipflow(self, input_dim): permutation = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device) return dist.PermuteTransform(permutation) @@ -82,6 +91,18 @@ def test_iaf_jacobians(self): for input_dim in [2, 3, 5, 7, 9, 11]: self._test_jacobian(input_dim, self._make_iaf) + def test_iaf_stable_jacobians(self): + for input_dim in [2, 3, 5, 7, 9, 11]: + self._test_jacobian(input_dim, self._make_iaf_stable) + + def test_iaf_inverses(self): + for input_dim in [2, 3, 5, 7, 9, 11]: + self._test_inverse(input_dim, self._make_iaf) + + def test_iaf_stable_inverses(self): + for input_dim in [2, 3, 5, 7, 9, 11]: + self._test_inverse(input_dim, self._make_iaf_stable) + def test_flipflow_inverses(self): for input_dim in [2, 3, 5, 7, 9, 11]: self._test_inverse(input_dim, self._make_flipflow) @@ -90,6 +111,10 @@ def test_iaf_shapes(self): for shape in [(3,), (3, 4), (3, 4, 2)]: self._test_shape(shape, self._make_iaf) + def test_iaf_stable_shapes(self): + for shape in [(3,), (3, 4), (3, 4, 2)]: + self._test_shape(shape, self._make_iaf_stable) + def test_flipflow_shapes(self): for shape in [(3,), (3, 4), (3, 4, 2)]: self._test_shape(shape, self._make_flipflow) From 109712b7bba5b6a8e3c2dec5c22fa7b231bdf4ec Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Thu, 25 Oct 2018 23:43:36 +0100 Subject: [PATCH 13/20] Fixed docs error --- docs/source/distributions.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index c2c18d54a7..a671588309 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -166,14 +166,14 @@ InverseAutoRegressiveFlow :show-inheritance: InverseAutoRegressiveFlowStable -------------------------- +------------------------------- .. autoclass:: pyro.distributions.InverseAutoregressiveFlowStable :members: :undoc-members: :show-inheritance: PermuteTransform -------------------------- +---------------- .. autoclass:: pyro.distributions.PermuteTransform :members: :undoc-members: From df8527999147729b6e6b65b37565d9842313ce1f Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Fri, 26 Oct 2018 18:38:38 +0100 Subject: [PATCH 14/20] Equations in docs --- pyro/distributions/iaf.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pyro/distributions/iaf.py b/pyro/distributions/iaf.py index ec2a5f5188..c993db0f06 100644 --- a/pyro/distributions/iaf.py +++ b/pyro/distributions/iaf.py @@ -18,7 +18,13 @@ def clamp_preserve_gradients(x, min, max): @copy_docs_from(Transform) class InverseAutoregressiveFlow(Transform): """ - An implementation of an Inverse Autoregressive Flow, using Eq (10) from Kingma Et Al., 2016. + An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma Et Al., 2016, + + :math:`\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}` + + where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs, :math:`\mu_t,\sigma_t` + are calculated from an autoregressive network on :math:`\mathbf{x}`, and :math:`\sigma_t>0`. + Together with `TransformedDistribution` this provides a way to create richer variational approximations. Example usage: @@ -147,9 +153,17 @@ def log_abs_det_jacobian(self, x, y): @copy_docs_from(Transform) class InverseAutoregressiveFlowStable(Transform): """ - An implementation of an Inverse Autoregressive Flow, using Eqs (13)/(14) from Kingma Et Al., 2016. + An implementation of an Inverse Autoregressive Flow, using Eqs (13)/(14) from Kingma Et Al., 2016, + + :math:`\mathbf{y} = \sigma_t\odot\mathbf{x} + (1-\sigma_t)\odot\mu_t` + + where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs, :math:`\mu_t,\sigma_t` + are calculated from an autoregressive network on :math:`\mathbf{x}`, and :math:`\sigma_t` is + restricted to :math:`[0,1]`. + This variant of IAF is claimed by the authors to be more numerically stable than one using Eq (10), - although in practice it leads to a restriction on the distributions that can be represented. + although in practice it leads to a restriction on the distributions that can be represented, + presumably since the input is restricted to rescaling by a number on :math:`[0,1]`. Example usage: From c98f6a074268b41324e0a07949ef0045de338c13 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Fri, 26 Oct 2018 19:12:19 +0100 Subject: [PATCH 15/20] Fixed docstrings --- pyro/distributions/iaf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyro/distributions/iaf.py b/pyro/distributions/iaf.py index c993db0f06..6520881817 100644 --- a/pyro/distributions/iaf.py +++ b/pyro/distributions/iaf.py @@ -20,10 +20,10 @@ class InverseAutoregressiveFlow(Transform): """ An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma Et Al., 2016, - :math:`\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}` + :math:`\\mathbf{y} = \\mu_t + \\sigma_t\\odot\\mathbf{x}` - where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs, :math:`\mu_t,\sigma_t` - are calculated from an autoregressive network on :math:`\mathbf{x}`, and :math:`\sigma_t>0`. + where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, :math:`\\mu_t,\\sigma_t` + are calculated from an autoregressive network on :math:`\\mathbf{x}`, and :math:`\\sigma_t>0`. Together with `TransformedDistribution` this provides a way to create richer variational approximations. @@ -155,10 +155,10 @@ class InverseAutoregressiveFlowStable(Transform): """ An implementation of an Inverse Autoregressive Flow, using Eqs (13)/(14) from Kingma Et Al., 2016, - :math:`\mathbf{y} = \sigma_t\odot\mathbf{x} + (1-\sigma_t)\odot\mu_t` + :math:`\\mathbf{y} = \\sigma_t\\odot\\mathbf{x} + (1-\\sigma_t)\\odot\\mu_t` - where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs, :math:`\mu_t,\sigma_t` - are calculated from an autoregressive network on :math:`\mathbf{x}`, and :math:`\sigma_t` is + where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, :math:`\\mu_t,\\sigma_t` + are calculated from an autoregressive network on :math:`\\mathbf{x}`, and :math:`\\sigma_t` is restricted to :math:`[0,1]`. This variant of IAF is claimed by the authors to be more numerically stable than one using Eq (10), From a75fed1e55fcc02862cdcbf9fad7093c07ca5339 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Mon, 5 Nov 2018 00:51:51 +0000 Subject: [PATCH 16/20] Planar flow (untested) --- pyro/distributions/iaf.py | 8 +-- pyro/distributions/planar.py | 116 +++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 5 deletions(-) create mode 100644 pyro/distributions/planar.py diff --git a/pyro/distributions/iaf.py b/pyro/distributions/iaf.py index 6520881817..9c457636e0 100644 --- a/pyro/distributions/iaf.py +++ b/pyro/distributions/iaf.py @@ -128,6 +128,8 @@ def _inverse(self, y): x[idx] = (y[..., idx] - mean) * inverse_scale x = torch.stack(x, dim=-1) + log_scale = clamp_preserve_gradients(log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip) + self._add_intermediate_to_cache(log_scale, y, 'log_scale') return x def _add_intermediate_to_cache(self, intermediate, y, name): @@ -142,11 +144,7 @@ def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian """ - if (y, 'log_scale') in self._intermediates_cache: - log_scale = self._intermediates_cache.pop((y, 'log_scale')) - else: - _, log_scale = self.module.arn(x) - log_scale = clamp_preserve_gradients(log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip) + log_scale = self._intermediates_cache.pop((y, 'log_scale')) return log_scale diff --git a/pyro/distributions/planar.py b/pyro/distributions/planar.py new file mode 100644 index 0000000000..18fa638511 --- /dev/null +++ b/pyro/distributions/planar.py @@ -0,0 +1,116 @@ +from __future__ import absolute_import, division, print_function + +import math + +import torch +import torch.nn as nn +from torch.distributions.transforms import Transform +from torch.distributions import constraints + +from pyro.distributions.util import copy_docs_from + +# This helper function clamps gradients but still passes through the gradient in clamped regions +# NOTE: Not sure how necessary this is, but I was copying the design of the TensorFlow implementation + + +def clamp_preserve_gradients(x, min, max): + return x + (x.clamp(min, max) - x).detach() + + +@copy_docs_from(Transform) +class PlanarFlow(Transform): + """ + A 'planar' normalizing flow that uses the transformation + + :math:`\\mathbf{y} = \\mathbf{x} + \\mathbf{u}\\tanh(\\mathbf{w}^T\\mathbf{z}+b)` + + where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, and the learnable parameters + are :math:`b\\in\\mathbb{R}`, :math:`\\mathbf{u}\\in\\mathbb{R}^D`, :math:`\\mathbf{w}\\in\\mathbb{R}^D` for input + dimension :math:`D`. For this to be an invertible transformation, the condition + :math:`\\mathbf{w}^T\\mathbf{u}>-1` is enforced. + + Together with `TransformedDistribution` this provides a way to create richer variational approximations. + + Example usage: + + >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) + >>> plf = PlanarFlow(10) + >>> plf_module = pyro.module("my_plf", plf.module) + >>> plf_dist = dist.TransformedDistribution(base_dist, [plf]) + >>> plf_dist.sample() # doctest: +SKIP + tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868, + 0.1389, -0.4629, 0.0986]) + + The inverse of this transform does not possess an analytical solution and is left unimplemented. However, + the inverse is cached when the forward operation is called during sampling, and so samples drawn using + planar flow can be scored. + + References: + + Variational Inference with Normalizing Flows [arXiv:1505.05770] + Danilo Jimenez Rezende, Shakir Mohamed + + """ + + codomain = constraints.real + + def __init__(self, input_dim): + super(PlanarFlow, self).__init__() + self.input_dim = input_dim + self.module = nn.Module() + self.module.lin = nn.Linear(input_dim, 1) + self.module.u = nn.Parameter(torch.Tensor(input_dim)) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.module.u.size(1)) + self.module.lin.data.uniform_(-stdv, stdv) + + def _call(self, x): + """ + :param x: the input into the bijection + :type x: torch.Tensor + + Invokes the bijection x=>y; in the prototypical context of a TransformedDistribution `x` is a + sample from the base distribution (or the output of a previous flow) + """ + y = x + self.module.u * torch.tanh(self.module.lin(x)) + + self._add_intermediate_to_cache(x, y, 'x') + return y + + def _inverse(self, y): + """ + :param y: the output of the bijection + :type y: torch.Tensor + + Inverts y => x. As noted above, this implementation is incapable of inverting arbitrary values + `y`; rather it assumes `y` is the result of a previously computed application of the bijector + to some `x` (which was cached on the forward call) + """ + if (y, 'x') in self._intermediates_cache: + x = self._intermediates_cache.pop((y, 'x')) + return x + else: + raise KeyError("PlanarFlow expected to find " + "key in intermediates cache but didn't") + + def _add_intermediate_to_cache(self, intermediate, y, name): + """ + Internal function used to cache intermediate results computed during the forward call + """ + assert((y, name) not in self._intermediates_cache),\ + "key collision in _add_intermediate_to_cache" + self._intermediates_cache[(y, name)] = intermediate + + def log_abs_det_jacobian(self, x, y): + """ + Calculates the elementwise determinant of the log jacobian + """ + psi_z = (1 - torch.tanh(self.module.lin(x)).pow(2))*self.module.lin.W + + # TODO: Check that dimensions of W broadcast properly! + print('W', self.module.lin.W.size(), 'psi_z', psi_z.size(), 'u', self.u.size()) + raise Exception() + + return torch.abs(1 + torch.dot(self.u, psi_z)) From f7be7366d31f88f6eee15c8840c8aca174dee0b6 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Mon, 5 Nov 2018 16:53:14 +0000 Subject: [PATCH 17/20] Debugging planar flow --- pyro/distributions/__init__.py | 2 ++ pyro/distributions/planar.py | 32 ++++++++++++++++++------ tests/distributions/test_flows.py | 41 +++++++++++++++++++++---------- 3 files changed, 55 insertions(+), 20 deletions(-) diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 2f6f43c43e..bf2ffcea69 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -15,6 +15,7 @@ from pyro.distributions.mixture import MaskedMixture from pyro.distributions.omt_mvn import OMTMultivariateNormal from pyro.distributions.permute import PermuteTransform +from pyro.distributions.planar import PlanarFlow from pyro.distributions.rejector import Rejector from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough) @@ -45,6 +46,7 @@ "MixtureOfDiagNormals", "OMTMultivariateNormal", "PermuteTransform", + "PlanarFlow", "Rejector", "RelaxedBernoulliStraightThrough", "RelaxedOneHotCategoricalStraightThrough", diff --git a/pyro/distributions/planar.py b/pyro/distributions/planar.py index 18fa638511..6703d1807d 100644 --- a/pyro/distributions/planar.py +++ b/pyro/distributions/planar.py @@ -6,6 +6,7 @@ import torch.nn as nn from torch.distributions.transforms import Transform from torch.distributions import constraints +import torch.nn.functional as F from pyro.distributions.util import copy_docs_from @@ -61,10 +62,23 @@ def __init__(self, input_dim): self.module.lin = nn.Linear(input_dim, 1) self.module.u = nn.Parameter(torch.Tensor(input_dim)) self.reset_parameters() + self._intermediates_cache = {} + self.add_inverse_to_cache = True def reset_parameters(self): - stdv = 1. / math.sqrt(self.module.u.size(1)) - self.module.lin.data.uniform_(-stdv, stdv) + stdv = 1. / math.sqrt(self.module.u.size(0)) + self.module.lin.weight.data.uniform_(-stdv, stdv) + + def u_hat(self): + u = self.module.u + + # TODO: Reshape W? + w = self.module.lin.weight.squeeze(0) + + alpha = torch.dot(u, w) + a_prime = -1 + F.softplus(alpha) + + return u + (a_prime - alpha) * w.div(w.norm()) def _call(self, x): """ @@ -74,7 +88,8 @@ def _call(self, x): Invokes the bijection x=>y; in the prototypical context of a TransformedDistribution `x` is a sample from the base distribution (or the output of a previous flow) """ - y = x + self.module.u * torch.tanh(self.module.lin(x)) + + y = x + self.u_hat() * torch.tanh(self.module.lin(x)) self._add_intermediate_to_cache(x, y, 'x') return y @@ -107,10 +122,13 @@ def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian """ - psi_z = (1 - torch.tanh(self.module.lin(x)).pow(2))*self.module.lin.W + psi_z = (1 - torch.tanh(self.module.lin(x)).pow(2))*self.module.lin.weight # TODO: Check that dimensions of W broadcast properly! - print('W', self.module.lin.W.size(), 'psi_z', psi_z.size(), 'u', self.u.size()) - raise Exception() + #print('W', self.module.lin.weight.size(), 'psi_z', psi_z.size(), 'u', self.module.u.size()) + #raise Exception() - return torch.abs(1 + torch.dot(self.u, psi_z)) + # TODO: Continue from here, 5/11/2018! + # *** Need to take account of fact that psi_z has a batch dimension + #return torch.abs(1 + torch.dot(self.u_hat(), psi_z)) + return torch.abs(1 + psi_z * self.u_hat()) diff --git a/tests/distributions/test_flows.py b/tests/distributions/test_flows.py index 6d7c23bcb8..63cf2d3755 100644 --- a/tests/distributions/test_flows.py +++ b/tests/distributions/test_flows.py @@ -12,7 +12,7 @@ pytestmark = pytest.mark.init(rng_seed=123) -class AutoregressiveFlowTests(TestCase): +class FlowTests(TestCase): def setUp(self): # Epsilon is used to compare numerical gradient to analytical one self.epsilon = 1e-3 @@ -22,32 +22,36 @@ def setUp(self): def _test_jacobian(self, input_dim, make_flow): jacobian = torch.zeros(input_dim, input_dim) - iaf = make_flow(input_dim) + flow = make_flow(input_dim) def nonzero(x): return torch.sign(torch.abs(x)) x = torch.randn(1, input_dim) - iaf_x = iaf(x) - analytic_ldt = iaf.log_abs_det_jacobian(x, iaf_x).data.sum() + flow_x = flow(x) + analytic_ldt = flow.log_abs_det_jacobian(x, flow_x).data.sum() for j in range(input_dim): for k in range(input_dim): epsilon_vector = torch.zeros(1, input_dim) epsilon_vector[0, j] = self.epsilon - delta = (iaf(x + 0.5 * epsilon_vector) - iaf(x - 0.5 * epsilon_vector)) / self.epsilon + delta = (flow(x + 0.5 * epsilon_vector) - flow(x - 0.5 * epsilon_vector)) / self.epsilon jacobian[j, k] = float(delta[0, k].data.sum()) - permutation = iaf.arn.get_permutation() - permuted_jacobian = jacobian.clone() - for j in range(input_dim): - for k in range(input_dim): - permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]] - numeric_ldt = torch.sum(torch.log(torch.diag(permuted_jacobian))) + # Apply permutation for autoregressive flows + if hasattr(flow, 'arn'): + permutation = flow.arn.get_permutation() + permuted_jacobian = jacobian.clone() + for j in range(input_dim): + for k in range(input_dim): + permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]] + jacobian = permuted_jacobian + + numeric_ldt = torch.sum(torch.log(torch.diag(jacobian))) ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt) - diag_sum = torch.sum(torch.diag(nonzero(permuted_jacobian))) - lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=-1)) + diag_sum = torch.sum(torch.diag(nonzero(jacobian))) + lower_sum = torch.sum(torch.tril(nonzero(jacobian), diagonal=-1)) assert ldt_discrepancy < self.epsilon assert diag_sum == float(input_dim) @@ -87,6 +91,9 @@ def _make_flipflow(self, input_dim): permutation = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device) return dist.PermuteTransform(permutation) + def _make_planar(self, input_dim): + return dist.PlanarFlow(input_dim) + def test_iaf_jacobians(self): for input_dim in [2, 3, 5, 7, 9, 11]: self._test_jacobian(input_dim, self._make_iaf) @@ -95,6 +102,10 @@ def test_iaf_stable_jacobians(self): for input_dim in [2, 3, 5, 7, 9, 11]: self._test_jacobian(input_dim, self._make_iaf_stable) + def test_planar_jacobians(self): + for input_dim in [2, 3, 5, 7, 9, 11]: + self._test_jacobian(input_dim, self._make_planar) + def test_iaf_inverses(self): for input_dim in [2, 3, 5, 7, 9, 11]: self._test_inverse(input_dim, self._make_iaf) @@ -118,3 +129,7 @@ def test_iaf_stable_shapes(self): def test_flipflow_shapes(self): for shape in [(3,), (3, 4), (3, 4, 2)]: self._test_shape(shape, self._make_flipflow) + + def test_planar_shapes(self): + for shape in [(3,), (3, 4), (3, 4, 2)]: + self._test_shape(shape, self._make_planar) From e47cb8f8700f855c514dcb505141189c27ade3dc Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Mon, 5 Nov 2018 18:54:38 +0000 Subject: [PATCH 18/20] Working now! --- pyro/distributions/planar.py | 38 ++++++++---------------- tests/distributions/test_flows.py | 48 ++++++++++++++++++------------- 2 files changed, 40 insertions(+), 46 deletions(-) diff --git a/pyro/distributions/planar.py b/pyro/distributions/planar.py index 6703d1807d..66f7f53f2a 100644 --- a/pyro/distributions/planar.py +++ b/pyro/distributions/planar.py @@ -10,13 +10,6 @@ from pyro.distributions.util import copy_docs_from -# This helper function clamps gradients but still passes through the gradient in clamped regions -# NOTE: Not sure how necessary this is, but I was copying the design of the TensorFlow implementation - - -def clamp_preserve_gradients(x, min, max): - return x + (x.clamp(min, max) - x).detach() - @copy_docs_from(Transform) class PlanarFlow(Transform): @@ -27,7 +20,7 @@ class PlanarFlow(Transform): where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, and the learnable parameters are :math:`b\\in\\mathbb{R}`, :math:`\\mathbf{u}\\in\\mathbb{R}^D`, :math:`\\mathbf{w}\\in\\mathbb{R}^D` for input - dimension :math:`D`. For this to be an invertible transformation, the condition + dimension :math:`D`. For this to be an invertible transformation, the condition :math:`\\mathbf{w}^T\\mathbf{u}>-1` is enforced. Together with `TransformedDistribution` this provides a way to create richer variational approximations. @@ -68,18 +61,16 @@ def __init__(self, input_dim): def reset_parameters(self): stdv = 1. / math.sqrt(self.module.u.size(0)) self.module.lin.weight.data.uniform_(-stdv, stdv) + self.module.u.data.uniform_(-stdv, stdv) + # This method ensures that torch(u_hat, w) > -1, required for invertibility def u_hat(self): - u = self.module.u + u = self.module.u + w = self.module.lin.weight.squeeze(0) + alpha = torch.dot(u, w) + a_prime = -1 + F.softplus(alpha) + return u + (a_prime - alpha) * w.div(w.norm()) - # TODO: Reshape W? - w = self.module.lin.weight.squeeze(0) - - alpha = torch.dot(u, w) - a_prime = -1 + F.softplus(alpha) - - return u + (a_prime - alpha) * w.div(w.norm()) - def _call(self, x): """ :param x: the input into the bijection @@ -122,13 +113,8 @@ def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian """ - psi_z = (1 - torch.tanh(self.module.lin(x)).pow(2))*self.module.lin.weight - - # TODO: Check that dimensions of W broadcast properly! - #print('W', self.module.lin.weight.size(), 'psi_z', psi_z.size(), 'u', self.module.u.size()) - #raise Exception() + psi_z = (1 - torch.tanh(self.module.lin(x)).pow(2)) * self.module.lin.weight - # TODO: Continue from here, 5/11/2018! - # *** Need to take account of fact that psi_z has a batch dimension - #return torch.abs(1 + torch.dot(self.u_hat(), psi_z)) - return torch.abs(1 + psi_z * self.u_hat()) + # TODO: Simplify following line once using multivariate base distributions for multivariate flows + return torch.log(torch.abs(1 + torch.matmul(psi_z, self.u_hat())).unsqueeze(-1)) * \ + torch.ones_like(x) / x.size(-1) diff --git a/tests/distributions/test_flows.py b/tests/distributions/test_flows.py index 63cf2d3755..f2a89837ed 100644 --- a/tests/distributions/test_flows.py +++ b/tests/distributions/test_flows.py @@ -40,43 +40,51 @@ def nonzero(x): # Apply permutation for autoregressive flows if hasattr(flow, 'arn'): - permutation = flow.arn.get_permutation() - permuted_jacobian = jacobian.clone() - for j in range(input_dim): - for k in range(input_dim): - permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]] - jacobian = permuted_jacobian - - numeric_ldt = torch.sum(torch.log(torch.diag(jacobian))) - ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt) - - diag_sum = torch.sum(torch.diag(nonzero(jacobian))) - lower_sum = torch.sum(torch.tril(nonzero(jacobian), diagonal=-1)) + permutation = flow.arn.get_permutation() + permuted_jacobian = jacobian.clone() + for j in range(input_dim): + for k in range(input_dim): + permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]] + jacobian = permuted_jacobian + + # For autoregressive flow, Jacobian is sum of diagonal, otherwise need full determinate + if hasattr(flow, 'arn'): + numeric_ldt = torch.sum(torch.log(torch.diag(jacobian))) + else: + numeric_ldt = torch.log(torch.abs(jacobian.det())) + ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt) assert ldt_discrepancy < self.epsilon - assert diag_sum == float(input_dim) - assert lower_sum == float(0.0) + + # Test that lower triangular with unit diagonal for autoregressive flows + if hasattr(flow, 'arn'): + diag_sum = torch.sum(torch.diag(nonzero(jacobian))) + lower_sum = torch.sum(torch.tril(nonzero(jacobian), diagonal=-1)) + assert diag_sum == float(input_dim) + assert lower_sum == float(0.0) def _test_inverse(self, input_dim, make_flow): base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim)) - iaf = make_flow(input_dim) + flow = make_flow(input_dim) x_true = base_dist.sample(torch.Size([10])) - y = iaf._call(x_true) + y = flow._call(x_true) # This line empties the inverse cache, if the flow uses it - iaf._inverse(y) + if hasattr(flow, '_intermediates_cache'): + flow._intermediates_cache.pop((y, 'log_scale')) + flow._intermediates_cache.pop((y, 'x')) # Cache is empty, hence must be calculating inverse afresh - x_calculated = iaf._inverse(y) + x_calculated = flow._inverse(y) assert torch.norm(x_true - x_calculated, dim=-1).max().item() < self.delta def _test_shape(self, base_shape, make_flow): base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape)) last_dim = base_shape[-1] if isinstance(base_shape, tuple) else base_shape - iaf = make_flow(input_dim=last_dim) - sample = dist.TransformedDistribution(base_dist, [iaf]).sample() + flow = make_flow(input_dim=last_dim) + sample = dist.TransformedDistribution(base_dist, [flow]).sample() assert sample.shape == base_shape def _make_iaf(self, input_dim): From f26bfbe1a79c4e592c891bf67df0f1b48ae95aed Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Mon, 5 Nov 2018 19:01:46 +0000 Subject: [PATCH 19/20] Docs for PlanarFlow --- docs/source/distributions.rst | 7 +++++++ pyro/distributions/planar.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index a671588309..58364e16ed 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -175,6 +175,13 @@ InverseAutoRegressiveFlowStable PermuteTransform ---------------- .. autoclass:: pyro.distributions.PermuteTransform + :members: + :undoc-members: + :show-inheritance: + +PlanarFlow +---------------- +.. autoclass:: pyro.distributions.PlanarFlow :members: :undoc-members: :show-inheritance: \ No newline at end of file diff --git a/pyro/distributions/planar.py b/pyro/distributions/planar.py index 66f7f53f2a..c8250c20c9 100644 --- a/pyro/distributions/planar.py +++ b/pyro/distributions/planar.py @@ -39,6 +39,9 @@ class PlanarFlow(Transform): the inverse is cached when the forward operation is called during sampling, and so samples drawn using planar flow can be scored. + :param input_dim: the dimension of the input (and output) variable. + :type autoregressive_nn: int + References: Variational Inference with Normalizing Flows [arXiv:1505.05770] From 65c0100a312e05c94dea9c25e6280ff74fcb4a45 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Mon, 19 Nov 2018 22:19:20 +0000 Subject: [PATCH 20/20] Made PlanarFlow hashable, removed .module attribute hack --- pyro/distributions/planar.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/pyro/distributions/planar.py b/pyro/distributions/planar.py index c8250c20c9..35a6d1199d 100644 --- a/pyro/distributions/planar.py +++ b/pyro/distributions/planar.py @@ -12,7 +12,7 @@ @copy_docs_from(Transform) -class PlanarFlow(Transform): +class PlanarFlow(Transform, nn.Module): """ A 'planar' normalizing flow that uses the transformation @@ -29,7 +29,7 @@ class PlanarFlow(Transform): >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) >>> plf = PlanarFlow(10) - >>> plf_module = pyro.module("my_plf", plf.module) + >>> plf_module = pyro.module("my_plf", plf) >>> plf_dist = dist.TransformedDistribution(base_dist, [plf]) >>> plf_dist.sample() # doctest: +SKIP tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868, @@ -52,24 +52,28 @@ class PlanarFlow(Transform): codomain = constraints.real def __init__(self, input_dim): - super(PlanarFlow, self).__init__() + Transform.__init__(self) + nn.Module.__init__(self) + self.input_dim = input_dim - self.module = nn.Module() - self.module.lin = nn.Linear(input_dim, 1) - self.module.u = nn.Parameter(torch.Tensor(input_dim)) + self.lin = nn.Linear(input_dim, 1) + self.u = nn.Parameter(torch.Tensor(input_dim)) self.reset_parameters() self._intermediates_cache = {} self.add_inverse_to_cache = True def reset_parameters(self): - stdv = 1. / math.sqrt(self.module.u.size(0)) - self.module.lin.weight.data.uniform_(-stdv, stdv) - self.module.u.data.uniform_(-stdv, stdv) + stdv = 1. / math.sqrt(self.u.size(0)) + self.lin.weight.data.uniform_(-stdv, stdv) + self.u.data.uniform_(-stdv, stdv) + + def __hash__(self): + return super(nn.Module, self).__hash__() # This method ensures that torch(u_hat, w) > -1, required for invertibility def u_hat(self): - u = self.module.u - w = self.module.lin.weight.squeeze(0) + u = self.u + w = self.lin.weight.squeeze(0) alpha = torch.dot(u, w) a_prime = -1 + F.softplus(alpha) return u + (a_prime - alpha) * w.div(w.norm()) @@ -83,7 +87,7 @@ def _call(self, x): sample from the base distribution (or the output of a previous flow) """ - y = x + self.u_hat() * torch.tanh(self.module.lin(x)) + y = x + self.u_hat() * torch.tanh(self.lin(x)) self._add_intermediate_to_cache(x, y, 'x') return y @@ -116,7 +120,7 @@ def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian """ - psi_z = (1 - torch.tanh(self.module.lin(x)).pow(2)) * self.module.lin.weight + psi_z = (1 - torch.tanh(self.lin(x)).pow(2)) * self.lin.weight # TODO: Simplify following line once using multivariate base distributions for multivariate flows return torch.log(torch.abs(1 + torch.matmul(psi_z, self.u_hat())).unsqueeze(-1)) * \