Skip to content

Commit

Permalink
Remove size from Distribution signatures (#5788)
Browse files Browse the repository at this point in the history
* Remove size from .dist() signature

Closes #5754

* Align (Half)Flat signatures with superclass

* Don't mention size in the docstring

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>

* Revert changes in `rng_fn` signatures

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
michaelosthege and ricardoV94 authored May 23, 2022
1 parent d0af6b1 commit b5a5b56
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
8 changes: 4 additions & 4 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)

@classmethod
def dist(cls, *, size=None, **kwargs):
res = super().dist([], size=size, **kwargs)
def dist(cls, **kwargs):
res = super().dist([], **kwargs)
return res

def moment(rv, size):
Expand Down Expand Up @@ -432,8 +432,8 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)

@classmethod
def dist(cls, *, size=None, **kwargs):
res = super().dist([], size=size, **kwargs)
def dist(cls, **kwargs):
res = super().dist([], **kwargs)
return res

def moment(rv, size):
Expand Down
11 changes: 6 additions & 5 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def __new__(
transform : optional
See ``Model.register_rv``.
**kwargs
Keyword arguments that will be forwarded to ``.dist()``.
Most prominently: ``shape`` and ``size``
Keyword arguments that will be forwarded to ``.dist()`` or the Aesara RV Op.
Most prominently: ``shape`` for ``.dist()`` or ``dtype`` for the Op.
Returns
-------
Expand Down Expand Up @@ -298,7 +298,6 @@ def dist(
dist_params,
*,
shape: Optional[Shape] = None,
size: Optional[Size] = None,
**kwargs,
) -> RandomVariable:
"""Creates a RandomVariable corresponding to the `cls` distribution.
Expand All @@ -312,8 +311,9 @@ def dist(
An Ellipsis (...) may be inserted in the last position to short-hand refer to
all the dimensions that the RV would get if no shape/size/dims were passed at all.
size : int, tuple, Variable, optional
For creating the RV like in Aesara/NumPy.
**kwargs
Keyword arguments that will be forwarded to the Aesara RV Op.
Most prominently: ``size`` or ``dtype``.
Returns
-------
Expand All @@ -337,6 +337,7 @@ def dist(

if "dims" in kwargs:
raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.")
size = kwargs.pop("size", None)
if shape is not None and size is not None:
raise ValueError(
f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!"
Expand Down

0 comments on commit b5a5b56

Please sign in to comment.