Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PlanarFlow #1515

Merged
merged 24 commits into from
Nov 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9c19029
Separated NN tests from flow tests
stefanwebb Oct 23, 2018
5fab7ab
PermutationFlow
stefanwebb Oct 23, 2018
e8719a4
Tests for PermutationFLow
stefanwebb Oct 23, 2018
18ec5bb
Bug fix
stefanwebb Oct 23, 2018
5cf65bb
Renamed PermutationFlow to PermuteTransform
stefanwebb Oct 24, 2018
3b3003a
Added PermuteTransform to docs
stefanwebb Oct 24, 2018
3e1cdec
Added device to permutation vectors
stefanwebb Oct 24, 2018
6fcd765
PEP8
stefanwebb Oct 24, 2018
eb719f0
Merge branch 'dev' of https://github.com/uber/pyro into new
stefanwebb Oct 24, 2018
721ee63
Removed 'flow', link to IAF in docs, fixed other bug in docs
stefanwebb Oct 25, 2018
f64dd91
Removed more 'flow's
stefanwebb Oct 25, 2018
52979b8
Added lazy_property to inv_permutation of PermuteTransform
stefanwebb Oct 25, 2018
7b92c38
Inverse operations for IAF and alternative version
stefanwebb Oct 25, 2018
0cf12cb
Merge
stefanwebb Oct 25, 2018
109712b
Fixed docs error
stefanwebb Oct 25, 2018
df85279
Equations in docs
stefanwebb Oct 26, 2018
c98f6a0
Fixed docstrings
stefanwebb Oct 26, 2018
a75fed1
Planar flow (untested)
stefanwebb Nov 5, 2018
f7be736
Debugging planar flow
stefanwebb Nov 5, 2018
e47cb8f
Working now!
stefanwebb Nov 5, 2018
f26bfbe
Docs for PlanarFlow
stefanwebb Nov 5, 2018
8f1be2f
Merge branch 'dev' of https://github.com/uber/pyro into new
stefanwebb Nov 5, 2018
65c0100
Made PlanarFlow hashable, removed .module attribute hack
stefanwebb Nov 19, 2018
33c8260
Merged in TransformModule changes
stefanwebb Nov 24, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ PermuteTransform
:undoc-members:
:show-inheritance:

PlanarFlow
----------------
.. autoclass:: pyro.distributions.PlanarFlow
:members:
:undoc-members:
:show-inheritance:

TransformModule
----------------
.. autoclass:: pyro.distributions.TransformModule
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyro.distributions.mixture import MaskedMixture
from pyro.distributions.omt_mvn import OMTMultivariateNormal
from pyro.distributions.permute import PermuteTransform
from pyro.distributions.planar import PlanarFlow
from pyro.distributions.rejector import Rejector
from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough,
RelaxedOneHotCategoricalStraightThrough)
Expand Down Expand Up @@ -46,6 +47,7 @@
"MixtureOfDiagNormals",
"OMTMultivariateNormal",
"PermuteTransform",
"PlanarFlow",
"Rejector",
"RelaxedBernoulliStraightThrough",
"RelaxedOneHotCategoricalStraightThrough",
Expand Down
8 changes: 3 additions & 5 deletions pyro/distributions/iaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def _inverse(self, y):
x[idx] = (y[..., idx] - mean) * inverse_scale

x = torch.stack(x, dim=-1)
log_scale = clamp_preserve_gradients(log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip)
stefanwebb marked this conversation as resolved.
Show resolved Hide resolved
self._add_intermediate_to_cache(log_scale, y, 'log_scale')
return x

def _add_intermediate_to_cache(self, intermediate, y, name):
Expand All @@ -132,11 +134,7 @@ def log_abs_det_jacobian(self, x, y):
"""
Calculates the elementwise determinant of the log jacobian
"""
if (y, 'log_scale') in self._intermediates_cache:
log_scale = self._intermediates_cache.pop((y, 'log_scale'))
else:
_, log_scale = self.arn(x)
log_scale = clamp_preserve_gradients(log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip)
log_scale = self._intermediates_cache.pop((y, 'log_scale'))
return log_scale


Expand Down
126 changes: 126 additions & 0 deletions pyro/distributions/planar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import absolute_import, division, print_function

import math

import torch
import torch.nn as nn
from torch.distributions import constraints
import torch.nn.functional as F

from pyro.distributions.torch_transform import TransformModule
from pyro.distributions.util import copy_docs_from


