-
-
Notifications
You must be signed in to change notification settings - Fork 984
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Separated NN tests from flow tests * PermutationFlow * Tests for PermutationFLow * Bug fix * Renamed PermutationFlow to PermuteTransform * Added PermuteTransform to docs * Added device to permutation vectors * PEP8 * Removed 'flow', link to IAF in docs, fixed other bug in docs * Removed more 'flow's * Added lazy_property to inv_permutation of PermuteTransform * Inverse operations for IAF and alternative version * Fixed docs error * Equations in docs * Fixed docstrings * Planar flow (untested) * Debugging planar flow * Working now! * Docs for PlanarFlow * Made PlanarFlow hashable, removed .module attribute hack
- Loading branch information
1 parent
781416f
commit 43155bf
Showing
5 changed files
with
184 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.distributions import constraints | ||
import torch.nn.functional as F | ||
|
||
from pyro.distributions.torch_transform import TransformModule | ||
from pyro.distributions.util import copy_docs_from | ||
|
||
|
||
@copy_docs_from(TransformModule) | ||
class PlanarFlow(TransformModule): | ||
""" | ||
A 'planar' normalizing flow that uses the transformation | ||
:math:`\\mathbf{y} = \\mathbf{x} + \\mathbf{u}\\tanh(\\mathbf{w}^T\\mathbf{z}+b)` | ||
where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, and the learnable parameters | ||
are :math:`b\\in\\mathbb{R}`, :math:`\\mathbf{u}\\in\\mathbb{R}^D`, :math:`\\mathbf{w}\\in\\mathbb{R}^D` for input | ||
dimension :math:`D`. For this to be an invertible transformation, the condition | ||
:math:`\\mathbf{w}^T\\mathbf{u}>-1` is enforced. | ||
Together with `TransformedDistribution` this provides a way to create richer variational approximations. | ||
Example usage: | ||
>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) | ||
>>> plf = PlanarFlow(10) | ||
>>> plf_module = pyro.module("my_plf", plf) | ||
>>> plf_dist = dist.TransformedDistribution(base_dist, [plf]) | ||
>>> plf_dist.sample() # doctest: +SKIP | ||
tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868, | ||
0.1389, -0.4629, 0.0986]) | ||
The inverse of this transform does not possess an analytical solution and is left unimplemented. However, | ||
the inverse is cached when the forward operation is called during sampling, and so samples drawn using | ||
planar flow can be scored. | ||
:param input_dim: the dimension of the input (and output) variable. | ||
:type autoregressive_nn: int | ||
References: | ||
Variational Inference with Normalizing Flows [arXiv:1505.05770] | ||
Danilo Jimenez Rezende, Shakir Mohamed | ||
""" | ||
|
||
codomain = constraints.real | ||
|
||
def __init__(self, input_dim): | ||
super(PlanarFlow, self).__init__() | ||
|
||
self.input_dim = input_dim | ||
self.lin = nn.Linear(input_dim, 1) | ||
self.u = nn.Parameter(torch.Tensor(input_dim)) | ||
self.reset_parameters() | ||
self._intermediates_cache = {} | ||
self.add_inverse_to_cache = True | ||
|
||
def reset_parameters(self): | ||
stdv = 1. / math.sqrt(self.u.size(0)) | ||
self.lin.weight.data.uniform_(-stdv, stdv) | ||
self.u.data.uniform_(-stdv, stdv) | ||
|
||
def __hash__(self): | ||
return super(nn.Module, self).__hash__() | ||
|
||
# This method ensures that torch(u_hat, w) > -1, required for invertibility | ||
def u_hat(self): | ||
u = self.u | ||
w = self.lin.weight.squeeze(0) | ||
alpha = torch.dot(u, w) | ||
a_prime = -1 + F.softplus(alpha) | ||
return u + (a_prime - alpha) * w.div(w.norm()) | ||
|
||
def _call(self, x): | ||
""" | ||
:param x: the input into the bijection | ||
:type x: torch.Tensor | ||
Invokes the bijection x=>y; in the prototypical context of a TransformedDistribution `x` is a | ||
sample from the base distribution (or the output of a previous flow) | ||
""" | ||
|
||
y = x + self.u_hat() * torch.tanh(self.lin(x)) | ||
|
||
self._add_intermediate_to_cache(x, y, 'x') | ||
return y | ||
|
||
def _inverse(self, y): | ||
""" | ||
:param y: the output of the bijection | ||
:type y: torch.Tensor | ||
Inverts y => x. As noted above, this implementation is incapable of inverting arbitrary values | ||
`y`; rather it assumes `y` is the result of a previously computed application of the bijector | ||
to some `x` (which was cached on the forward call) | ||
""" | ||
if (y, 'x') in self._intermediates_cache: | ||
x = self._intermediates_cache.pop((y, 'x')) | ||
return x | ||
else: | ||
raise KeyError("PlanarFlow expected to find " | ||
"key in intermediates cache but didn't") | ||
|
||
def _add_intermediate_to_cache(self, intermediate, y, name): | ||
""" | ||
Internal function used to cache intermediate results computed during the forward call | ||
""" | ||
assert((y, name) not in self._intermediates_cache),\ | ||
"key collision in _add_intermediate_to_cache" | ||
self._intermediates_cache[(y, name)] = intermediate | ||
|
||
def log_abs_det_jacobian(self, x, y): | ||
""" | ||
Calculates the elementwise determinant of the log jacobian | ||
""" | ||
psi_z = (1 - torch.tanh(self.lin(x)).pow(2)) * self.lin.weight | ||
|
||
# TODO: Simplify following line once using multivariate base distributions for multivariate flows | ||
return torch.log(torch.abs(1 + torch.matmul(psi_z, self.u_hat())).unsqueeze(-1)) * \ | ||
torch.ones_like(x) / x.size(-1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters