Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add first draft of Bernstein polynomial flow #32

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
torch.set_default_dtype(torch.float64)


@pytest.mark.parametrize('F', [GMM, NICE, MAF, NSF, SOSPF, NAF, UNAF, CNF, GF])
@pytest.mark.parametrize('F', [GMM, NICE, MAF, NSF, SOSPF, NAF, UNAF, CNF, GF, BERN])
def test_flows(tmp_path: Path, F: callable):
flow = F(3, 5)

Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_univariate_transforms(batched: bool):
GaussianizationTransform(randn(*batch, 8), randn(*batch, 8)),
UnconstrainedMonotonicTransform(lambda x: torch.exp(-x**2) + 1e-2, randn(batch)),
SOSPolynomialTransform(randn(*batch, 3, 5), randn(batch)),
BernTransform(randn(*batch, 5)),
]

for t in ts:
Expand Down
50 changes: 48 additions & 2 deletions zuko/flows/spline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
r"""Spline flows."""

__all__ = [
'NSF',
'NCSF',
"NSF",
"NCSF",
"BERN",
]

import torch
Expand Down Expand Up @@ -59,6 +60,51 @@ def __init__(
)


class BERN(MAF):
r"""Creates a Bernstein flow (BERN) with a monotonic Bernstein polynomial transformation.

By default, transformations are fully autoregressive. Coupling transformations
can be obtained by setting :py:`passes=2`.

Warning:
Spline transformations are defined over the domain :math:`[-10, 10]`. Any feature
outside of this domain is not transformed. It is recommended to standardize
features (zero mean, unit variance) before training. Note that the domain of the Bernstein
polynomial is [0,1].

See also:
:class:`zuko.transforms.BernTransform`

References:
| Bernstein-Flows (Sick et al., 2020)
| https://arxiv.org/abs/2004.00464
Arguments:
features: The number of features.
context: The number of context features.
degree: The number of Bernstein-Polymials :math:`M`.
kwargs: Keyword arguments passed to :class:`zuko.flows.autoregressive.MAF`.
"""

def __init__(
self,
features: int,
context: int = 0,
degree: int = 30,
**kwargs,
):
super().__init__(
features=features,
context=context,
univariate=BernTransform,
shapes=[(degree,)],
**kwargs,
)

transforms = self.transform.transforms
for i in reversed(range(1, len(transforms))):
transforms.insert(i, Unconditional(SoftclipTransform, bound=10))


def CircularRQSTransform(*phi) -> Transform:
r"""Creates a circular rational-quadratic spline (RQS) transformation."""

Expand Down
70 changes: 70 additions & 0 deletions zuko/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
'PermutationTransform',
'RotationTransform',
'LULinearTransform',
'BernTransform',
]

import math
import warnings
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -599,6 +601,74 @@ def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]:
jacobian = torch.autograd.grad(y, x, torch.ones_like(y), create_graph=True)[0]

return y, jacobian.log()

class BernTransform(MonotonicTransform):
r"""Creates a monotonic transformation based on Bernstein Polynomials.

References:
| Bernstein-Flows (Sick et al., 2020)
| https://arxiv.org/abs/2004.00464

Arguments:
theta_un: The unconstrained coefficients with shape :math:`(*, K)`.
bound: The spline's (co)domain bound :math:`B`.
eps: The error bound for the bisection algorithm.

The implemented transformation is given by Eq. (4) in the paper.
Note that in constrast to the Bernstein Polynomials in the paper ($\tilde{y} \in [0,1]$), here the domain of the data is [-bound, bound].
"""

bijective = True
sign = +1
codomain = constraints.real

eps_bern = 1e-6


def __init__(
self,
theta_un: Tensor,
eps: float = 1e-6,
bound = 10,
**kwargs,
):
super().__init__(self.f, phi=theta_un, bound=bound, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the bug 🐛 phi should be a list of parameters (in this case it should be (theta_un,). Otherwise I think everything works, good job! I will refactor the code a bit to make it compliant with Zuko's conventions.


self.theta = BernTransform.to_theta(theta_un)
self.eps = eps
self.domain = constraints.interval(-bound, bound)

# Defining the beta distribution for h
len_theta = theta_un.shape[-1]
alpha = torch.tensor(range(1, len_theta + 1), dtype=torch.float32)
beta = torch.tensor(range(len_theta, 0, -1), dtype=torch.float32)
self.beta_dist_for_h = torch.distributions.Beta(alpha, beta)


@staticmethod
def to_theta(pre_theta):
"""
Converts the unconstrained output of the network to a constrained theta so that the
Bernstein polynomial are monotonically increasing. We assume that pre_theta is centered around 0 and thus substract softplus(0)*M/2 to center theta around 0.
"""
spo = torch.log(torch.tensor(2.0)) # softplus(0)
softplus_tensor = torch.nn.functional.softplus(
pre_theta[..., 1:]
)
d = torch.cat((pre_theta[..., :1], softplus_tensor), dim=-1)
return torch.cumsum(d, dim=-1) - spo*pre_theta.shape[-1]/2

def f(self, x: Tensor) -> Tensor:
# Check if the data is in the domain (just for printing a warning)
if torch.any(x < -self.bound) or torch.any(x > self.bound):
warnings.warn(f"Warning: Some values in y are outside the range {-self.bound, +self.bound}.")
# Data is bounded between -bound and bound. So we need to scale it to [0,1]
x = (x + self.bound) / (2 * self.bound)
x = x.unsqueeze(-1)
x = torch.clamp(x, self.eps_bern, 1-self.eps_bern)
f_im = self.beta_dist_for_h.log_prob(x).exp()
h = torch.mean(f_im * self.theta, dim=-1)
return h


class GaussianizationTransform(MonotonicTransform):
Expand Down
Loading