Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stable distribution with numerically integrated log-probability calculation (StableWithLogProb). #3369

Merged
merged 24 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e94495e
Added Stable distribution with unsafe log-probability calculation.
May 19, 2024
665a092
Make Stable distribution log-probability calculation safe at values n…
May 19, 2024
695ee8c
Make Stable distribution log-probability calculation safe at alpha ne…
May 19, 2024
6d50cca
Make Stable log-probability method part of an independent class.
May 19, 2024
7661a6e
Added dynamic near zero value tolerance to the log-probability estima…
May 19, 2024
2e2b036
Reduce Stable log-probability calculation value near zero tolerance i…
May 19, 2024
cd6dde3
Cap range of Stable log-probability.
May 19, 2024
49657f7
Clamp log in order to make gradient continuous.
May 20, 2024
147a772
Code cleanup.
May 20, 2024
d493f8c
Don't reparametrize pyro.distributions.StableWithLogProb.
May 20, 2024
037f094
Add tests for Stable distribution with method for calculating the log…
May 20, 2024
1f0a696
Linting and formatting.
May 20, 2024
2d2b702
Moved definition of StableWithLogProb into pyro.distributions.stable.
May 21, 2024
daf04a0
Avoid importing scipy until StableWithLogProb.log_prob is called for …
May 21, 2024
110ea37
Don't allow reparameterization of StableWithLogProb.
May 21, 2024
2864c33
Linting and formatting.
May 21, 2024
a19fbee
Add iterations in order to assure convergence in parameter fit tests.
May 21, 2024
a79eb7b
Comment out test.
May 21, 2024
3602f00
Increase test error limit.
May 21, 2024
1ee4391
Added StableWithLogProb docs.
May 22, 2024
ff8cd1f
Cap near zero tolerance by inverse probability density.
May 27, 2024
77d9c9f
Make log_prob return data type same as that of the input value.
May 27, 2024
9e8044c
Added Stable distirbution log-probability calculation goodness of fit…
May 27, 2024
06b4bec
Added explanation of StableWithLogProb usage and results.
May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ Stable
:undoc-members:
:show-inheritance:

StableWithLogProb
-----------------
.. autoclass:: pyro.distributions.StableWithLogProb
:members:
:undoc-members:
:show-inheritance:

TruncatedPolyaGamma
-------------------
.. autoclass:: pyro.distributions.TruncatedPolyaGamma
Expand Down
3 changes: 2 additions & 1 deletion pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
from pyro.distributions.sine_skewed import SineSkewed
from pyro.distributions.softlaplace import SoftLaplace
from pyro.distributions.spanning_tree import SpanningTree
from pyro.distributions.stable import Stable
from pyro.distributions.stable import Stable, StableWithLogProb
from pyro.distributions.torch import __all__ as torch_dists
from pyro.distributions.torch_distribution import (
ExpandedDistribution,
Expand Down Expand Up @@ -234,6 +234,7 @@
"SoftLaplace",
"SpanningTree",
"Stable",
"StableWithLogProb",
"StudentT",
"TorchDistribution",
"TransformModule",
Expand Down
22 changes: 22 additions & 0 deletions pyro/distributions/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all

from pyro.distributions.stable_log_prob import StableLogProb
from pyro.distributions.torch_distribution import TorchDistribution


Expand Down Expand Up @@ -204,3 +205,24 @@ def mean(self):
def variance(self):
var = self.scale * self.scale
return var.mul(2).masked_fill(self.stability < 2, math.inf)


class StableWithLogProb(StableLogProb, Stable):
r"""
Levy :math:`\alpha`-stable distribution that is based on
:class:`Stable` but with an added method for calculating the
log probability density using numerical integration.

This should be used in cases where reparameterization does not work
like when trying to estimate the skew :math:`\beta` parameter. Running
times are slower than with reparameterization.

The numerical integration implementation is based on the algorithm
proposed by Chambers, Mallows and Stuck (CMS) for simulating the
Levy :math:`\alpha`-stable distribution. The CMS algorithm involves a
nonlinear transformation of two independent random variables into
one stable random variable. The first random variable is uniformly
distributed while the second is exponentially distributed. The numerical
integration is performed over the first uniformly distributed random
variable.
"""
220 changes: 220 additions & 0 deletions pyro/distributions/stable_log_prob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
from functools import partial

import torch

value_near_zero_tolerance_alpha = 0.01
value_near_zero_tolerance_density = 0.1
alpha_near_one_tolerance = 0.05


finfo = torch.finfo(torch.float64)
MAX_LOG = math.log10(finfo.max)
MIN_LOG = math.log10(finfo.tiny)


def create_integrator(num_points):
from scipy.special import roots_legendre

roots, weights = roots_legendre(num_points)
roots = torch.Tensor(roots).double()
weights = torch.Tensor(weights).double()
log_weights = weights.log()
half_roots = roots * 0.5

def integrate(fn, domain):
sl = [slice(None)] + (len(domain.shape) - 1) * [None]
half_roots_sl = half_roots[sl]
value = domain[0] * (0.5 - half_roots_sl) + domain[1] * (0.5 + half_roots_sl)
return (
torch.logsumexp(fn(value) + log_weights[sl], dim=0)
+ ((domain[1] - domain[0]) / 2).log()
)

return integrate


def set_integrator(num_points):
global integrate
integrate = create_integrator(num_points)


# Stub which is replaced by the default integrator when called for the first time
# if a default integrator has not already been set.
def integrate(*args, **kwargs):
set_integrator(num_points=501)
return integrate(*args, **kwargs)


class StableLogProb:
def log_prob(self, value):
# Undo shift and scale
value = (value - self.loc) / self.scale
value_dtype = value.dtype

# Use double precision math
alpha = self.stability.double()
beta = self.skew.double()
value = value.double()

log_prob = _stable_log_prob(alpha, beta, value, self.coords)

return log_prob.to(dtype=value_dtype) - self.scale.log()


def _stable_log_prob(alpha, beta, value, coords):
# Convert to Nolan's parametrization S^0 where samples depend
# continuously on (alpha,beta), allowing interpolation around the hole at
# alpha=1.
if coords == "S":
value = torch.where(
alpha == 1, value, value - beta * (math.pi / 2 * alpha)
).tan()
elif coords != "S0":
raise ValueError("Unknown coords: {}".format(coords))

# Find near one alpha
idx = (alpha - 1).abs() < alpha_near_one_tolerance

log_prob = _unsafe_alpha_stable_log_prob_S0(
torch.where(idx, 1 + alpha_near_one_tolerance, alpha), beta, value
)

# Handle alpha near one by interpolation
if idx.any():
log_prob_pos = log_prob[idx]
log_prob_neg = _unsafe_alpha_stable_log_prob_S0(
(1 - alpha_near_one_tolerance) * log_prob_pos.new_ones(log_prob_pos.shape),
beta[idx],
value[idx],
)
weights = (alpha[idx] - 1) / (2 * alpha_near_one_tolerance) + 0.5
log_prob[idx] = torch.logsumexp(
torch.stack(
(log_prob_pos + weights.log(), log_prob_neg + (1 - weights).log()),
dim=0,
),
dim=0,
)

return log_prob


def _unsafe_alpha_stable_log_prob_S0(alpha, beta, Z):
# Calculate log-probability of Z in Nolan's parametrization S^0. This will fail if alpha is close to 1

# Convert from Nolan's parametrization S^0 where samples depend
# continuously on (alpha,beta), allowing interpolation around the hole at
# alpha=1.
Z = Z + beta * (math.pi / 2 * alpha).tan()

# Find near zero values
per_param_value_near_zero_tolerance = (
value_near_zero_tolerance_alpha * alpha / (1 - alpha).abs()
).clamp(
max=value_near_zero_tolerance_density
* _unsafe_alpha_stable_log_prob_at_zero(alpha, 0).exp().reciprocal()
)
idx = Z.abs() < per_param_value_near_zero_tolerance

# Calculate log-prob at safe values
log_prob = _unsafe_stable_log_prob(
alpha, beta, torch.where(idx, per_param_value_near_zero_tolerance, Z)
)

# Handle near zero values by interpolation
if idx.any():
log_prob_pos = log_prob[idx]
log_prob_neg = _unsafe_stable_log_prob(
alpha[idx], beta[idx], -per_param_value_near_zero_tolerance[idx]
)
weights = Z[idx] / (2 * per_param_value_near_zero_tolerance[idx]) + 0.5
log_prob[idx] = torch.logsumexp(
torch.stack(
(log_prob_pos + weights.log(), log_prob_neg + (1 - weights).log()),
dim=0,
),
dim=0,
)

return log_prob


def _unsafe_stable_log_prob(alpha, beta, Z):
# Calculate log-probability of Z. This will fail if alpha is close to 1
# or if Z is close to 0
ha = math.pi / 2 * alpha
b = beta * ha.tan()
atan_b = b.atan()
u_zero = -alpha.reciprocal() * atan_b

# If sample should be negative calculate with flipped beta and flipped value
flip_beta_x = Z < 0
beta = torch.where(flip_beta_x, -beta, beta)
u_zero = torch.where(flip_beta_x, -u_zero, u_zero)
Z = torch.where(flip_beta_x, -Z, Z)

# Set integration domwin
domain = torch.stack((u_zero, 0.5 * math.pi * u_zero.new_ones(u_zero.shape)), dim=0)

integrand = partial(
_unsafe_stable_given_uniform_log_prob, alpha=alpha, beta=beta, Z=Z
)

return integrate(integrand, domain) - math.log(math.pi)


def _unsafe_stable_given_uniform_log_prob(V, alpha, beta, Z):
# Calculate log-probability of Z given V. This will fail if alpha is close to 1
# or if Z is close to 0
inv_alpha_minus_one = (alpha - 1).reciprocal()
half_pi = math.pi / 2
eps = torch.finfo(V.dtype).eps
# make V belong to the open interval (-pi/2, pi/2)
V = V.clamp(min=2 * eps - half_pi, max=half_pi - 2 * eps)
ha = half_pi * alpha
b = beta * ha.tan()
atan_b = b.atan()
cos_V = V.cos()

# +/- `ha` term to keep the precision of alpha * (V + half_pi) when V ~ -half_pi
v = atan_b - ha + alpha * (V + half_pi)

term1_log = atan_b.cos().log() * inv_alpha_minus_one
term2_log = (Z * cos_V / v.sin()).log() * alpha * inv_alpha_minus_one
term3_log = ((v - V).cos() / cos_V).log()

W_log = term1_log + term2_log + term3_log

W = W_log.clamp(min=MIN_LOG, max=MAX_LOG).exp()

log_prob = -W + (alpha * W / Z / (alpha - 1)).abs().log()

# Infinite W means zero-probability
log_prob = torch.where(W == torch.inf, -torch.inf, log_prob)

log_prob = log_prob.clamp(min=MIN_LOG, max=MAX_LOG)

return log_prob


def _unsafe_alpha_stable_log_prob_at_zero(alpha, beta):
# Calculate log-probability at value of zero. This will fail if alpha is close to 1
inv_alpha = alpha.reciprocal()
half_pi = math.pi / 2
ha = half_pi * alpha
b = beta * ha.tan()
atan_b = b.atan()

term1_log = (inv_alpha * atan_b).cos().log()
term2_log = atan_b.cos().log() * inv_alpha
term3_log = torch.lgamma(1 + inv_alpha)

log_prob = term1_log - term2_log + term3_log - math.log(math.pi)

log_prob = log_prob.clamp(min=MIN_LOG, max=MAX_LOG)

return log_prob
6 changes: 5 additions & 1 deletion pyro/infer/reparam/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,11 @@ def apply(self, msg):
is_observed = msg["is_observed"]

fn, event_dim = self._unwrap(fn)
assert isinstance(fn, dist.Stable) and fn.coords == "S0"
assert (
isinstance(fn, dist.Stable)
and fn.coords == "S0"
and not isinstance(fn, dist.StableWithLogProb)
)

# Strategy: Let X ~ S0(a,b,s,m) be the stable variable of interest.
# 1. WLOG scale and shift so s=1 and m=0, additionally shifting to convert
Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/reparam/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _minimal_reparam(fn, is_observed):
return TransformReparam() # Then reparametrize new sites.
fn = fn.base_dist

if isinstance(fn, dist.Stable):
if isinstance(fn, dist.Stable) and not isinstance(fn, dist.StableWithLogProb):
if not is_observed:
return LatentStableReparam()
elif fn.skew.requires_grad or fn.skew.any():
Expand Down
Loading
Loading