Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZeroSumNormal distribution #1751

Merged
merged 43 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
fa63f9a
added zerosumnormal and tests
kylejcaron Feb 27, 2024
c28fd0c
added edge case handling for support shape
kylejcaron Feb 29, 2024
93bcf0f
removed commented out functions
kylejcaron Feb 29, 2024
d9f2b4e
added zerosumnormal to docs
kylejcaron Feb 29, 2024
0abb60b
fixed zerosumnormal support shape default
kylejcaron Feb 29, 2024
b28f38c
Added v1 of docstrings for zerosumnormal
kylejcaron Feb 29, 2024
4e1dd16
updated zsn docstring
kylejcaron Feb 29, 2024
8cd792c
improved init shape handling for zerosumnormal
kylejcaron Feb 29, 2024
dcbdd85
improved docstrings
kylejcaron Feb 29, 2024
13fff40
added ZeroSumTransform
kylejcaron Mar 5, 2024
514000c
made n_zerosum_axes an attribute for the zerosumtransform
kylejcaron Mar 5, 2024
d6315c3
removed commented out lines
kylejcaron Mar 5, 2024
907cd2e
added zerosumtransform class
kylejcaron Mar 7, 2024
fc3f053
switched zsn from ParameterFreeTransform to Transform
kylejcaron Mar 8, 2024
8187421
changed ZeroSumNormal to transformed distibutrion
kylejcaron Mar 25, 2024
0051342
changed input to tuple for _transform_to_zero_sum
kylejcaron Mar 25, 2024
1820a74
added forward and inverse shape to transform, fixed zero_sum constrai…
kylejcaron Mar 26, 2024
ee227bf
fixed failing zsn tests
kylejcaron Mar 26, 2024
bb4880c
added docstring, removed whitespace, fixed missing import
kylejcaron Mar 26, 2024
38b8f56
fixed allclose to be assert allclose
kylejcaron Mar 26, 2024
54533ff
Merge branch 'master' into zsn-dist
kylejcaron Mar 26, 2024
c8af390
linted and formatted
kylejcaron Mar 26, 2024
3034f4a
added sample code to docstring for zsn
kylejcaron Mar 26, 2024
ebdd309
updated docstring
kylejcaron Mar 26, 2024
8cb7a5f
removed list from ZeroSum constraint call
kylejcaron Mar 26, 2024
ae1586f
removed unneeded iteration, updated docstring
kylejcaron Mar 26, 2024
ab58216
updated constraint code
kylejcaron Mar 26, 2024
ad4e7c2
added ZeroSumTransform to docs
kylejcaron Mar 26, 2024
54547f2
fixed transform shapes
kylejcaron Mar 26, 2024
bdc6480
added doctest example for zsn
kylejcaron Mar 26, 2024
0b5070b
added constraint test
kylejcaron Mar 26, 2024
b1129bf
added zero_sum constraint to docs
kylejcaron Mar 26, 2024
5fcaf68
added type hinting to transforms file
kylejcaron Mar 26, 2024
619f90b
fixed docs formatting
kylejcaron Mar 27, 2024
2e79677
moved skip zsn from test_gof earlier
kylejcaron Mar 27, 2024
da382f5
reversed zerosumtransform
kylejcaron Mar 27, 2024
5aa5aeb
broadcasted mean and var of zsn
kylejcaron Mar 27, 2024
f7992d1
added stricter zero_sum constraint tol, improved mean and var functions
kylejcaron Mar 28, 2024
1e77815
fixed _transform_to_zero_sum
kylejcaron Mar 28, 2024
98f32f9
removed shape promote from zsn, changed broadcast to zeros_like
kylejcaron Mar 28, 2024
c639e70
chose better zsn test cases
kylejcaron Mar 28, 2024
8a7a905
Update zero_sum constraint feasible_like
kylejcaron Mar 28, 2024
d7f05ff
fixed docstring for doctests
kylejcaron Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ Weibull
:show-inheritance:
:member-order: bysource