@copy_docs_from(TransformModule)
class PlanarFlow(TransformModule):
"""
A 'planar' normalizing flow that uses the transformation

:math:`\\mathbf{y} = \\mathbf{x} + \\mathbf{u}\\tanh(\\mathbf{w}^T\\mathbf{z}+b)`

where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, and the learnable parameters
are :math:`b\\in\\mathbb{R}`, :math:`\\mathbf{u}\\in\\mathbb{R}^D`, :math:`\\mathbf{w}\\in\\mathbb{R}^D` for input
dimension :math:`D`. For this to be an invertible transformation, the condition
:math:`\\mathbf{w}^T\\mathbf{u}>-1` is enforced.

Together with `TransformedDistribution` this provides a way to create richer variational approximations.

Example usage:

>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> plf = PlanarFlow(10)
>>> plf_module = pyro.module("my_plf", plf)
>>> plf_dist = dist.TransformedDistribution(base_dist, [plf])
>>> plf_dist.sample() # doctest: +SKIP
tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868,
0.1389, -0.4629, 0.0986])

The inverse of this transform does not possess an analytical solution and is left unimplemented. However,
the inverse is cached when the forward operation is called during sampling, and so samples drawn using
planar flow can be scored.

:param input_dim: the dimension of the input (and output) variable.
:type autoregressive_nn: int

References:

Variational Inference with Normalizing Flows [arXiv:1505.05770]
Danilo Jimenez Rezende, Shakir Mohamed

"""

codomain = constraints.real

def __init__(self, input_dim):
super(PlanarFlow, self).__init__()

self.input_dim = input_dim
self.lin = nn.Linear(input_dim, 1)
self.u = nn.Parameter(torch.Tensor(input_dim))
self.reset_parameters()
self._intermediates_cache = {}
self.add_inverse_to_cache = True

def reset_parameters(self):
stdv = 1. / math.sqrt(self.u.size(0))
self.lin.weight.data.uniform_(-stdv, stdv)
self.u.data.uniform_(-stdv, stdv)

def __hash__(self):
stefanwebb marked this conversation as resolved.
Show resolved Hide resolved
return super(nn.Module, self).__hash__()

# This method ensures that torch(u_hat, w) > -1, required for invertibility
def u_hat(self):
u = self.u
w = self.lin.weight.squeeze(0)
alpha = torch.dot(u, w)
a_prime = -1 + F.softplus(alpha)
return u + (a_prime - alpha) * w.div(w.norm())

def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor

Invokes the bijection x=>y; in the prototypical context of a TransformedDistribution `x` is a
sample from the base distribution (or the output of a previous flow)
"""

y = x + self.u_hat() * torch.tanh(self.lin(x))

self._add_intermediate_to_cache(x, y, 'x')
return y

def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor

Inverts y => x. As noted above, this implementation is incapable of inverting arbitrary values
`y`; rather it assumes `y` is the result of a previously computed application of the bijector
to some `x` (which was cached on the forward call)
"""
if (y, 'x') in self._intermediates_cache:
x = self._intermediates_cache.pop((y, 'x'))
return x
else:
raise KeyError("PlanarFlow expected to find "
"key in intermediates cache but didn't")

def _add_intermediate_to_cache(self, intermediate, y, name):
"""
Internal function used to cache intermediate results computed during the forward call
"""
assert((y, name) not in self._intermediates_cache),\
"key collision in _add_intermediate_to_cache"
self._intermediates_cache[(y, name)] = intermediate

def log_abs_det_jacobian(self, x, y):
"""
Calculates the elementwise determinant of the log jacobian
"""
psi_z = (1 - torch.tanh(self.lin(x)).pow(2)) * self.lin.weight

# TODO: Simplify following line once using multivariate base distributions for multivariate flows
return torch.log(torch.abs(1 + torch.matmul(psi_z, self.u_hat())).unsqueeze(-1)) * \
torch.ones_like(x) / x.size(-1)
69 changes: 46 additions & 23 deletions tests/distributions/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
pytestmark = pytest.mark.init(rng_seed=123)


class AutoregressiveFlowTests(TestCase):
class FlowTests(TestCase):
def setUp(self):
# Epsilon is used to compare numerical gradient to analytical one
self.epsilon = 1e-3
Expand All @@ -22,57 +22,69 @@ def setUp(self):

def _test_jacobian(self, input_dim, make_flow):
jacobian = torch.zeros(input_dim, input_dim)
iaf = make_flow(input_dim)
flow = make_flow(input_dim)

def nonzero(x):
return torch.sign(torch.abs(x))

