Skip to content

Commit

Permalink
Separate shape logic into a separate file (#4708)
Browse files Browse the repository at this point in the history
* Separate shape stuff into a separate file.

* Fix name of new function.

* Add missing import

* Fix import order

* Add doc-strings for find_size and maybe_resize.

* Move shape.py contents to shape_utils.py. Fix tests.

* Pass size kwarg to maybe_resize.

* No keyword arguments to maybe_resize.

* Move convert_size and convert_shape back to dist().

Co-authored-by: Michael Osthege <m.osthege@fz-juelich.de>
  • Loading branch information
twiecki and michaelosthege authored Jun 1, 2021
1 parent ab5f44f commit 669a6e7
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 224 deletions.
241 changes: 34 additions & 207 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,28 @@

from abc import ABCMeta
from copy import copy
from typing import Optional, Sequence, Tuple, Union
from typing import Optional

import aesara
import aesara.tensor as at
import dill
import numpy as np

from aesara.graph.basic import Variable
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable

from pymc3.aesaraf import change_rv_size, pandas_to_array
from pymc3.aesaraf import change_rv_size
from pymc3.distributions import _logcdf, _logp
from pymc3.exceptions import ShapeError, ShapeWarning
from pymc3.distributions.shape_utils import (
Dims,
Shape,
Size,
convert_dims,
convert_shape,
convert_size,
find_size,
maybe_resize,
resize_from_dims,
resize_from_observed,
)
from pymc3.util import UNSET, get_repr_for_variable
from pymc3.vartypes import string_types

Expand All @@ -51,20 +59,6 @@

PLATFORM = sys.platform

# User-provided can be lazily specified as scalars
Shape = Union[int, TensorVariable, Sequence[Union[int, TensorVariable, type(Ellipsis)]]]
Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]]
Size = Union[int, TensorVariable, Sequence[Union[int, TensorVariable]]]

# After conversion to vectors
WeakShape = Union[TensorVariable, Tuple[Union[int, TensorVariable, type(Ellipsis)], ...]]
WeakDims = Tuple[Union[str, None, type(Ellipsis)], ...]

# After Ellipsis were substituted
StrongShape = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
StrongDims = Sequence[Union[str, None]]
StrongSize = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]


class _Unpickling:
pass
Expand Down Expand Up @@ -128,135 +122,6 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
return new_cls


def _convert_dims(dims: Dims) -> Optional[WeakDims]:
""" Process a user-provided dims variable into None or a valid dims tuple. """
if dims is None:
return None

if isinstance(dims, str):
dims = (dims,)
elif isinstance(dims, (list, tuple)):
dims = tuple(dims)
else:
raise ValueError(f"The `dims` parameter must be a tuple, str or list. Actual: {type(dims)}")

if any(d == Ellipsis for d in dims[:-1]):
raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}")

return dims


def _convert_shape(shape: Shape) -> Optional[WeakShape]:
""" Process a user-provided shape variable into None or a valid shape object. """
if shape is None:
return None

if isinstance(shape, int) or (isinstance(shape, TensorVariable) and shape.ndim == 0):
shape = (shape,)
elif isinstance(shape, (list, tuple)):
shape = tuple(shape)
else:
raise ValueError(
f"The `shape` parameter must be a tuple, TensorVariable, int or list. Actual: {type(shape)}"
)

if isinstance(shape, tuple) and any(s == Ellipsis for s in shape[:-1]):
raise ValueError(
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
)

return shape


def _convert_size(size: Size) -> Optional[StrongSize]:
""" Process a user-provided size variable into None or a valid size object. """
if size is None:
return None

if isinstance(size, int) or (isinstance(size, TensorVariable) and size.ndim == 0):
size = (size,)
elif isinstance(size, (list, tuple)):
size = tuple(size)
else:
raise ValueError(
f"The `size` parameter must be a tuple, TensorVariable, int or list. Actual: {type(size)}"
)

if isinstance(size, tuple) and Ellipsis in size:
raise ValueError(f"The `size` parameter cannot contain an Ellipsis. Actual: {size}")

return size


def _resize_from_dims(
dims: WeakDims, ndim_implied: int, model
) -> Tuple[int, StrongSize, StrongDims]:
"""Determines a potential resize shape from a `dims` tuple.
Parameters
----------
dims : array-like
A vector of dimension names, None or Ellipsis.
ndim_implied : int
Number of RV dimensions that were implied from its inputs alone.
model : pm.Model
The current model on stack.
Returns
-------
ndim_resize : int
Number of dimensions that should be added through resizing.
resize_shape : array-like
The shape of the new dimensions.
"""
if Ellipsis in dims:
# Auto-complete the dims tuple to the full length.
# We don't have a way to know the names of implied
# dimensions, so they will be `None`.
dims = (*dims[:-1], *[None] * ndim_implied)

