Skip to content

Commit

Permalink
Merge branch 'v4' into reintro_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
twiecki authored May 25, 2021
2 parents ab5f44f + 0970af0 commit 08cbb33
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 10 deletions.
8 changes: 7 additions & 1 deletion pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ def pandas_to_array(data):
if hasattr(data, "to_numpy") and hasattr(data, "isnull"):
# typically, but not limited to pandas objects
vals = data.to_numpy()
mask = data.isnull().to_numpy()
null_data = data.isnull()
if hasattr(null_data, "to_numpy"):
# pandas Series
mask = null_data.to_numpy()
else:
# pandas Index
mask = null_data
if mask.any():
# there are missing values
ret = np.ma.MaskedArray(vals, mask)
Expand Down
27 changes: 23 additions & 4 deletions pymc3/tests/test_aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_change_rv_size():
loc = at.as_tensor_variable([1, 2])
rv = normal(loc=loc)
assert rv.ndim == 1
assert rv.eval().shape == (2,)
assert tuple(rv.shape.eval()) == (2,)

with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
change_rv_size(rv, new_size=[[2, 3]])
Expand All @@ -61,7 +61,7 @@ def test_change_rv_size():

rv_new = change_rv_size(rv, new_size=(3,), expand=True)
assert rv_new.ndim == 2
assert rv_new.eval().shape == (3, 2)
assert tuple(rv_new.shape.eval()) == (3, 2)

# Make sure that the shape used to determine the expanded size doesn't
# depend on the old `RandomVariable`.
Expand All @@ -71,7 +71,7 @@ def test_change_rv_size():

rv_newer = change_rv_size(rv_new, new_size=(4,), expand=True)
assert rv_newer.ndim == 3
assert rv_newer.eval().shape == (4, 3, 2)
assert tuple(rv_newer.shape.eval()) == (4, 3, 2)

# Make sure we avoid introducing a `Cast` by converting the new size before
# constructing the new `RandomVariable`
Expand All @@ -80,7 +80,19 @@ def test_change_rv_size():
rv_newer = change_rv_size(rv, new_size=new_size, expand=False)
assert rv_newer.ndim == 2
assert isinstance(rv_newer.owner.inputs[1], Constant)
assert rv_newer.eval().shape == (4, 3)
assert tuple(rv_newer.shape.eval()) == (4, 3)

rv = normal(0, 1)
new_size = at.as_tensor(np.array([4, 3], dtype="int32"))
rv_newer = change_rv_size(rv, new_size=new_size, expand=True)
assert rv_newer.ndim == 2
assert tuple(rv_newer.shape.eval()) == (4, 3)

rv = normal(0, 1)
new_size = at.as_tensor(2, dtype="int32")
rv_newer = change_rv_size(rv, new_size=new_size, expand=True)
assert rv_newer.ndim == 1
assert tuple(rv_newer.shape.eval()) == (2,)


class TestBroadcasting:
Expand Down Expand Up @@ -436,6 +448,13 @@ def test_pandas_to_array(input_dtype):
assert isinstance(wrapped, TensorVariable)


def test_pandas_to_array_pandas_index():
data = pd.Index([1, 2, 3])
result = pandas_to_array(data)
expected = np.array([1, 2, 3])
np.testing.assert_array_equal(result, expected)


def test_walk_model():
d = at.vector("d")
b = at.vector("b")
Expand Down
8 changes: 4 additions & 4 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,10 +688,10 @@ def test_deterministic_of_observed_modified_interface(self):
meas_in_1 = pm.aesaraf.floatX(2 + 4 * np.random.randn(100))
meas_in_2 = pm.aesaraf.floatX(5 + 4 * np.random.randn(100))
with pm.Model() as model:
mu_in_1 = pm.Normal("mu_in_1", 0, 1)
sigma_in_1 = pm.HalfNormal("sd_in_1", 1)
mu_in_2 = pm.Normal("mu_in_2", 0, 1)
sigma_in_2 = pm.HalfNormal("sd__in_2", 1)
mu_in_1 = pm.Normal("mu_in_1", 0, 1, testval=0)
sigma_in_1 = pm.HalfNormal("sd_in_1", 1, testval=1)
mu_in_2 = pm.Normal("mu_in_2", 0, 1, testval=0)
sigma_in_2 = pm.HalfNormal("sd__in_2", 1, testval=1)

in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1)
in_2 = pm.Normal("in_2", mu_in_2, sigma_in_2, observed=meas_in_2)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
aesara>=2.0.8
aesara>=2.0.9
arviz>=0.11.2
cachetools>=4.2.1
dill
Expand Down

0 comments on commit 08cbb33

Please sign in to comment.