ZeroSumNormal
^^^^^^^
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved
.. autoclass:: numpyro.distributions.continuous.ZeroSumNormal
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Discrete Distributions
----------------------
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
StudentT,
Uniform,
Weibull,
ZeroSumNormal,
)
from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta
from numpyro.distributions.directional import (
Expand Down Expand Up @@ -196,4 +197,5 @@
"ZeroInflatedDistribution",
"ZeroInflatedPoisson",
"ZeroInflatedNegativeBinomial2",
"ZeroSumNormal",
]
122 changes: 122 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2444,3 +2444,125 @@ def cdf(self, value):

def icdf(self, value):
return self._ald.icdf(value)


class ZeroSumNormal(Distribution):
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved
r"""
Zero Sum Normal distribution adapted from PyMC [1] as described in [2]. This is a Normal distribution where one or
more axes are constrained to sum to zero (the last axis by default).

:param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is
enforced.
:param int n_zerosum_axes: The number of axes to enforce a zerosum constraint.
:param tuple support_shape: The event shape of the distribution.

.. math::
\begin{align*}
ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
n = \text{number of zero-sum axes}
\end{align*}

**References**
[1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637
[2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html
"""
arg_constraints = {"scale": constraints.positive}
support = constraints.real
reparametrized_params = ["scale"]
pytree_aux_fields = ("n_zerosum_axes","support_shape",)

def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=None, *, validate_args=None):
if not all(tuple(i == 1 for i in jnp.shape( scale ))):
raise ValueError("scale must have length one across the zero-sum axes")

self.n_zerosum_axes = self.check_zerosum_axes(n_zerosum_axes)
support_shape = self.check_support_shape(support_shape, self.n_zerosum_axes)
if jnp.ndim(scale) == 0:
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved
(scale,) = promote_shapes(scale, shape=(1,))

batch_shape = jnp.shape(scale)[:-1]
self.scale = scale
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved

super(ZeroSumNormal, self).__init__(
batch_shape=batch_shape,
event_shape=support_shape,
validate_args=validate_args
)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
zerosum_rv_ = random.normal(
key, shape=sample_shape + self.batch_shape + self.event_shape
) * self.scale

if not zerosum_rv_.shape:
return jnp.zeros(zerosum_rv_.shape)

for axis in range(self.n_zerosum_axes):
zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True)
return zerosum_rv_

@validate_sample
def log_prob(self, value):
shape = jnp.array(value.shape)
_deg_free_support_shape = shape.at[-self.n_zerosum_axes:].set( shape[-self.n_zerosum_axes:] - 1 )
_full_size = jnp.prod(shape).astype(float)
_degrees_of_freedom = jnp.prod(_deg_free_support_shape).astype(float)

if not value.shape or self.batch_shape:
value = jnp.expand_dims(value, -1)

log_pdf = jnp.sum(
-0.5 * jnp.pow(value / self.scale, 2)
- (jnp.log(jnp.sqrt(2.0 * jnp.pi)) + jnp.log(self.scale)) * _degrees_of_freedom / _full_size,
axis=tuple(np.arange(-self.n_zerosum_axes, 0)),
)
return log_pdf

@property
def mean(self):
return jnp.broadcast_to(0, self.batch_shape)
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved

@property
def variance(self):
theoretical_var = self.scale.astype(float)**2
for axis in range(1,self.n_zerosum_axes+1):
theoretical_var *= (1 - 1 / self.event_shape[-axis])

return theoretical_var
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the mean, we need to broadcast this to batch_shape + event_shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, updated!


def check_zerosum_axes(self, n_zerosum_axes):
if n_zerosum_axes is None:
n_zerosum_axes = 1