x = torch.randn(1, input_dim)
iaf_x = iaf(x)
analytic_ldt = iaf.log_abs_det_jacobian(x, iaf_x).data.sum()
flow_x = flow(x)
analytic_ldt = flow.log_abs_det_jacobian(x, flow_x).data.sum()

for j in range(input_dim):
for k in range(input_dim):
epsilon_vector = torch.zeros(1, input_dim)
epsilon_vector[0, j] = self.epsilon
delta = (iaf(x + 0.5 * epsilon_vector) - iaf(x - 0.5 * epsilon_vector)) / self.epsilon
delta = (flow(x + 0.5 * epsilon_vector) - flow(x - 0.5 * epsilon_vector)) / self.epsilon
jacobian[j, k] = float(delta[0, k].data.sum())

permutation = iaf.arn.get_permutation()
permuted_jacobian = jacobian.clone()
for j in range(input_dim):
for k in range(input_dim):
permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]]
numeric_ldt = torch.sum(torch.log(torch.diag(permuted_jacobian)))
ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt)

diag_sum = torch.sum(torch.diag(nonzero(permuted_jacobian)))
lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=-1))
# Apply permutation for autoregressive flows
if hasattr(flow, 'arn'):
permutation = flow.arn.get_permutation()
permuted_jacobian = jacobian.clone()
for j in range(input_dim):
for k in range(input_dim):
permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]]
jacobian = permuted_jacobian

# For autoregressive flow, Jacobian is sum of diagonal, otherwise need full determinate
if hasattr(flow, 'arn'):
numeric_ldt = torch.sum(torch.log(torch.diag(jacobian)))
else:
numeric_ldt = torch.log(torch.abs(jacobian.det()))

ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt)
assert ldt_discrepancy < self.epsilon
assert diag_sum == float(input_dim)
assert lower_sum == float(0.0)

# Test that lower triangular with unit diagonal for autoregressive flows
if hasattr(flow, 'arn'):
diag_sum = torch.sum(torch.diag(nonzero(jacobian)))
lower_sum = torch.sum(torch.tril(nonzero(jacobian), diagonal=-1))
assert diag_sum == float(input_dim)
assert lower_sum == float(0.0)

def _test_inverse(self, input_dim, make_flow):
base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
iaf = make_flow(input_dim)
flow = make_flow(input_dim)

x_true = base_dist.sample(torch.Size([10]))
y = iaf._call(x_true)
y = flow._call(x_true)

# This line empties the inverse cache, if the flow uses it
iaf._inverse(y)
if hasattr(flow, '_intermediates_cache'):
flow._intermediates_cache.pop((y, 'log_scale'))
flow._intermediates_cache.pop((y, 'x'))

# Cache is empty, hence must be calculating inverse afresh
x_calculated = iaf._inverse(y)
x_calculated = flow._inverse(y)

assert torch.norm(x_true - x_calculated, dim=-1).max().item() < self.delta

def _test_shape(self, base_shape, make_flow):
base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape))
last_dim = base_shape[-1] if isinstance(base_shape, tuple) else base_shape
iaf = make_flow(input_dim=last_dim)
sample = dist.TransformedDistribution(base_dist, [iaf]).sample()
flow = make_flow(input_dim=last_dim)
sample = dist.TransformedDistribution(base_dist, [flow]).sample()
assert sample.shape == base_shape

def _make_iaf(self, input_dim):
Expand All @@ -87,6 +99,9 @@ def _make_flipflow(self, input_dim):
permutation = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device)
return dist.PermuteTransform(permutation)

def _make_planar(self, input_dim):
return dist.PlanarFlow(input_dim)

def test_iaf_jacobians(self):
for input_dim in [2, 3, 5, 7, 9, 11]:
self._test_jacobian(input_dim, self._make_iaf)
Expand All @@ -95,6 +110,10 @@ def test_iaf_stable_jacobians(self):
for input_dim in [2, 3, 5, 7, 9, 11]:
self._test_jacobian(input_dim, self._make_iaf_stable)

def test_planar_jacobians(self):
for input_dim in [2, 3, 5, 7, 9, 11]:
self._test_jacobian(input_dim, self._make_planar)

def test_iaf_inverses(self):
for input_dim in [2, 3, 5, 7, 9, 11]:
self._test_inverse(input_dim, self._make_iaf)
Expand All @@ -118,3 +137,7 @@ def test_iaf_stable_shapes(self):
def test_flipflow_shapes(self):
for shape in [(3,), (3, 4), (3, 4, 2)]:
self._test_shape(shape, self._make_flipflow)

def test_planar_shapes(self):
for shape in [(3,), (3, 4), (3, 4, 2)]:
self._test_shape(shape, self._make_planar)