ndim_resize = len(dims) - ndim_implied

# All resize dims must be known already (numerically or symbolically).
unknowndim_resize_dims = set(dims[:ndim_resize]) - set(model.dim_lengths)
if unknowndim_resize_dims:
raise KeyError(
f"Dimensions {unknowndim_resize_dims} are unknown to the model and cannot be used to specify a `size`."
)

# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
resize_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_resize])
return ndim_resize, resize_shape, dims


def _resize_from_observed(
observed, ndim_implied: int
) -> Tuple[int, StrongSize, Union[np.ndarray, Variable]]:
"""Determines a potential resize shape from observations.
Parameters
----------
observed : scalar, array-like
The value of the `observed` kwarg to the RV creation.
ndim_implied : int
Number of RV dimensions that were implied from its inputs alone.
Returns
-------
ndim_resize : int
Number of dimensions that should be added through resizing.
resize_shape : array-like
The shape of the new dimensions.
observed : scalar, array-like
Observations as numpy array or `Variable`.
"""
if not hasattr(observed, "shape"):
observed = pandas_to_array(observed)
ndim_resize = observed.ndim - ndim_implied
resize_shape = tuple(observed.shape[d] for d in range(ndim_resize))
return ndim_resize, resize_shape, observed


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""

Expand Down Expand Up @@ -335,7 +200,7 @@ def __new__(
raise ValueError(
f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!"
)
dims = _convert_dims(dims)
dims = convert_dims(dims)

# Create the RV without specifying testval, because the testval may have a shape
# that only matches after replicating with a size implied by dims (see below).
Expand All @@ -346,9 +211,9 @@ def __new__(
# `dims` are only available with this API, because `.dist()` can be used
# without a modelcontext and dims are not tracked at the Aesara level.
if dims is not None:
ndim_resize, resize_shape, dims = _resize_from_dims(dims, ndim_actual, model)
ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif observed is not None:
ndim_resize, resize_shape, observed = _resize_from_observed(observed, ndim_actual)
ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual)

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
Expand Down Expand Up @@ -398,65 +263,27 @@ def dist(
raise ValueError(
f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!"
)
shape = _convert_shape(shape)
size = _convert_size(size)

ndim_supp = cls.rv_op.ndim_supp
ndim_expected = None
ndim_batch = None
create_size = None

if shape is not None:
if Ellipsis in shape:
# Ellipsis short-hands all implied dimensions. Therefore
# we don't know how many dimensions to expect.
ndim_expected = ndim_batch = None
# Create the RV with its implied shape and resize later
create_size = None
else:
ndim_expected = len(tuple(shape))
ndim_batch = ndim_expected - ndim_supp
create_size = tuple(shape)[:ndim_batch]
elif size is not None:
ndim_expected = ndim_supp + len(tuple(size))
ndim_batch = ndim_expected - ndim_supp
create_size = size

shape = convert_shape(shape)
size = convert_size(size)

create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp
)
# Create the RV with a `size` right away.
# This is not necessarily the final result.
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
ndim_actual = rv_out.ndim
ndims_unexpected = ndim_actual != ndim_expected

if shape is not None and ndims_unexpected:
if Ellipsis in shape:
# Resize and we're done!
rv_out = change_rv_size(rv_var=rv_out, new_size=shape[:-1], expand=True)
else:
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
# Recreate the RV without passing `size` to created it with just the implied dimensions.
rv_out = cls.rv_op(*dist_params, size=None, **kwargs)

# Now resize by any remaining "extra" dimensions that were not implied from support and parameters
if rv_out.ndim < ndim_expected:
expand_shape = shape[: ndim_expected - rv_out.ndim]
rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True)
if not rv_out.ndim == ndim_expected:
raise ShapeError(
f"Failed to create the RV with the expected dimensionality. "
f"This indicates a severe problem. Please open an issue.",
actual=ndim_actual,
expected=ndim_batch + ndim_supp,
)

# Warn about the edge cases where the RV Op creates more dimensions than
# it should based on `size` and `RVOp.ndim_supp`.
if size is not None and ndims_unexpected:
warnings.warn(
f"You may have expected a ({len(tuple(size))}+{ndim_supp})-dimensional RV, but the resulting RV will be {ndim_actual}-dimensional."
' To silence this warning use `warnings.simplefilter("ignore", pm.ShapeWarning)`.',
ShapeWarning,
)
rv_out = maybe_resize(
rv_out,
cls.rv_op,
dist_params,
ndim_expected,
ndim_batch,
ndim_supp,
shape,
size,
**kwargs,
)

if testval is not None:
rv_out.tag.test_value = testval
Expand Down
Loading

0 comments on commit 669a6e7

Please sign in to comment.