Skip to content

Commit

Permalink
Test for whether TransformModules work for density estimation (#2544)
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanwebb authored Jul 7, 2020
1 parent 01cc072 commit 939c04d
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 61 deletions.
2 changes: 1 addition & 1 deletion pyro/distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from pyro.distributions.transforms.planar import ConditionalPlanar, Planar, conditional_planar, planar
from pyro.distributions.transforms.polynomial import Polynomial, polynomial
from pyro.distributions.transforms.radial import ConditionalRadial, Radial, conditional_radial, radial
from pyro.distributions.transforms.spline_autoregressive import SplineAutoregressive, spline_autoregressive
from pyro.distributions.transforms.spline import ConditionalSpline, Spline, conditional_spline, spline
from pyro.distributions.transforms.spline_autoregressive import SplineAutoregressive, spline_autoregressive
from pyro.distributions.transforms.spline_coupling import SplineCoupling, spline_coupling
from pyro.distributions.transforms.sylvester import Sylvester, sylvester

Expand Down
12 changes: 9 additions & 3 deletions pyro/distributions/transforms/householder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import math
import warnings
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -27,8 +28,9 @@ def __init__(self, u_unnormed=None):

# Construct normalized vectors for Householder transform
def u(self):
norm = torch.norm(self.u_unnormed, p=2, dim=-1, keepdim=True)
return torch.div(self.u_unnormed, norm)
u_unnormed = self.u_unnormed() if callable(self.u_unnormed) else self.u_unnormed
norm = torch.norm(u_unnormed, p=2, dim=-1, keepdim=True)
return torch.div(u_unnormed, norm)

def _call(self, x):
"""
Expand Down Expand Up @@ -217,14 +219,18 @@ def __init__(self, input_dim, nn, count_transforms=1):
over-parametrization!".format(count_transforms, input_dim))
self.count_transforms = count_transforms

def condition(self, context):
def _u_unnormed(self, context):
# u_unnormed ~ (count_transforms, input_dim)
# Hence, input_dim must divide
u_unnormed = self.nn(context)
if self.count_transforms == 1:
u_unnormed = u_unnormed.unsqueeze(-2)
else:
u_unnormed = torch.stack(u_unnormed, dim=-2)
return u_unnormed

def condition(self, context):
u_unnormed = partial(self._u_unnormed, context)
return ConditionedHouseholder(u_unnormed)


Expand Down
27 changes: 17 additions & 10 deletions pyro/distributions/transforms/planar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -21,11 +22,9 @@ class ConditionedPlanar(Transform):
bijective = True
event_dim = 1

def __init__(self, bias=None, u=None, w=None):
def __init__(self, params):
super().__init__(cache_size=1)
self.bias = bias
self.u = u
self.w = w
self._params = params
self._cached_logDetJ = None

# This method ensures that torch(u_hat, w) > -1, required for invertibility
Expand All @@ -42,15 +41,16 @@ def _call(self, x):
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
the base distribution (or the output of a previous transform)
"""
bias, u, w = self._params() if callable(self._params) else self._params

# x ~ (batch_size, dim_size, 1)
# w ~ (batch_size, 1, dim_size)
# bias ~ (batch_size, 1)
act = torch.tanh(torch.matmul(self.w.unsqueeze(-2), x.unsqueeze(-1)).squeeze(-1) + self.bias)
u_hat = self.u_hat(self.u, self.w)
act = torch.tanh(torch.matmul(w.unsqueeze(-2), x.unsqueeze(-1)).squeeze(-1) + bias)
u_hat = self.u_hat(u, w)
y = x + u_hat * act

psi_z = (1. - act.pow(2)) * self.w
psi_z = (1. - act.pow(2)) * w
self._cached_logDetJ = torch.log(
torch.abs(1 + torch.matmul(psi_z.unsqueeze(-2), u_hat.unsqueeze(-1)).squeeze(-1).squeeze(-1)))

Expand Down Expand Up @@ -126,14 +126,18 @@ class Planar(ConditionedPlanar, TransformModule):
event_dim = 1

def __init__(self, input_dim):
super().__init__()
super().__init__(self._params)

self.bias = nn.Parameter(torch.Tensor(1,))
self.u = nn.Parameter(torch.Tensor(input_dim,))
self.w = nn.Parameter(torch.Tensor(input_dim,))

self.input_dim = input_dim
self.reset_parameters()

def _params(self):
return self.bias, self.u, self.w

def reset_parameters(self):
stdv = 1. / math.sqrt(self.u.size(0))
self.w.data.uniform_(-stdv, stdv)
Expand Down Expand Up @@ -198,9 +202,12 @@ def __init__(self, nn):
super().__init__()
self.nn = nn

def _params(self, context):
return self.nn(context)

def condition(self, context):
bias, u, w = self.nn(context)
return ConditionedPlanar(bias, u, w)
params = partial(self._params, context)
return ConditionedPlanar(params)


def planar(input_dim):
Expand Down
29 changes: 18 additions & 11 deletions pyro/distributions/transforms/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -21,11 +22,9 @@ class ConditionedRadial(Transform):
bijective = True
event_dim = 1

def __init__(self, x0=None, alpha_prime=None, beta_prime=None):
def __init__(self, params):
super().__init__(cache_size=1)
self.x0 = x0
self.alpha_prime = alpha_prime
self.beta_prime = beta_prime
self._params = params
self._cached_logDetJ = None

# This method ensures that torch(u_hat, w) > -1, required for invertibility
Expand All @@ -43,18 +42,20 @@ def _call(self, x):
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from the base distribution (or the output
of a previous transform)
"""
x0, alpha_prime, beta_prime = self._params() if callable(self._params) else self._params

# Ensure invertibility using approach in appendix A.2
alpha = F.softplus(self.alpha_prime)
beta = -alpha + F.softplus(self.beta_prime)
alpha = F.softplus(alpha_prime)
beta = -alpha + F.softplus(beta_prime)

# Compute y and logDet using Equation 14.
diff = x - self.x0
diff = x - x0
r = diff.norm(dim=-1, keepdim=True)
h = (alpha + r).reciprocal()
h_prime = - (h ** 2)
beta_h = beta * h

self._cached_logDetJ = ((self.x0.size(-1) - 1) * torch.log1p(beta_h) +
self._cached_logDetJ = ((x0.size(-1) - 1) * torch.log1p(beta_h) +
torch.log1p(beta_h + beta * h_prime * r)).sum(-1)
return x + beta_h * diff

Expand Down Expand Up @@ -129,14 +130,17 @@ class Radial(ConditionedRadial, TransformModule):
event_dim = 1

def __init__(self, input_dim):
super().__init__()
super().__init__(self._params)

self.x0 = nn.Parameter(torch.Tensor(input_dim,))
self.alpha_prime = nn.Parameter(torch.Tensor(1,))
self.beta_prime = nn.Parameter(torch.Tensor(1,))
self.input_dim = input_dim
self.reset_parameters()

def _params(self):
return self.x0, self.alpha_prime, self.beta_prime

def reset_parameters(self):
stdv = 1. / math.sqrt(self.x0.size(0))
self.alpha_prime.data.uniform_(-stdv, stdv)
Expand Down Expand Up @@ -195,9 +199,12 @@ def __init__(self, nn):
super().__init__()
self.nn = nn

def _params(self, context):
return self.nn(context)

def condition(self, context):
x0, alpha_prime, beta_prime = self.nn(context)
return ConditionedRadial(x0, alpha_prime, beta_prime)
params = partial(self._params, context)
return ConditionedRadial(params)


def radial(input_dim):
Expand Down
49 changes: 25 additions & 24 deletions pyro/distributions/transforms/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# * https://github.com/bayesiains/nsf/blob/master/nde/transforms/splines/rational_quadratic.py
# under the MIT license.

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -269,15 +271,13 @@ class ConditionedSpline(Transform):
bijective = True
event_dim = 0

def __init__(self, widths=None, heights=None, derivatives=None, lambdas=None, bound=3.0, order='linear'):
def __init__(self, params, bound=3.0, order='linear'):
super().__init__(cache_size=1)

self._params = params
self.order = order
self.bound = bound
self.widths = widths
self.heights = heights
self.derivatives = derivatives
self.lambdas = lambdas
self._cache_log_detJ = None

def _call(self, x):
y, log_detJ = self.spline_op(x)
Expand Down Expand Up @@ -309,14 +309,8 @@ def log_abs_det_jacobian(self, x, y):
return self._cache_log_detJ

def spline_op(self, x, **kwargs):
y, log_detJ = _monotonic_rational_spline(
x,
self.widths,
self.heights,
self.derivatives,
self.lambdas,
bound=self.bound,
**kwargs)
w, h, d, l = self._params() if callable(self._params) else self._params
y, log_detJ = _monotonic_rational_spline(x, w, h, d, l, bound=self.bound, **kwargs)
return y, log_detJ


Expand Down Expand Up @@ -376,7 +370,7 @@ class Spline(ConditionedSpline, TransformModule):
event_dim = 0

def __init__(self, input_dim, count_bins=8, bound=3., order='linear'):
super(Spline, self).__init__()
super(Spline, self).__init__(self._params)

self.input_dim = input_dim
self.count_bins = count_bins
Expand All @@ -390,18 +384,21 @@ def __init__(self, input_dim, count_bins=8, bound=3., order='linear'):
# Rational linear splines have additional lambda parameters
if self.order == "linear":
self.unnormalized_lambdas = nn.Parameter(torch.rand(self.input_dim, self.count_bins))
self.lambdas = torch.sigmoid(self.unnormalized_lambdas)
elif self.order == "quadratic":
self.lambdas = None
else:
elif self.order != "quadratic":
raise ValueError(
"Keyword argument 'order' must be one of ['linear', 'quadratic'], but '{}' was found!".format(
self.order))

self.widths = F.softmax(self.unnormalized_widths, dim=-1)
self.heights = F.softmax(self.unnormalized_heights, dim=-1)
self.derivatives = F.softplus(self.unnormalized_derivatives)
self._cache_log_detJ = None
def _params(self):
# widths, unnormalized_widths ~ (input_dim, num_bins)
w = F.softmax(self.unnormalized_widths, dim=-1)
h = F.softmax(self.unnormalized_heights, dim=-1)
d = F.softplus(self.unnormalized_derivatives)
if self.order == 'linear':
l = torch.sigmoid(self.unnormalized_lambdas)
else:
l = None
return w, h, d, l


@copy_docs_from(ConditionalTransformModule)
Expand Down Expand Up @@ -481,7 +478,7 @@ def __init__(self, nn, input_dim, count_bins, bound=3.0, order='linear'):
self.bound = bound
self.order = order

def condition(self, context):
def _params(self, context):
# Rational linear splines have additional lambda parameters
if self.order == "linear":
w, h, d, l = self.nn(context)
Expand Down Expand Up @@ -514,7 +511,11 @@ def condition(self, context):
w = F.softmax(w, dim=-1)
h = F.softmax(h, dim=-1)
d = F.softplus(d)
return ConditionedSpline(w, h, d, l, bound=self.bound, order=self.order)
return w, h, d, l

def condition(self, context):
params = partial(self._params, context)
return ConditionedSpline(params, bound=self.bound, order=self.order)


def spline(input_dim, **kwargs):
Expand Down
Loading

0 comments on commit 939c04d

Please sign in to comment.