Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added dim keyword arg #2472

Merged
merged 12 commits into from
May 22, 2020
98 changes: 75 additions & 23 deletions pyro/distributions/transforms/affine_coupling.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from functools import partial
import operator
from functools import partial, reduce

import torch
from torch.distributions import constraints
from torch.distributions.utils import _sum_rightmost

from pyro.distributions.conditional import ConditionalTransformModule
from pyro.distributions.torch_transform import TransformModule
Expand Down Expand Up @@ -66,6 +68,9 @@ class AffineCoupling(TransformModule):
dimension split_dim and the output final dimension input_dim-split_dim for
each member of the tuple.
:type hypernet: callable
:param dim: the tensor dimension on which to split. This value must be negative
and defines the event dim as `abs(dim)`.
:type dim: int
:param log_scale_min_clip: The minimum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_min_clip: float
Expand All @@ -83,12 +88,16 @@ class AffineCoupling(TransformModule):
domain = constraints.real
codomain = constraints.real
bijective = True
event_dim = 1

def __init__(self, split_dim, hypernet, log_scale_min_clip=-5., log_scale_max_clip=3.):
def __init__(self, split_dim, hypernet, *, dim=-1, log_scale_min_clip=-5., log_scale_max_clip=3.):
super().__init__(cache_size=1)
if dim >= 0:
raise ValueError("'dim' keyword argument must be negative")

self.split_dim = split_dim
self.nn = hypernet
self.dim = dim
self.event_dim = -dim
self._cached_log_scale = None
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
Expand All @@ -102,15 +111,19 @@ def _call(self, x):
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
the base distribution (or the output of a previous transform)
"""
x1, x2 = x[..., :self.split_dim], x[..., self.split_dim:]
x1, x2 = x.split([self.split_dim, x.size(self.dim) - self.split_dim], dim=self.dim)
stefanwebb marked this conversation as resolved.
Show resolved Hide resolved

# Now that we can split on an arbitrary dimension, we have do a bit of reshaping...
mean, log_scale = self.nn(x1.reshape(x1.shape[:-self.event_dim] + (-1,)))
mean = mean.reshape(mean.shape[:-1] + x2.shape[-self.event_dim:])
log_scale = log_scale.reshape(log_scale.shape[:-1] + x2.shape[-self.event_dim:])

mean, log_scale = self.nn(x1)
log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip)
self._cached_log_scale = log_scale

y1 = x1
y2 = torch.exp(log_scale) * x2 + mean
return torch.cat([y1, y2], dim=-1)
return torch.cat([y1, y2], dim=self.dim)

def _inverse(self, y):
"""
Expand All @@ -120,14 +133,19 @@ def _inverse(self, y):
Inverts y => x. Uses a previously cached inverse if available, otherwise
performs the inversion afresh.
"""
y1, y2 = y[..., :self.split_dim], y[..., self.split_dim:]
y1, y2 = y.split([self.split_dim, y.size(self.dim) - self.split_dim], dim=self.dim)
x1 = y1
mean, log_scale = self.nn(x1)

# Now that we can split on an arbitrary dimension, we have do a bit of reshaping...
mean, log_scale = self.nn(x1.reshape(x1.shape[:-self.event_dim] + (-1,)))
mean = mean.reshape(mean.shape[:-1] + y2.shape[-self.event_dim:])
log_scale = log_scale.reshape(log_scale.shape[:-1] + y2.shape[-self.event_dim:])

log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip)
self._cached_log_scale = log_scale

x2 = (y2 - mean) * torch.exp(-log_scale)
return torch.cat([x1, x2], dim=-1)
return torch.cat([x1, x2], dim=self.dim)

def log_abs_det_jacobian(self, x, y):
"""
Expand All @@ -137,10 +155,11 @@ def log_abs_det_jacobian(self, x, y):
if self._cached_log_scale is not None and x is x_old and y is y_old:
log_scale = self._cached_log_scale
else:
x1 = x[..., :self.split_dim]
_, log_scale = self.nn(x1)
x1, x2 = x.split([self.split_dim, x.size(self.dim) - self.split_dim], dim=self.dim)
_, log_scale = self.nn(x1.reshape(x1.shape[:-self.event_dim] + (-1,)))
log_scale = log_scale.reshape(log_scale.shape[:-1] + x2.shape[-self.event_dim:])
log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip)
return log_scale.sum(-1)
return _sum_rightmost(log_scale, self.event_dim)


