diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 962ba3bda9c..0ebd582bc40 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -8,6 +8,11 @@ ### New Features - The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models. +- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)): + - With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible. + - `dims` is arguably the most elegant parametrization, because it allows you to resize `pm.Data` variables and leads to well defined coordinates in `InferenceData` objects. + - The `size` kwarg creates new dimensions in addition to what is implied by RV parameters. + - An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions. - ... ### Maintenance diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 4960592c8cd..e95eecaed94 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -13,6 +13,7 @@ # limitations under the License. import contextvars import inspect +import logging import multiprocessing import sys import types @@ -20,23 +21,19 @@ from abc import ABCMeta from copy import copy -from typing import TYPE_CHECKING +from typing import Any, Optional, Sequence, Tuple, Union +import aesara +import aesara.tensor as at import dill +from aesara.graph.basic import Variable from aesara.tensor.random.op import RandomVariable +from pymc3.aesaraf import change_rv_size, pandas_to_array from pymc3.distributions import _logcdf, _logp - -if TYPE_CHECKING: - from typing import Optional, Callable - -import aesara -import aesara.graph.basic -import aesara.tensor as at - from pymc3.util import UNSET, get_repr_for_variable -from pymc3.vartypes import string_types +from pymc3.vartypes import isgenerator, string_types __all__ = [ "DensityDist", @@ -46,12 +43,18 @@ "NoDistribution", ] +_log = logging.getLogger(__file__) + vectorized_ppc = contextvars.ContextVar( "vectorized_ppc", default=None ) # type: contextvars.ContextVar[Optional[Callable]] PLATFORM = sys.platform +Shape = Union[int, Sequence[Union[str, type(Ellipsis)]], Variable] +Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]] +Size = Union[int, Tuple[int, ...]] + class _Unpickling: pass @@ -122,13 +125,111 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs): return new_cls +def _valid_ellipsis_position(items: Union[None, Shape, Dims, Size]) -> bool: + if items is not None and not isinstance(items, Variable) and Ellipsis in items: + if any(i == Ellipsis for i in items[:-1]): + return False + return True + + +def _validate_shape_dims_size( + shape: Any = None, dims: Any = None, size: Any = None +) -> Tuple[Optional[Shape], Optional[Dims], Optional[Size]]: + # Raise on unsupported parametrization + if shape is not None and dims is not None: + raise ValueError("Passing both `shape` ({shape}) and `dims` ({dims}) is not supported!") + if dims is not None and size is not None: + raise ValueError("Passing both `dims` ({dims}) and `size` ({size}) is not supported!") + if shape is not None and size is not None: + raise ValueError("Passing both `shape` ({shape}) and `size` ({size}) is not supported!") + + # Raise on invalid types + if not isinstance(shape, (type(None), int, list, tuple, Variable)): + raise ValueError("The `shape` parameter must be an int, list or tuple.") + if not isinstance(dims, (type(None), str, list, tuple)): + raise ValueError("The `dims` parameter must be a str, list or tuple.") + if not isinstance(size, (type(None), int, list, tuple)): + raise ValueError("The `size` parameter must be an int, list or tuple.") + + # Auto-convert non-tupled parameters + if isinstance(shape, int): + shape = (shape,) + if isinstance(dims, str): + dims = (dims,) + if isinstance(size, int): + size = (size,) + + # Convert to actual tuples + if not isinstance(shape, (type(None), tuple, Variable)): + shape = tuple(shape) + if not isinstance(dims, (type(None), tuple)): + dims = tuple(dims) + if not isinstance(size, (type(None), tuple)): + size = tuple(size) + + if not _valid_ellipsis_position(shape): + raise ValueError( + f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}" + ) + if not _valid_ellipsis_position(dims): + raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}") + if size is not None and Ellipsis in size: + raise ValueError("The `size` parameter cannot contain an Ellipsis. Actual: {size}") + return shape, dims, size + + class Distribution(metaclass=DistributionMeta): """Statistical distribution""" rv_class = None rv_op = None - def __new__(cls, name, *args, **kwargs): + def __new__( + cls, + name: str, + *args, + rng=None, + dims: Optional[Dims] = None, + testval=None, + observed=None, + total_size=None, + transform=UNSET, + **kwargs, + ) -> RandomVariable: + """Adds a RandomVariable corresponding to a PyMC3 distribution to the current model. + + Note that all remaining kwargs must be compatible with ``.dist()`` + + Parameters + ---------- + cls : type + A PyMC3 distribution. + name : str + Name for the new model variable. + rng : optional + Random number generator to use with the RandomVariable. + dims : tuple, optional + A tuple of dimension names known to the model. + testval : optional + Test value to be attached to the output RV. + Must match its shape exactly. + observed : optional + Observed data to be passed when registering the random variable in the model. + See ``Model.register_rv``. + total_size : float, optional + See ``Model.register_rv``. + transform : optional + See ``Model.register_rv``. + **kwargs + Keyword arguments that will be forwarded to ``.dist()``. + Most prominently: ``shape`` and ``size`` + + Returns + ------- + rv : RandomVariable + The created RV, registered in the Model. + """ + try: from pymc3.model import Model @@ -141,40 +242,127 @@ def __new__(cls, name, *args, **kwargs): "for a standalone distribution." ) - rng = kwargs.pop("rng", None) + if not isinstance(name, string_types): + raise TypeError(f"Name needs to be a string but got: {name}") if rng is None: rng = model.default_rng - if not isinstance(name, string_types): - raise TypeError(f"Name needs to be a string but got: {name}") + _, dims, _ = _validate_shape_dims_size(dims=dims) + resize = None - data = kwargs.pop("observed", None) + # Create the RV without specifying testval, because the testval may have a shape + # that only matches after replicating with a size implied by dims (see below). + rv_out = cls.dist(*args, rng=rng, testval=None, **kwargs) + n_implied = rv_out.ndim - total_size = kwargs.pop("total_size", None) + # `dims` are only available with this API, because `.dist()` can be used + # without a modelcontext and dims are not yet tracked at the Aesara level. + if dims is not None: + if Ellipsis in dims: + # Auto-complete the dims tuple to the full length + dims = (*dims[:-1], *[None] * rv_out.ndim) - dims = kwargs.pop("dims", None) + n_resize = len(dims) - n_implied - if "shape" in kwargs: - raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.") + # All resize dims must be known already (numerically or symbolically). + unknown_resize_dims = set(dims[:n_resize]) - set(model.dim_lengths) + if unknown_resize_dims: + raise KeyError( + f"Dimensions {unknown_resize_dims} are unknown to the model and cannot be used to specify a `size`." + ) + + # The numeric/symbolic resize tuple can be created using model.RV_dim_lengths + resize = tuple(model.dim_lengths[dname] for dname in dims[:n_resize]) + elif observed is not None: + if isgenerator(observed): + observed = pandas_to_array(observed).astype(rv_out.dtype) + else: + observed = at.as_tensor_variable(observed, dtype=rv_out.dtype) + n_resize = observed.ndim - n_implied + resize = tuple(observed.shape[d] for d in range(n_resize)) - transform = kwargs.pop("transform", UNSET) + if resize: + # A batch size was specified through `dims`, or implied by `observed`. + rv_out = change_rv_size(rv_var=rv_out, new_size=resize, expand=True) - rv_out = cls.dist(*args, rng=rng, **kwargs) + if dims is not None: + # Now that we have a handle on the output RV, we can register named implied dimensions that + # were not yet known to the model, such that they can be used for size further downstream. + for di, dname in enumerate(dims[n_resize:]): + if not dname in model.dim_lengths: + model.add_coord(dname, values=None, length=rv_out.shape[n_resize + di]) - return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform) + if testval is not None: + # Assigning the testval earlier causes trouble because the RV may not be created with the final shape already. + rv_out.tag.test_value = testval + + return model.register_rv(rv_out, name, observed, total_size, dims=dims, transform=transform) @classmethod - def dist(cls, dist_params, **kwargs): + def dist( + cls, + dist_params, + *, + shape: Optional[Shape] = None, + size: Optional[Size] = None, + testval=None, + **kwargs, + ) -> RandomVariable: + """Creates a RandomVariable corresponding to the `cls` distribution. - testval = kwargs.pop("testval", None) + Parameters + ---------- + dist_params + shape : tuple, optional + A tuple of sizes for each dimension of the new RV. + + Ellipsis (...) may be used in the last position of the tuple, + and automatically expand to the shape implied by RV inputs. + size : int, tuple, Variable, optional + A scalar or tuple for replicating the RV in addition + to its implied shape/dimensionality. + testval : optional + Test value to be attached to the output RV. + Must match its shape exactly. - rv_var = cls.rv_op(*dist_params, **kwargs) + Returns + ------- + rv : RandomVariable + The created RV. + """ + if "dims" in kwargs: + raise NotImplementedError("The use of a `.dist(dims=...)` API is not yet supported.") + + shape, _, size = _validate_shape_dims_size(shape=shape, size=size) + + # Create the RV without specifying size or testval. + # The size will be expanded later (if necessary) and only then the testval fits. + rv_native = cls.rv_op(*dist_params, size=None, **kwargs) + + if shape is None and size is None: + size = () + elif shape is not None: + if isinstance(shape, Variable): + size = () + else: + if Ellipsis in shape: + size = tuple(shape[:-1]) + else: + size = tuple(shape[: len(shape) - rv_native.ndim]) + # no-op conditions: + # `elif size is not None` (User already specified how to expand the RV) + # `else` (Unreachable) + + if size: + rv_out = change_rv_size(rv_var=rv_native, new_size=size, expand=True) + else: + rv_out = rv_native if testval is not None: - rv_var.tag.test_value = testval + rv_out.tag.test_value = testval - return rv_var + return rv_out def _distr_parameters_for_repr(self): """Return the names of the parameters for this distribution (e.g. "mu" diff --git a/pymc3/model.py b/pymc3/model.py index 17ad3670436..f72dd81716f 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -47,7 +47,6 @@ from pandas import Series from pymc3.aesaraf import ( - change_rv_size, gradient, hessian, inputvars, @@ -959,7 +958,7 @@ def add_coord( ---------- name : str Name of the dimension. - Forbidden: {"chain", "draw"} + Forbidden: {"chain", "draw", "__sample__"} values : optional, array-like Coordinate values or ``None`` (for auto-numbering). If ``None`` is passed, a ``length`` must be specified. @@ -967,9 +966,10 @@ def add_coord( A symbolic scalar of the dimensions length. Defaults to ``aesara.shared(len(values))``. """ - if name in {"draw", "chain"}: + if name in {"draw", "chain", "__sample__"}: raise ValueError( - "Dimensions can not be named `draw` or `chain`, as they are reserved for the sampler's outputs." + "Dimensions can not be named `draw`, `chain` or `__sample__`, " + "as those are reserved for use in `InferenceData`." ) if values is None and length is None: raise ValueError( @@ -981,7 +981,7 @@ def add_coord( ) if name in self.coords: if not values.equals(self.coords[name]): - raise ValueError("Duplicate and incompatiple coordinate: %s." % name) + raise ValueError(f"Duplicate and incompatiple coordinate: {name}.") else: self._coords[name] = values self._dim_lengths[name] = length or aesara.shared(len(values)) @@ -1019,7 +1019,8 @@ def set_data( New values for the shared variable. coords : optional, dict New coordinate values for dimensions of the shared variable. - Must be provided for all named dimensions that change in length. + Must be provided for all named dimensions that change in length + and already have coordinate values. """ shared_object = self[name] if not isinstance(shared_object, SharedVariable): @@ -1134,6 +1135,7 @@ def make_obs_var( ========== rv_var The random variable that is observed. + Its dimensionality must be compatible with the data already. data The observed data. dims: tuple @@ -1145,22 +1147,6 @@ def make_obs_var( name = rv_var.name data = pandas_to_array(data).astype(rv_var.dtype) - # The shapes of the observed random variable and its data might not - # match. We need need to update the observed random variable's `size` - # (i.e. number of samples) so that it matches the data. - - # Setting `size` produces a random variable with shape `size + - # support_shape`, where `len(support_shape) == op.ndim_supp`, we need - # to disregard the last `op.ndim_supp`-many dimensions when we - # determine the appropriate `size` value from `data.shape`. - ndim_supp = rv_var.owner.op.ndim_supp - if ndim_supp > 0: - new_size = data.shape[:-ndim_supp] - else: - new_size = data.shape - - rv_var = change_rv_size(rv_var, new_size) - if aesara.config.compute_test_value != "off": test_value = getattr(rv_var.tag, "test_value", None) diff --git a/pymc3/tests/sampler_fixtures.py b/pymc3/tests/sampler_fixtures.py index 30a14a6a1e9..b4f5cc6cffd 100644 --- a/pymc3/tests/sampler_fixtures.py +++ b/pymc3/tests/sampler_fixtures.py @@ -92,7 +92,7 @@ class BetaBinomialFixture(KnownCDF): @classmethod def make_model(cls): with pm.Model() as model: - p = pm.Beta("p", [0.5, 0.5, 1.0], [0.5, 0.5, 1.0], size=3) + p = pm.Beta("p", [0.5, 0.5, 1.0], [0.5, 0.5, 1.0]) pm.Binomial("y", p=p, n=[4, 12, 9], observed=[1, 2, 9]) return model diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index dddc1dfb236..488586951bb 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -23,6 +23,7 @@ import pymc3 as pm from pymc3.distributions import logpt +from pymc3.exceptions import ShapeError from pymc3.tests.helpers import SeededTest @@ -160,22 +161,42 @@ def test_shared_data_as_rv_input(self): """ with pm.Model() as m: x = pm.Data("x", [1.0, 2.0, 3.0]) - _ = pm.Normal("y", mu=x, size=3) - trace = pm.sample( - chains=1, return_inferencedata=False, compute_convergence_checks=False + assert x.eval().shape == (3,) + y = pm.Normal("y", mu=x, size=2) + assert y.eval().shape == (2, 3) + idata = pm.sample( + chains=1, + tune=500, + draws=550, + return_inferencedata=True, + compute_convergence_checks=False, ) + samples = idata.posterior["y"] + assert samples.shape == (1, 550, 2, 3) np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), x.get_value(), atol=1e-1) - np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), trace["y"].mean(0), atol=1e-1) + np.testing.assert_allclose( + np.array([1.0, 2.0, 3.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1 + ) with m: pm.set_data({"x": np.array([2.0, 4.0, 6.0])}) - trace = pm.sample( - chains=1, return_inferencedata=False, compute_convergence_checks=False + assert x.eval().shape == (3,) + assert y.eval().shape == (2, 3) + idata = pm.sample( + chains=1, + tune=500, + draws=620, + return_inferencedata=True, + compute_convergence_checks=False, ) + samples = idata.posterior["y"] + assert samples.shape == (1, 620, 2, 3) np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1) - np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1) + np.testing.assert_allclose( + np.array([2.0, 4.0, 6.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1 + ) def test_shared_scalar_as_rv_input(self): # See https://github.com/pymc-devs/pymc3/issues/3139 @@ -284,6 +305,38 @@ def test_explicit_coords(self): assert isinstance(pmodel.dim_lengths["columns"], ScalarSharedVariable) assert pmodel.dim_lengths["columns"].eval() == 7 + def test_symbolic_coords(self): + """ + In v4 dimensions can be created without passing coordinate values. + Their length is then automatically linked to the corresponding Tensor dimension. + """ + with pm.Model() as pmodel: + intensity = pm.Data("intensity", np.ones((2, 3)), dims=("row", "column")) + assert "row" in pmodel.dim_lengths + assert "column" in pmodel.dim_lengths + assert isinstance(pmodel.dim_lengths["row"], TensorVariable) + assert isinstance(pmodel.dim_lengths["column"], TensorVariable) + assert pmodel.dim_lengths["row"].eval() == 2 + assert pmodel.dim_lengths["column"].eval() == 3 + + intensity.set_value(np.ones((4, 5))) + assert pmodel.dim_lengths["row"].eval() == 4 + assert pmodel.dim_lengths["column"].eval() == 5 + pass + + def test_no_resize_of_implied_dimensions(self): + with pm.Model() as pmodel: + # Imply a dimension through RV params + pm.Normal("n", mu=[1, 2, 3], dims="city") + # _Use_ the dimension for a data variable + inhabitants = pm.Data("inhabitants", [100, 200, 300], dims="city") + + # Attempting to re-size the dimension through the data variable would + # cause shape problems in InferenceData conversion, because the RV remains (3,). + with pytest.raises(ShapeError, match="was not initialized from a shared variable"): + pmodel.set_data("inhabitants", [1, 2, 3, 4]) + pass + def test_implicit_coords_series(self): ser_sales = pd.Series( data=np.random.randint(low=0, high=30, size=22), diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 8146c132d7d..0d13bc5720f 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -174,12 +174,7 @@ def get_random_variable(self, shape, with_vector_params=False, name=None): # in the test case parametrization "None" means "no specified (default)" return self.distribution(name, transform=None, **params) else: - ndim_supp = self.distribution.rv_op.ndim_supp - if ndim_supp == 0: - size = shape - else: - size = shape[:-ndim_supp] - return self.distribution(name, size=size, transform=None, **params) + return self.distribution(name, shape=shape, transform=None, **params) except TypeError: if np.sum(np.atleast_1d(shape)) == 0: pytest.skip("Timeseries must have positive shape") @@ -188,10 +183,9 @@ def get_random_variable(self, shape, with_vector_params=False, name=None): @staticmethod def sample_random_variable(random_variable, size): """ Draws samples from a RandomVariable using its .random() method. """ - if size is None: - return random_variable.eval() - else: - return change_rv_size(random_variable, size, expand=True).eval() + if size: + random_variable = change_rv_size(random_variable, size, expand=True) + return random_variable.eval() @pytest.mark.parametrize("size", [None, (), 1, (1,), 5, (4, 5)], ids=str) @pytest.mark.parametrize("shape", [None, ()], ids=str) diff --git a/pymc3/tests/test_logp.py b/pymc3/tests/test_logp.py index aea9db1fdc5..215e155e2fa 100644 --- a/pymc3/tests/test_logp.py +++ b/pymc3/tests/test_logp.py @@ -70,7 +70,7 @@ def test_logpt_basic(): @pytest.mark.parametrize( - "indices, size", + "indices, shape", [ (slice(0, 2), 5), (np.r_[True, True, False, False, True], 5), @@ -78,15 +78,15 @@ def test_logpt_basic(): ((np.array([0, 1, 4]), np.array([0, 1, 4])), (5, 5)), ], ) -def test_logpt_incsubtensor(indices, size): +def test_logpt_incsubtensor(indices, shape): """Make sure we can compute a log-likelihood for ``Y[idx] = data`` where ``Y`` is univariate.""" - mu = floatX(np.power(10, np.arange(np.prod(size)))).reshape(size) + mu = floatX(np.power(10, np.arange(np.prod(shape)))).reshape(shape) data = mu[indices] sigma = 0.001 rng = aesara.shared(np.random.RandomState(232), borrow=True) - a = Normal.dist(mu, sigma, size=size, rng=rng) + a = Normal.dist(mu, sigma, shape=shape, rng=rng) a.name = "a" a_idx = at.set_subtensor(a[indices], data) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 1b061deb609..250924396dc 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -213,7 +213,7 @@ def test_return_inferencedata(self, monkeypatch): return_inferencedata=True, discard_tuned_samples=True, idata_kwargs={"prior": prior}, - random_seed=-1 + random_seed=-1, ) assert "prior" in result assert isinstance(result, InferenceData) @@ -385,11 +385,10 @@ def test_shared_named(self): "theta0", mu=np.atleast_2d(0), tau=np.atleast_2d(1e20), - size=(1, 1), testval=np.atleast_2d(0), ) theta = pm.Normal( - "theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), size=(1, 1) + "theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), shape=(1, 1) ) res = theta.eval() assert np.isclose(res, 0.0) @@ -401,11 +400,10 @@ def test_shared_unnamed(self): "theta0", mu=np.atleast_2d(0), tau=np.atleast_2d(1e20), - size=(1, 1), testval=np.atleast_2d(0), ) theta = pm.Normal( - "theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), size=(1, 1) + "theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), shape=(1, 1) ) res = theta.eval() assert np.isclose(res, 0.0) @@ -417,11 +415,10 @@ def test_constant_named(self): "theta0", mu=np.atleast_2d(0), tau=np.atleast_2d(1e20), - size=(1, 1), testval=np.atleast_2d(0), ) theta = pm.Normal( - "theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), size=(1, 1) + "theta", mu=at.dot(G_var, theta0), tau=np.atleast_2d(1e20), shape=(1, 1) ) res = theta.eval() @@ -936,14 +933,14 @@ def test_ignores_observed(self): npt.assert_array_almost_equal(prior["positive_mu"], np.abs(prior["mu"]), decimal=4) def test_respects_shape(self): - for shape in (2, (2,), (10, 2), (10, 10)): + for shape in ((2,), (10, 2), (10, 10)): with pm.Model(): - mu = pm.Gamma("mu", 3, 1, size=1) - goals = pm.Poisson("goals", mu, size=shape) + mu = pm.Gamma("mu", 3, 1) + assert mu.eval().shape == () + goals = pm.Poisson("goals", mu, shape=shape) + assert goals.eval().shape == shape, f"Current shape setting: {shape}" trace1 = pm.sample_prior_predictive(10, var_names=["mu", "mu", "goals"]) trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"]) - if shape == 2: # want to test shape as an int - shape = (2,) assert trace1["goals"].shape == (10,) + shape assert trace2["goals"].shape == (10,) + shape @@ -971,14 +968,14 @@ def test_multivariate2(self): def test_layers(self): with pm.Model() as model: - a = pm.Uniform("a", lower=0, upper=1, size=10) - b = pm.Binomial("b", n=1, p=a, size=10) + a = pm.Uniform("a", lower=0, upper=1, size=5) + b = pm.Binomial("b", n=1, p=a, size=7) model.default_rng.get_value(borrow=True).seed(232093) b_sampler = aesara.function([], b) avg = np.stack([b_sampler() for i in range(10000)]).mean(0) - npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2) + npt.assert_array_almost_equal(avg, 0.5 * np.ones((7, 5)), decimal=2) def test_transformed(self): n = 18 diff --git a/pymc3/tests/test_shape_handling.py b/pymc3/tests/test_shape_handling.py index 37c06193226..e0e0bc78d7a 100644 --- a/pymc3/tests/test_shape_handling.py +++ b/pymc3/tests/test_shape_handling.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import aesara import numpy as np import pytest @@ -19,6 +21,7 @@ import pymc3 as pm +from pymc3.distributions.distribution import _validate_shape_dims_size from pymc3.distributions.shape_utils import ( broadcast_dist_samples_shape, broadcast_dist_samples_to, @@ -28,6 +31,8 @@ to_tuple, ) +_log = logging.getLogger(__file__) + test_shapes = [ (tuple(), (1,), (4,), (5, 4)), (tuple(), (1,), (7,), (5, 4)), @@ -219,3 +224,134 @@ def test_sample_generate_values(fixture_model, fixture_sizes): prior = pm.sample_prior_predictive(samples=fixture_sizes) for rv in RVs: assert prior[rv.name].shape == size + tuple(rv.distribution.shape) + + +class TestShapeDimsSize: + @pytest.mark.parametrize("param_shape", [(), (3,)]) + @pytest.mark.parametrize("batch_shape", [(), (3,)]) + @pytest.mark.parametrize( + "parametrization", + [ + "implicit", + "shape", + "shape...", + "dims", + "dims...", + "size", + ], + ) + def test_param_and_batch_shape_combos( + self, param_shape: tuple, batch_shape: tuple, parametrization: str + ): + coords = {} + param_dims = [] + batch_dims = [] + + # Create coordinates corresponding to the parameter shape + for d in param_shape: + dname = f"param_dim_{d}" + coords[dname] = [f"c_{i}" for i in range(d)] + param_dims.append(dname) + assert len(param_dims) == len(param_shape) + # Create coordinates corresponding to the batch shape + for d in batch_shape: + dname = f"batch_dim_{d}" + coords[dname] = [f"c_{i}" for i in range(d)] + batch_dims.append(dname) + assert len(batch_dims) == len(batch_shape) + + with pm.Model(coords=coords) as pmodel: + mu = aesara.shared(np.random.normal(size=param_shape)) + + with pytest.warns(None): + if parametrization == "implicit": + rv = pm.Normal("rv", mu=mu).shape == param_shape + else: + if parametrization == "shape": + rv = pm.Normal("rv", mu=mu, shape=batch_shape + param_shape) + assert rv.eval().shape == batch_shape + param_shape + elif parametrization == "shape...": + rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...)) + assert rv.eval().shape == batch_shape + param_shape + elif parametrization == "dims": + rv = pm.Normal("rv", mu=mu, dims=batch_dims + param_dims) + assert rv.eval().shape == batch_shape + param_shape + elif parametrization == "dims...": + rv = pm.Normal("rv", mu=mu, dims=(*batch_dims, ...)) + n_size = len(batch_shape) + n_implied = len(param_shape) + ndim = n_size + n_implied + assert len(pmodel.RV_dims["rv"]) == ndim, pmodel.RV_dims + assert len(pmodel.RV_dims["rv"][:n_size]) == len(batch_dims) + assert len(pmodel.RV_dims["rv"][n_size:]) == len(param_dims) + if n_implied > 0: + assert pmodel.RV_dims["rv"][-1] is None + elif parametrization == "size": + rv = pm.Normal("rv", mu=mu, size=batch_shape) + assert rv.eval().shape == batch_shape + param_shape + else: + raise NotImplementedError("Invalid test case parametrization.") + pass + + def test_define_dims_on_the_fly(self): + with pm.Model() as pmodel: + agedata = aesara.shared(np.array([10, 20, 30])) + + # Associate the "patient" dim with an implied dimension + age = pm.Normal("age", agedata, dims=("patient",)) + assert "patient" in pmodel.dim_lengths + assert pmodel.dim_lengths["patient"].eval() == 3 + + # Use the dim to replicate a new RV + effect = pm.Normal("effect", 0, dims=("patient",)) + assert effect.ndim == 1 + assert effect.eval().shape == (3,) + + # Now change the length of the implied dimension + agedata.set_value([1, 2, 3, 4]) + # The change should propagate all the way through + assert effect.eval().shape == (4,) + pass + + def test_dist_api_works(self): + mu = aesara.shared(np.array([1, 2, 3])) + with pytest.raises(NotImplementedError, match="API is not yet supported"): + pm.Normal.dist(mu=mu, dims=("town",)) + assert pm.Normal.dist(mu=mu, shape=(3,)).eval().shape == (3,) + assert pm.Normal.dist(mu=mu, shape=(5, 3)).eval().shape == (5, 3) + assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3) + assert pm.Normal.dist(mu=mu, size=(4,)).eval().shape == (4, 3) + pass + + def test_lazy_flavors(self): + + _validate_shape_dims_size(shape=5) + _validate_shape_dims_size(dims="town") + _validate_shape_dims_size(size=7) + + assert pm.Uniform.dist(2, [4, 5], size=[3, 4]).eval().shape == (3, 4, 2) + assert pm.Uniform.dist(2, [4, 5], shape=[3, 2]).eval().shape == (3, 2) + with pm.Model(coords=dict(town=["Greifswald", "Madrid"])): + assert pm.Normal("n2", mu=[1, 2], dims=("town",)).eval().shape == (2,) + pass + + def test_invalid_flavors(self): + # redundant parametrizations + with pytest.raises(ValueError, match="Passing both"): + _validate_shape_dims_size(shape=(2,), dims=("town",)) + with pytest.raises(ValueError, match="Passing both"): + _validate_shape_dims_size(dims=("town",), size=(2,)) + with pytest.raises(ValueError, match="Passing both"): + _validate_shape_dims_size(shape=(3,), size=(3,)) + + # invalid, but not necessarly rare + with pytest.raises(ValueError, match="must be an int, list or tuple"): + _validate_shape_dims_size(size="notasize") + + # invalid ellipsis positions + with pytest.raises(ValueError, match="may only appear in the last position"): + _validate_shape_dims_size(shape=(3, ..., 2)) + with pytest.raises(ValueError, match="may only appear in the last position"): + _validate_shape_dims_size(dims=(..., "town")) + with pytest.raises(ValueError, match="cannot contain"): + _validate_shape_dims_size(size=(3, ...)) diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index fd021398794..6ef93eb5d1b 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -982,7 +982,7 @@ def test_linalg(self, caplog): a = Normal("a", size=2, testval=floatX(np.zeros(2))) a = at.switch(a > 0, np.inf, a) b = at.slinalg.solve(floatX(np.eye(2)), a) - Normal("c", mu=b, size=2, testval=floatX(np.r_[0.0, 0.0])) + Normal("c", mu=b, shape=(2,), testval=floatX(np.r_[0.0, 0.0])) caplog.clear() trace = sample(20, init=None, tune=5, chains=2) warns = [msg.msg for msg in caplog.records] diff --git a/pymc3/tests/test_transforms.py b/pymc3/tests/test_transforms.py index fd32d8b9b65..e040e6e244b 100644 --- a/pymc3/tests/test_transforms.py +++ b/pymc3/tests/test_transforms.py @@ -274,11 +274,11 @@ def test_chain_jacob_det(): class TestElementWiseLogp(SeededTest): - def build_model(self, distfam, params, size, transform, testval=None): + def build_model(self, distfam, params, *, size=None, shape=None, transform=None, testval=None): if testval is not None: testval = pm.floatX(testval) with pm.Model() as m: - distfam("x", size=size, transform=transform, testval=testval, **params) + distfam("x", size=size, shape=shape, transform=transform, testval=testval, **params) return m def check_transform_elementwise_logp(self, model): @@ -328,32 +328,34 @@ def test_half_normal(self, sd, size): model = self.build_model(pm.HalfNormal, {"sd": sd}, size=size, transform=tr.log) self.check_transform_elementwise_logp(model) - @pytest.mark.parametrize("lam,size", [(2.5, 2), (5.0, (2, 3)), (np.ones(3), (4, 3))]) + @pytest.mark.parametrize("lam,size", [(2.5, 2), (5.0, (2, 3)), (np.ones(3), (4, 5))]) def test_exponential(self, lam, size): model = self.build_model(pm.Exponential, {"lam": lam}, size=size, transform=tr.log) self.check_transform_elementwise_logp(model) @pytest.mark.parametrize( - "a,b,size", + "a,b,shape", [ (1.0, 1.0, 2), (0.5, 0.5, (2, 3)), (np.ones(3), np.ones(3), (4, 3)), ], ) - def test_beta(self, a, b, size): - model = self.build_model(pm.Beta, {"alpha": a, "beta": b}, size=size, transform=tr.logodds) + def test_beta(self, a, b, shape): + model = self.build_model( + pm.Beta, {"alpha": a, "beta": b}, shape=shape, transform=tr.logodds + ) self.check_transform_elementwise_logp(model) @pytest.mark.parametrize( - "lower,upper,size", + "lower,upper,shape", [ (0.0, 1.0, 2), (0.5, 5.5, (2, 3)), (pm.floatX(np.zeros(3)), pm.floatX(np.ones(3)), (4, 3)), ], ) - def test_uniform(self, lower, upper, size): + def test_uniform(self, lower, upper, shape): def transform_params(rv_var): _, _, _, lower, upper = rv_var.owner.inputs lower = at.as_tensor_variable(lower) if lower is not None else None @@ -362,25 +364,25 @@ def transform_params(rv_var): interval = tr.Interval(transform_params) model = self.build_model( - pm.Uniform, {"lower": lower, "upper": upper}, size=size, transform=interval + pm.Uniform, {"lower": lower, "upper": upper}, shape=shape, transform=interval ) self.check_transform_elementwise_logp(model) @pytest.mark.parametrize( - "mu,kappa,size", [(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (np.zeros(3), np.ones(3), (4, 3))] + "mu,kappa,shape", [(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (np.zeros(3), np.ones(3), (4, 3))] ) @pytest.mark.xfail(reason="Distribution not refactored yet") - def test_vonmises(self, mu, kappa, size): + def test_vonmises(self, mu, kappa, shape): model = self.build_model( - pm.VonMises, {"mu": mu, "kappa": kappa}, size=size, transform=tr.circular + pm.VonMises, {"mu": mu, "kappa": kappa}, shape=shape, transform=tr.circular ) self.check_transform_elementwise_logp(model) @pytest.mark.parametrize( - "a,size", [(np.ones(2), None), (np.ones((2, 3)) * 0.5, None), (np.ones(3), (4,))] + "a,shape", [(np.ones(2), None), (np.ones((2, 3)) * 0.5, None), (np.ones(3), (4,))] ) - def test_dirichlet(self, a, size): - model = self.build_model(pm.Dirichlet, {"a": a}, size=size, transform=tr.stick_breaking) + def test_dirichlet(self, a, shape): + model = self.build_model(pm.Dirichlet, {"a": a}, shape=shape, transform=tr.stick_breaking) self.check_vectortransform_elementwise_logp(model, vect_opt=1) def test_normal_ordered(self): @@ -394,59 +396,59 @@ def test_normal_ordered(self): self.check_vectortransform_elementwise_logp(model, vect_opt=0) @pytest.mark.parametrize( - "sd,size", + "sd,shape", [ (2.5, (2,)), (np.ones(3), (4, 3)), ], ) @pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32") - def test_half_normal_ordered(self, sd, size): - testval = np.sort(np.abs(np.random.randn(*size))) + def test_half_normal_ordered(self, sd, shape): + testval = np.sort(np.abs(np.random.randn(*shape))) model = self.build_model( pm.HalfNormal, {"sd": sd}, - size=size, + shape=shape, testval=testval, transform=tr.Chain([tr.log, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model, vect_opt=0) - @pytest.mark.parametrize("lam,size", [(2.5, (2,)), (np.ones(3), (4, 3))]) - def test_exponential_ordered(self, lam, size): - testval = np.sort(np.abs(np.random.randn(*size))) + @pytest.mark.parametrize("lam,shape", [(2.5, (2,)), (np.ones(3), (4, 3))]) + def test_exponential_ordered(self, lam, shape): + testval = np.sort(np.abs(np.random.randn(*shape))) model = self.build_model( pm.Exponential, {"lam": lam}, - size=size, + shape=shape, testval=testval, transform=tr.Chain([tr.log, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model, vect_opt=0) @pytest.mark.parametrize( - "a,b,size", + "a,b,shape", [ (1.0, 1.0, (2,)), (np.ones(3), np.ones(3), (4, 3)), ], ) - def test_beta_ordered(self, a, b, size): - testval = np.sort(np.abs(np.random.rand(*size))) + def test_beta_ordered(self, a, b, shape): + testval = np.sort(np.abs(np.random.rand(*shape))) model = self.build_model( pm.Beta, {"alpha": a, "beta": b}, - size=size, + shape=shape, testval=testval, transform=tr.Chain([tr.logodds, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model, vect_opt=0) @pytest.mark.parametrize( - "lower,upper,size", + "lower,upper,shape", [(0.0, 1.0, (2,)), (pm.floatX(np.zeros(3)), pm.floatX(np.ones(3)), (4, 3))], ) - def test_uniform_ordered(self, lower, upper, size): + def test_uniform_ordered(self, lower, upper, shape): def transform_params(rv_var): _, _, _, lower, upper = rv_var.owner.inputs lower = at.as_tensor_variable(lower) if lower is not None else None @@ -455,43 +457,45 @@ def transform_params(rv_var): interval = tr.Interval(transform_params) - testval = np.sort(np.abs(np.random.rand(*size))) + testval = np.sort(np.abs(np.random.rand(*shape))) model = self.build_model( pm.Uniform, {"lower": lower, "upper": upper}, - size=size, + shape=shape, testval=testval, transform=tr.Chain([interval, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model, vect_opt=1) - @pytest.mark.parametrize("mu,kappa,size", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))]) + @pytest.mark.parametrize( + "mu,kappa,shape", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))] + ) @pytest.mark.xfail(reason="Distribution not refactored yet") - def test_vonmises_ordered(self, mu, kappa, size): - testval = np.sort(np.abs(np.random.rand(*size))) + def test_vonmises_ordered(self, mu, kappa, shape): + testval = np.sort(np.abs(np.random.rand(*shape))) model = self.build_model( pm.VonMises, {"mu": mu, "kappa": kappa}, - size=size, + shape=shape, testval=testval, transform=tr.Chain([tr.circular, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model, vect_opt=0) @pytest.mark.parametrize( - "lower,upper,size,transform", + "lower,upper,shape,transform", [ (0.0, 1.0, (2,), tr.stick_breaking), (0.5, 5.5, (2, 3), tr.stick_breaking), (np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])), ], ) - def test_uniform_other(self, lower, upper, size, transform): - testval = np.ones(size) / size[-1] + def test_uniform_other(self, lower, upper, shape, transform): + testval = np.ones(shape) / shape[-1] model = self.build_model( pm.Uniform, {"lower": lower, "upper": upper}, - size=size, + shape=shape, testval=testval, transform=transform, )