Skip to content

Commit

Permalink
Merge pull request #133 from BasisResearch/eb-finish-slc
Browse files Browse the repository at this point in the history
Update text and code of SLC notebook
  • Loading branch information
SamWitty authored Jun 12, 2023
2 parents 370a9c8 + 8123cdf commit 57e8c3b
Show file tree
Hide file tree
Showing 5 changed files with 481 additions and 654 deletions.
6 changes: 2 additions & 4 deletions causal_pyro/counterfactual/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ def _pyro_gen_intervene_name(cls, msg: Dict[str, Any]) -> None:
msg["value"] = name if name is not None else cls.default_name
msg["done"] = True

@staticmethod
@pyro.poutine.block(hide_types=["intervene"])
def _pyro_split(msg: Dict[str, Any]) -> None:
def _pyro_split(self, msg: Dict[str, Any]) -> None:
if msg["done"]:
return

Expand All @@ -73,9 +72,8 @@ class SingleWorldCounterfactual(BaseCounterfactual):
Trivial counterfactual handler that returns the intervened value.
"""

@staticmethod
@pyro.poutine.block(hide_types=["intervene"])
def _pyro_split(msg: Dict[str, Any]) -> None:
def _pyro_split(self, msg: Dict[str, Any]) -> None:
obs, acts = msg["args"]
msg["value"] = intervene(obs, acts[-1], **msg["kwargs"])
msg["done"] = True
Expand Down
4 changes: 3 additions & 1 deletion causal_pyro/counterfactual/handlers/ambiguity.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def configure(

fn = msg["fn"]
while hasattr(fn, "base_dist"):
if isinstance(fn, dist.TransformedDistribution):
if isinstance(fn, dist.FoldedDistribution):
return FactualConditioningReparam()
elif isinstance(fn, dist.TransformedDistribution):
return ConditionTransformReparam()
else:
fn = fn.base_dist
Expand Down
2 changes: 1 addition & 1 deletion causal_pyro/indexed/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _pyro_add_indices(self, msg):
# are still guaranteed to exit safely in the correct order.
self.plates[name] = self._enter_index_plate(
_LazyPlateMessenger(
name="__index_plate__" + name,
name=name,
dim=self.first_available_dim,
size=new_size,
)
Expand Down
8 changes: 7 additions & 1 deletion causal_pyro/indexed/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ def _indices_of_distribution(


class _LazyPlateMessenger(IndepMessenger):
prefix: str = "__index_plate__"

def __init__(self, name: str, *args, **kwargs):
self._orig_name: str = name
super().__init__(f"{self.prefix}_{name}", *args, **kwargs)

@property
def frame(self) -> CondIndepStackFrame:
return CondIndepStackFrame(
Expand All @@ -208,7 +214,7 @@ def frame(self) -> CondIndepStackFrame:
def _process_message(self, msg):
if msg["type"] not in ("sample",) or pyro.poutine.util.site_is_subsample(msg):
return
if self.frame.name in union(
if self._orig_name in union(
indices_of(msg["value"], event_dim=msg["fn"].event_dim),
indices_of(msg["fn"]),
):
Expand Down
1,115 changes: 468 additions & 647 deletions docs/source/slc.ipynb

Large diffs are not rendered by default.

0 comments on commit 57e8c3b

Please sign in to comment.