is_integer = isinstance(n_zerosum_axes, int)
is_jax_int_array = isinstance(n_zerosum_axes, jnp.ndarray) and jnp.issubdtype(n_zerosum_axes.dtype, jnp.integer)
if not (is_integer or is_jax_int_array):
raise TypeError("n_zerosum_axes has to be an integer")
if not n_zerosum_axes > 0:
raise ValueError("n_zerosum_axes has to be > 0")
return n_zerosum_axes

def check_support_shape(self, support_shape, n_zerosum_axes):
if support_shape is None:
return ()
assert n_zerosum_axes <= len(support_shape), "support_shape has to be as long as n_zerosum_axes"
assert all(shape > 0 for shape in support_shape), "support_shape must be a valid shape"
assert len(support_shape) > 0, "support_shape must be a valid shape"
return support_shape

@staticmethod
def infer_shapes(scale=1.0, n_zerosum_axes=None, support_shape=(1,)):
'''Numpyro assumes that the event and batch shape can be entirely
determined by the shapes of the distribution inputs. This distribution
doesn't follow those conventions, so the `infer_shapes` method cant be implemented.
'''
raise NotImplementedError()

def _validate_sample(self, value):
mask = super(ZeroSumNormal, self)._validate_sample(value)
batch_dim = jnp.ndim(value) - len(self.event_shape)
if batch_dim < jnp.ndim(mask):
mask = jnp.all(jnp.reshape(mask, jnp.shape(mask)[:batch_dim] + (-1,)), -1)
return mask
15 changes: 14 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,10 @@ def get_sp_dist(jax_dist):
T(dist.Weibull, 0.2, 1.1),
T(dist.Weibull, 2.8, np.array([2.0, 2.0])),
T(dist.Weibull, 1.8, np.array([[1.0, 1.0], [2.0, 2.0]])),
T(dist.ZeroSumNormal, 1.0, None, (1,)),
T(dist.ZeroSumNormal, 1.0, 1, (1,)),
T(dist.ZeroSumNormal, np.array([2.0]), None, (1,)),
T(dist.ZeroSumNormal, 1.0, 2, (4,5)),
T(
_GaussianMixture,
np.ones(3) / 3.0,
Expand Down Expand Up @@ -1296,6 +1300,7 @@ def test_jit_log_likelihood(jax_dist, sp_dist, params):
"LKJ",
"LKJCholesky",
"_SparseCAR",
"ZeroSumNormal",
):
pytest.xfail(reason="non-jittable params")

Expand Down Expand Up @@ -1454,6 +1459,9 @@ def test_gof(jax_dist, sp_dist, params):
if jax_dist is dist.ProjectedNormal:
dim = samples.shape[-1] - 1

if jax_dist is dist.ZeroSumNormal:
pytest.skip("skip gof test for ZeroSumNormal")
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved

# Test each batch independently.
probs = probs.reshape(num_samples, -1)
samples = samples.reshape(probs.shape + d.event_shape)
Expand Down Expand Up @@ -1671,6 +1679,9 @@ def fn(*args):
if jax_dist is _SparseCAR and i == 3:
# skip taking grad w.r.t. adj_matrix
continue
if jax_dist is dist.ZeroSumNormal and i != 0:
# skip taking grad w.r.t. n_zerosum_axes and support_shape
continue
if isinstance(
params[i], dist.Distribution
): # skip taking grad w.r.t. base_dist
Expand Down Expand Up @@ -1857,7 +1868,7 @@ def get_min_shape(ix, batch_shape):
if isinstance(d_jax, dist.Gompertz):
pytest.skip("Gompertz distribution does not have `variance` implemented.")
if jnp.all(jnp.isfinite(d_jax.variance)):
assert_allclose(
jnp.allclose(
jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=0.05, atol=1e-2
)

Expand Down Expand Up @@ -1898,6 +1909,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
continue
if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps":
continue
if jax_dist is dist.ZeroSumNormal and dist_args[i] in ("n_zerosum_axes", "support_shape"):
continue
if (
jax_dist is dist.SineBivariateVonMises
and dist_args[i] == "weighted_correlation"
Expand Down