@copy_docs_from(ConditionalTransformModule)
Expand Down Expand Up @@ -234,20 +253,24 @@ def condition(self, context):
return AffineCoupling(self.split_dim, cond_nn, **self.kwargs)


def affine_coupling(input_dim, hidden_dims=None, split_dim=None, **kwargs):
def affine_coupling(input_dim, hidden_dims=None, split_dim=None, dim=-1, **kwargs):
"""
A helper function to create an
:class:`~pyro.distributions.transforms.AffineCoupling` object that takes care of
constructing a dense network with the correct input/output dimensions.

:param input_dim: Dimension of input variable
:param input_dim: Dimension(s) of input variable to permute. Note that when
`dim < -1` this must be a tuple corresponding to the event shape.
:type input_dim: int
:param hidden_dims: The desired hidden dimensions of the dense network. Defaults
to using [10*input_dim]
:type hidden_dims: list[int]
:param split_dim: The dimension to split the input on for the coupling
transform. Defaults to using input_dim // 2
:type split_dim: int
:param dim: the tensor dimension on which to split. This value must be negative
and defines the event dim as `abs(dim)`.
:type dim: int
:param log_scale_min_clip: The minimum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_min_clip: float
Expand All @@ -256,15 +279,29 @@ def affine_coupling(input_dim, hidden_dims=None, split_dim=None, **kwargs):
:type log_scale_max_clip: float

"""
if not isinstance(input_dim, int):
if len(input_dim) != -dim:
raise ValueError('event shape {} must have same length as event_dim {}'.format(input_dim, -dim))
event_shape = input_dim
extra_dims = reduce(operator.mul, event_shape[(dim + 1):], 1)
else:
event_shape = [input_dim]
extra_dims = 1
event_shape = list(event_shape)

if split_dim is None:
split_dim = input_dim // 2
split_dim = event_shape[dim] // 2
if hidden_dims is None:
hidden_dims = [10 * input_dim]
hypernet = DenseNN(split_dim, hidden_dims, [input_dim - split_dim, input_dim - split_dim])
return AffineCoupling(split_dim, hypernet, **kwargs)
hidden_dims = [10 * event_shape[dim] * extra_dims]

hypernet = DenseNN(split_dim * extra_dims,
hidden_dims,
[(event_shape[dim] - split_dim) * extra_dims,
(event_shape[dim] - split_dim) * extra_dims])
return AffineCoupling(split_dim, hypernet, dim=dim, **kwargs)


def conditional_affine_coupling(input_dim, context_dim, hidden_dims=None, split_dim=None, **kwargs):
def conditional_affine_coupling(input_dim, context_dim, hidden_dims=None, split_dim=None, dim=-1, **kwargs):
"""
A helper function to create an
:class:`~pyro.distributions.transforms.ConditionalAffineCoupling` object that
Expand All @@ -281,6 +318,9 @@ def conditional_affine_coupling(input_dim, context_dim, hidden_dims=None, split_
:param split_dim: The dimension to split the input on for the coupling
transform. Defaults to using input_dim // 2
:type split_dim: int
:param dim: the tensor dimension on which to split. This value must be negative
and defines the event dim as `abs(dim)`.
:type dim: int
:param log_scale_min_clip: The minimum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_min_clip: float
Expand All @@ -289,9 +329,21 @@ def conditional_affine_coupling(input_dim, context_dim, hidden_dims=None, split_
:type log_scale_max_clip: float

"""
if not isinstance(input_dim, int):
if len(input_dim) != -dim:
raise ValueError('event shape {} must have same length as event_dim {}'.format(input_dim, -dim))
event_shape = input_dim
extra_dims = reduce(operator.mul, event_shape[(dim + 1):], 1)
else:
event_shape = [input_dim]
extra_dims = 1
event_shape = list(event_shape)

