Skip to content

Commit

Permalink
Disable use of Arviz in pymc3.tests.test_data_container
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 22, 2021
1 parent 2b4fd10 commit 6d8c136
Showing 1 changed file with 53 additions and 8 deletions.
61 changes: 53 additions & 8 deletions pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ def test_sample_posterior_predictive_after_set_data(self):
y = pm.Data("y", [1.0, 2.0, 3.0])
beta = pm.Normal("beta", 0, 10.0)
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
trace = pm.sample(1000, tune=1000, chains=1)
trace = pm.sample(
1000,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)
# Predict on new data.
with model:
x_test = [5, 6, 9]
Expand All @@ -86,13 +92,27 @@ def test_sample_after_set_data(self):
y = pm.Data("y", [1.0, 2.0, 3.0])
beta = pm.Normal("beta", 0, 10.0)
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
pm.sample(1000, init=None, tune=1000, chains=1)
pm.sample(
1000,
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)
# Predict on new data.
new_x = [5.0, 6.0, 9.0]
new_y = [5.0, 6.0, 9.0]
with model:
pm.set_data(new_data={"x": new_x, "y": new_y})
new_trace = pm.sample(1000, init=None, tune=1000, chains=1)
new_trace = pm.sample(
1000,
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)
pp_trace = pm.sample_posterior_predictive(new_trace, 1000)

assert pp_trace["obs"].shape == (1000, 3)
Expand All @@ -110,7 +130,14 @@ def test_shared_data_as_index(self):
pm.Normal("obs", alpha[index], np.sqrt(1e-2), observed=y)

prior_trace = pm.sample_prior_predictive(1000, var_names=["alpha"])
trace = pm.sample(1000, init=None, tune=1000, chains=1)
trace = pm.sample(
1000,
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)

# Predict on new data
new_index = np.array([0, 1, 2])
Expand All @@ -132,14 +159,18 @@ def test_shared_data_as_rv_input(self):
with pm.Model() as m:
x = pm.Data("x", [1.0, 2.0, 3.0])
_ = pm.Normal("y", mu=x, size=3)
trace = pm.sample(chains=1)
trace = pm.sample(
chains=1, return_inferencedata=False, compute_convergence_checks=False
)

np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), x.get_value(), atol=1e-1)
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), trace["y"].mean(0), atol=1e-1)

with m:
pm.set_data({"x": np.array([2.0, 4.0, 6.0])})
trace = pm.sample(chains=1)
trace = pm.sample(
chains=1, return_inferencedata=False, compute_convergence_checks=False
)

np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1)
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1)
Expand Down Expand Up @@ -175,7 +206,14 @@ def test_set_data_to_non_data_container_variables(self):
y = np.array([1.0, 2.0, 3.0])
beta = pm.Normal("beta", 0, 10.0)
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
pm.sample(1000, init=None, tune=1000, chains=1)
pm.sample(
1000,
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)
with pytest.raises(TypeError) as error:
pm.set_data({"beta": [1.1, 2.2, 3.3]}, model=model)
error.match("defined as `pymc3.Data` inside the model")
Expand All @@ -188,7 +226,14 @@ def test_model_to_graphviz_for_model_with_data_container(self):
beta = pm.Normal("beta", 0, 10.0)
obs_sigma = floatX(np.sqrt(1e-2))
pm.Normal("obs", beta * x, obs_sigma, observed=y)
pm.sample(1000, init=None, tune=1000, chains=1)
pm.sample(
1000,
init=None,
tune=1000,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)

for formatting in {"latex", "latex_with_params"}:
with pytest.raises(ValueError, match="Unsupported formatting"):
Expand Down

0 comments on commit 6d8c136

Please sign in to comment.