Skip to content

Commit

Permalink
Add handlers for masking indexed random variables (#97)
Browse files Browse the repository at this point in the history
* selection

* all tests pass

* remove xfail

* nit

* suppress warning

* nit

* rename selection handlers

* Add reparameterizers for resolving conditioning ambiguity

* move to counterfactual

* rename factual_conditioning -> conditioning

* configuration for deepscm

* move hack into internals

* Include obs weight

* skip when not ambiguous

* Squashed commit of the following:

commit 6222760
Author: Eli <eli@basis.ai>
Date:   Wed Mar 1 10:05:05 2023 -0500

    import

commit fa2891d
Author: Eli <eli@basis.ai>
Date:   Wed Mar 1 09:50:46 2023 -0500

    simplify api by removing merge

* Squashed commit of the following:

commit b37fb99
Author: Eli <eli@basis.ai>
Date:   Wed Mar 1 10:22:43 2023 -0500

    lint

commit 6222760
Author: Eli <eli@basis.ai>
Date:   Wed Mar 1 10:05:05 2023 -0500

    import

commit fa2891d
Author: Eli <eli@basis.ai>
Date:   Wed Mar 1 09:50:46 2023 -0500

    simplify api by removing merge

* lint

* fix tests

* remove debug

* device inference

* fix rebase

* Rebase

* small bug

* separate out transform conditioning

* docs

* functions

* docs

* docstrings

* lint

* Add tests and refactor to make them pass

* clean up test and docstring

* rename base

* Support ambiguous conditioning on TransformedDistributions (#108)

* Add transform conditioning

* incorporate changes

* format

* test file

* docs

* test passes

* fix bug

* refactor

* move default strategy into handlers

* lint
  • Loading branch information
eb8680 authored Mar 16, 2023
1 parent 58c3d0a commit b92b0b0
Show file tree
Hide file tree
Showing 9 changed files with 659 additions and 31 deletions.
243 changes: 243 additions & 0 deletions causal_pyro/counterfactual/conditioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
from typing import Any, Dict, Literal, Optional, TypedDict, TypeVar, Union

import pyro
import pyro.distributions as dist
import pyro.infer.reparam
import torch

from causal_pyro.counterfactual.internals import expand_obs_value_inplace_
from causal_pyro.counterfactual.selection import (
SelectCounterfactual,
SelectFactual,
get_factual_indices,
)
from causal_pyro.primitives import gather, indices_of, scatter, union

T = TypeVar("T")


def site_is_ambiguous(msg: Dict[str, Any]) -> bool:
"""
Helper function used with :func:`pyro.condition` to determine
whether a site is observed or ambiguous.
A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).
"""
if not (
msg["type"] == "sample"
and msg["is_observed"]
and not pyro.poutine.util.site_is_subsample(msg)
):
return False

try:
return not msg["infer"]["_specified_conditioning"]
except KeyError:
value_indices = indices_of(msg["value"], event_dim=len(msg["fn"].event_shape))
dist_indices = indices_of(msg["fn"])
return (
bool(union(value_indices, dist_indices)) and value_indices != dist_indices
)


def no_ambiguity(msg: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper function used with :func:`pyro.poutine.infer_config` to inform
:class:`AmbiguousConditioningReparam` that all ambiguity in the current
context has been resolved.
"""
return {"_specified_conditioning": True}


class AmbiguousConditioningReparam(pyro.infer.reparam.reparam.Reparam):
"""
Abstract base class for reparameterizers that handle ambiguous conditioning.
"""

pass


class AmbiguousConditioningStrategy(pyro.infer.reparam.strategies.Strategy):
"""
Abstract base class for strategies that handle ambiguous conditioning.
"""

pass


CondStrategy = Union[
Dict[str, AmbiguousConditioningReparam], AmbiguousConditioningStrategy
]


class AmbiguousConditioningReparamMessenger(
pyro.poutine.reparam_messenger.ReparamMessenger
):
config: CondStrategy

def _pyro_sample(self, msg: pyro.infer.reparam.reparam.ReparamMessage) -> None:
if site_is_ambiguous(msg):
expand_obs_value_inplace_(msg)
msg["infer"]["_specified_conditioning"] = False
super()._pyro_sample(msg)
msg["infer"]["_specified_conditioning"] = True


class ConditionReparamMsg(TypedDict):
fn: pyro.distributions.Distribution
value: torch.Tensor
is_observed: Literal[True]


class ConditionReparamArgMsg(ConditionReparamMsg):
name: str


class FactualConditioningReparam(AmbiguousConditioningReparam):
"""
Factual conditioning reparameterizer.
This :class:`pyro.infer.reparam.reparam.Reparam` is used to resolve inherent
semantic ambiguity in conditioning in the presence of interventions by
splitting the observed value into a factual and counterfactual component,
associating the observed value with the factual random variable,
and sampling the counterfactual random variable from its prior.
"""

@pyro.poutine.infer_config(config_fn=no_ambiguity)
def apply(self, msg: ConditionReparamArgMsg) -> ConditionReparamMsg:
with SelectFactual():
fv = pyro.sample(msg["name"] + "_factual", msg["fn"], obs=msg["value"])

with SelectCounterfactual():
cv = pyro.sample(msg["name"] + "_counterfactual", msg["fn"])

event_dim = len(msg["fn"].event_shape)
fw_indices = get_factual_indices()
new_value: torch.Tensor = scatter(
fv, fw_indices, result=cv.clone(), event_dim=event_dim
)
new_fn = dist.Delta(new_value, event_dim=event_dim).mask(False)
return {"fn": new_fn, "value": new_value, "is_observed": True}


class MinimalFactualConditioning(AmbiguousConditioningStrategy):
"""
Reparameterization strategy for handling ambiguity in conditioning, for use with
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` .
:class:`MinimalFactualConditioning` applies :class:`FactualConditioningReparam`
instances to all ambiguous sample sites in a model.
.. note::
A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).
"""

def configure(
self, msg: pyro.infer.reparam.reparam.ReparamMessage
) -> Optional[FactualConditioningReparam]:
if not site_is_ambiguous(msg):
return None

return FactualConditioningReparam()


class ConditionTransformReparamMsg(TypedDict):
fn: pyro.distributions.TransformedDistribution
value: torch.Tensor
is_observed: Literal[True]


class ConditionTransformReparamArgMsg(ConditionTransformReparamMsg):
name: str


class ConditionTransformReparam(AmbiguousConditioningReparam):
def apply(
self, msg: ConditionTransformReparamArgMsg
) -> ConditionTransformReparamMsg:
name, fn, value = msg["name"], msg["fn"], msg["value"]

tfm = (
fn.transforms[-1]
if len(fn.transforms) == 1
else dist.transforms.ComposeTransformModule(fn.transforms)
)
noise_dist = fn.base_dist
noise_event_dim = len(noise_dist.event_shape)
obs_event_dim = len(fn.event_shape)

# factual world
with SelectFactual(), pyro.poutine.infer_config(config_fn=no_ambiguity):
new_base_dist = dist.Delta(value, event_dim=obs_event_dim).mask(False)
new_noise_dist = dist.TransformedDistribution(new_base_dist, tfm.inv)
obs_noise = pyro.sample(
name + "_noise_likelihood", new_noise_dist, obs=tfm.inv(value)
)

# depends on strategy and indices of noise_dist
fw = get_factual_indices()
obs_noise = gather(obs_noise, fw, event_dim=noise_event_dim).expand(
obs_noise.shape
)
obs_noise = pyro.sample(name + "_noise_prior", noise_dist, obs=obs_noise)

# counterfactual world
with SelectCounterfactual(), pyro.poutine.infer_config(config_fn=no_ambiguity):
cf_noise_dist = dist.Delta(obs_noise, event_dim=noise_event_dim).mask(False)
cf_obs_dist = dist.TransformedDistribution(cf_noise_dist, tfm)
cf_obs_value = pyro.sample(name + "_cf_obs", cf_obs_dist)

# merge
new_value = scatter(
value, fw, result=cf_obs_value.clone(), event_dim=obs_event_dim
)
new_fn = dist.Delta(new_value, event_dim=obs_event_dim).mask(False)
return {"fn": new_fn, "value": new_value, "is_observed": msg["is_observed"]}


class AutoFactualConditioning(MinimalFactualConditioning):
"""
Reparameterization strategy for handling ambiguity in conditioning, for use with
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` .
When the distribution is a :class:`pyro.distributions.TransformedDistribution`,
:class:`AutoFactualConditioning` automatically applies :class:`ConditionTransformReparam`
to the site. Otherwise, it behaves like :class:`MinimalFactualConditioning` .
.. note::
This strategy is applied by default via :class:`MultiWorldCounterfactual`
and :class:`TwinWorldCounterfactual` unless otherwise specified.
.. note::
A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).
"""

def configure(
self, msg: pyro.infer.reparam.reparam.ReparamMessage
) -> Optional[FactualConditioningReparam]:
if not site_is_ambiguous(msg):
return None

fn = msg["fn"]
while hasattr(fn, "base_dist"):
if isinstance(fn, dist.TransformedDistribution):
return ConditionTransformReparam()
else:
fn = fn.base_dist

return super().configure(msg)
20 changes: 16 additions & 4 deletions causal_pyro/counterfactual/handlers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
from typing import Any, Dict

import pyro
from typing import Any, Dict, Optional

from causal_pyro.counterfactual.conditioning import (
AmbiguousConditioningReparamMessenger,
AutoFactualConditioning,
CondStrategy,
)
from causal_pyro.counterfactual.internals import IndexPlatesMessenger
from causal_pyro.primitives import IndexSet, scatter


class BaseCounterfactual(pyro.poutine.messenger.Messenger):
class BaseCounterfactual(AmbiguousConditioningReparamMessenger):
"""
Base class for counterfactual handlers.
"""

def __init__(self, config: Optional[CondStrategy] = None):
if config is None:
config = AutoFactualConditioning()
super().__init__(config=config)

def _pyro_get_index_plates(self, msg: Dict[str, Any]) -> None:
msg["stop"], msg["done"] = True, True
msg["value"] = {}

def _pyro_intervene(self, msg: Dict[str, Any]) -> None:
msg["stop"] = True

Expand Down
Loading

0 comments on commit b92b0b0

Please sign in to comment.