if split_dim is None:
split_dim = input_dim // 2
split_dim = event_shape[dim] // 2
if hidden_dims is None:
hidden_dims = [10 * input_dim]
nn = ConditionalDenseNN(split_dim, context_dim, hidden_dims, [input_dim - split_dim, input_dim - split_dim])
return ConditionalAffineCoupling(split_dim, nn, **kwargs)
hidden_dims = [10 * event_shape[dim] * extra_dims]

nn = ConditionalDenseNN(split_dim * extra_dims, context_dim, hidden_dims,
[(event_shape[dim] - split_dim) * extra_dims, (event_shape[dim] - split_dim) * extra_dims])
return ConditionalAffineCoupling(split_dim, nn, dim=dim, **kwargs)
32 changes: 23 additions & 9 deletions pyro/distributions/transforms/permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,25 @@ class Permute(Transform):

:param permutation: a permutation ordering that is applied to the inputs.
:type permutation: torch.LongTensor
:param dim: the tensor dimension to permute. This value must be negative and
defines the event dim as `abs(dim)`.
:type dim: int

"""

codomain = constraints.real
bijective = True
event_dim = 1
volume_preserving = True

def __init__(self, permutation, cache_size=1):
def __init__(self, permutation, *, dim=-1, cache_size=1):
super().__init__(cache_size=cache_size)

if dim >= 0:
raise ValueError("'dim' keyword argument must be negative")

self.permutation = permutation
self.dim = dim
self.event_dim = -dim

@lazy_property
def inv_permutation(self):
Expand All @@ -68,7 +75,7 @@ def _call(self, x):
the base distribution (or the output of a previous transform)
"""

return x[..., self.permutation]
return x.index_select(self.dim, self.permutation)

def _inverse(self, y):
"""
Expand All @@ -77,8 +84,7 @@ def _inverse(self, y):

Inverts y => x.
"""

return y[..., self.inv_permutation]
return y.index_select(self.dim, self.inv_permutation)

def log_abs_det_jacobian(self, x, y):
"""
Expand All @@ -89,27 +95,35 @@ def log_abs_det_jacobian(self, x, y):
determinant is -1 or +1), and so returning a vector of zeros works.
"""

return torch.zeros(x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device)
return torch.zeros(x.size()[:-self.event_dim], dtype=x.dtype, layout=x.layout, device=x.device)

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return Permute(self.permutation, cache_size=cache_size)


def permute(input_dim, permutation=None):
def permute(input_dim, permutation=None, dim=-1):
"""
A helper function to create a :class:`~pyro.distributions.transforms.Permute`
object for consistency with other helpers.

:param input_dim: Dimension of input variable
:param input_dim: Dimension(s) of input variable to permute. Note that when
`dim < -1` this must be a tuple corresponding to the event shape.
:type input_dim: int
:param permutation: Torch tensor of integer indices representing permutation.
Defaults to a random permutation.
:type permutation: torch.LongTensor
:param dim: the tensor dimension to permute. This value must be negative and
defines the event dim as `abs(dim)`.
:type dim: int

"""
if dim < -1 or not isinstance(input_dim, int):
if len(input_dim) != -dim:
raise ValueError('event shape {} must have same length as event_dim {}'.format(input_dim, -dim))
input_dim = input_dim[dim]

if permutation is None:
permutation = torch.randperm(input_dim)
return Permute(permutation)
return Permute(permutation, dim=dim)
Loading