Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port InferenceData conversion code to pymc3 codebase #4489

Merged
merged 7 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/arviz_compat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ on:

jobs:
pytest:
if: false
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
Expand Down
7 changes: 6 additions & 1 deletion pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def __set_compiler_flags():

from pymc3 import gp, ode, sampling
from pymc3.aesaraf import *
from pymc3.backends import load_trace, save_trace
from pymc3.backends import (
load_trace,
predictions_to_inference_data,
save_trace,
to_inference_data,
)
from pymc3.backends.tracetab import *
from pymc3.blocking import *
from pymc3.data import *
Expand Down
26 changes: 25 additions & 1 deletion pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from aesara import scalar
from aesara import tensor as aet
from aesara.gradient import grad
from aesara.graph.basic import Apply, graph_inputs
from aesara.graph.basic import Apply, Constant, graph_inputs
from aesara.graph.op import Op
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.sharedvar import SharedVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
from aesara.tensor.var import TensorVariable

from pymc3.data import GeneratorAdapter
Expand All @@ -48,6 +50,28 @@
]


def extract_obs_data(x: TensorVariable) -> np.ndarray:
"""Extract data observed symbolic variables.

Raises
------
TypeError

"""
if isinstance(x, Constant):
return x.data
if isinstance(x, SharedVariable):
return x.get_value()
if x.owner and isinstance(x.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
array_data = extract_obs_data(x.owner.inputs[0])
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:])
mask = np.zeros_like(array_data)
mask[mask_idx] = 1
return np.ma.MaskedArray(array_data, mask)

raise TypeError(f"Data cannot be extracted from {x}")


def inputvars(a):
"""
Get the inputs into a aesara variables
Expand Down
1 change: 1 addition & 0 deletions pymc3/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
Saved backends can be loaded using `arviz.from_netcdf`

"""
from pymc3.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc3.backends.ndarray import (
NDArray,
load_trace,
Expand Down
Loading