Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ambiguous conditioning on TransformedDistributions #108

Merged
merged 9 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 150 additions & 15 deletions causal_pyro/counterfactual/conditioning.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,48 @@
from typing import Any, Dict, Literal, Optional, TypedDict, TypeVar
from typing import Any, Dict, Literal, Optional, TypedDict, TypeVar, Union

import pyro
import pyro.distributions as dist
import pyro.infer.reparam
import torch

from causal_pyro.counterfactual.internals import expand_obs_value_inplace_
from causal_pyro.counterfactual.selection import (
SelectCounterfactual,
SelectFactual,
get_factual_indices,
)
from causal_pyro.primitives import scatter
from causal_pyro.primitives import gather, indices_of, scatter, union

T = TypeVar("T")


def site_is_ambiguous(msg: Dict[str, Any]) -> bool:
"""
Helper function used with :func:`pyro.condition` to determine
whether a site is observed or ambiguous.

A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).
"""
if not (
msg["type"] == "sample"
and msg["is_observed"]
and not pyro.poutine.util.site_is_subsample(msg)
):
return False

try:
return not msg["infer"]["_specified_conditioning"]
except KeyError:
value_indices = indices_of(msg["value"], event_dim=len(msg["fn"].event_shape))
dist_indices = indices_of(msg["fn"])
return (
bool(union(value_indices, dist_indices)) and value_indices != dist_indices
)


def no_ambiguity(msg: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper function used with :func:`pyro.poutine.infer_config` to inform
Expand All @@ -40,6 +68,24 @@ class AmbiguousConditioningStrategy(pyro.infer.reparam.strategies.Strategy):
pass


CondStrategy = Union[
Dict[str, AmbiguousConditioningReparam], AmbiguousConditioningStrategy
]


class AmbiguousConditioningReparamMessenger(
pyro.poutine.reparam_messenger.ReparamMessenger
):
config: CondStrategy

def _pyro_sample(self, msg: pyro.infer.reparam.reparam.ReparamMessage) -> None:
if site_is_ambiguous(msg):
expand_obs_value_inplace_(msg)
msg["infer"]["_specified_conditioning"] = False
super()._pyro_sample(msg)
msg["infer"]["_specified_conditioning"] = True


class ConditionReparamMsg(TypedDict):
fn: pyro.distributions.Distribution
value: torch.Tensor
Expand Down Expand Up @@ -78,31 +124,120 @@ def apply(self, msg: ConditionReparamArgMsg) -> ConditionReparamMsg:
return {"fn": new_fn, "value": new_value, "is_observed": True}


class AutoFactualConditioning(AmbiguousConditioningStrategy):
class MinimalFactualConditioning(AmbiguousConditioningStrategy):
"""
Default strategy for handling ambiguity in conditioning, for use with
Reparameterization strategy for handling ambiguity in conditioning, for use with
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` .

This strategy automatically applies :class:`FactualConditioningReparam` to
all sites that are observed and are downstream of an intervention, provided
that the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).
:class:`MinimalFactualConditioning` applies :class:`FactualConditioningReparam`
instances to all ambiguous sample sites in a model.

.. note::

A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).

"""

def configure(
self, msg: pyro.infer.reparam.reparam.ReparamMessage
) -> Optional[FactualConditioningReparam]:
if not site_is_ambiguous(msg):
return None

return FactualConditioningReparam()


class ConditionTransformReparamMsg(TypedDict):
fn: pyro.distributions.TransformedDistribution
value: torch.Tensor
is_observed: Literal[True]


class ConditionTransformReparamArgMsg(ConditionTransformReparamMsg):
name: str


class ConditionTransformReparam(AmbiguousConditioningReparam):
def apply(
self, msg: ConditionTransformReparamArgMsg
) -> ConditionTransformReparamMsg:
name, fn, value = msg["name"], msg["fn"], msg["value"]

tfm = (
fn.transforms[-1]
if len(fn.transforms) == 1
else dist.transforms.ComposeTransformModule(fn.transforms)
)
noise_dist = fn.base_dist
noise_event_dim = len(noise_dist.event_shape)
obs_event_dim = len(fn.event_shape)

# factual world
with SelectFactual(), pyro.poutine.infer_config(config_fn=no_ambiguity):
new_base_dist = dist.Delta(value, event_dim=obs_event_dim).mask(False)
new_noise_dist = dist.TransformedDistribution(new_base_dist, tfm.inv)
obs_noise = pyro.sample(
name + "_noise_likelihood", new_noise_dist, obs=tfm.inv(value)
)

# depends on strategy and indices of noise_dist
fw = get_factual_indices()
obs_noise = gather(obs_noise, fw, event_dim=noise_event_dim).expand(
obs_noise.shape
)
obs_noise = pyro.sample(name + "_noise_prior", noise_dist, obs=obs_noise)

# counterfactual world
with SelectCounterfactual(), pyro.poutine.infer_config(config_fn=no_ambiguity):
cf_noise_dist = dist.Delta(obs_noise, event_dim=noise_event_dim).mask(False)
cf_obs_dist = dist.TransformedDistribution(cf_noise_dist, tfm)
cf_obs_value = pyro.sample(name + "_cf_obs", cf_obs_dist)

# merge
new_value = scatter(
value, fw, result=cf_obs_value.clone(), event_dim=obs_event_dim
)
new_fn = dist.Delta(new_value, event_dim=obs_event_dim).mask(False)
return {"fn": new_fn, "value": new_value, "is_observed": msg["is_observed"]}


class AutoFactualConditioning(MinimalFactualConditioning):
"""
Reparameterization strategy for handling ambiguity in conditioning, for use with
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` .

When the distribution is a :class:`pyro.distributions.TransformedDistribution`,
:class:`AutoFactualConditioning` automatically applies :class:`ConditionTransformReparam`
to the site. Otherwise, it behaves like :class:`MinimalFactualConditioning` .

.. note::

This strategy is applied by default via :class:`MultiWorldCounterfactual`
and :class:`TwinWorldCounterfactual` unless otherwise specified.

.. note::

A sample site is ambiguous if it is marked observed, is downstream of an intervention,
and the observed value's index variables are a strict subset of the distribution's
indices and hence require clarification of which entries of the random variable
are fixed/observed (as opposed to random/unobserved).

"""

def configure(
self, msg: pyro.infer.reparam.reparam.ReparamMessage
) -> Optional[FactualConditioningReparam]:
if (
not msg["is_observed"]
or pyro.poutine.util.site_is_subsample(msg)
or msg["infer"].get("_specified_conditioning", False)
):
if not site_is_ambiguous(msg):
return None

return FactualConditioningReparam()
fn = msg["fn"]
while hasattr(fn, "base_dist"):
if isinstance(fn, dist.TransformedDistribution):
return ConditionTransformReparam()
else:
fn = fn.base_dist

return super().configure(msg)
15 changes: 4 additions & 11 deletions causal_pyro/counterfactual/handlers.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
from typing import Any, Dict, Optional, Union

from causal_pyro.counterfactual.conditioning import (
AmbiguousConditioningReparam,
AmbiguousConditioningStrategy,
AmbiguousConditioningReparamMessenger,
AutoFactualConditioning,
CondStrategy,
)
from causal_pyro.counterfactual.internals import (
ExpandReparamMessenger,
IndexPlatesMessenger,
)
from causal_pyro.counterfactual.internals import IndexPlatesMessenger
from causal_pyro.primitives import IndexSet, scatter

CondStrategy = Union[
Dict[str, AmbiguousConditioningReparam], AmbiguousConditioningStrategy
]


class BaseCounterfactual(ExpandReparamMessenger):
class BaseCounterfactual(AmbiguousConditioningReparamMessenger):
"""
Base class for counterfactual handlers.
"""
Expand Down
42 changes: 11 additions & 31 deletions causal_pyro/counterfactual/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _pyro_add_indices(self, msg):
), f"cannot add {name}={indices} to {self.plates[name].size}"


class ExpandReparamMessenger(pyro.poutine.reparam_messenger.ReparamMessenger):
def expand_obs_value_inplace_(msg: pyro.infer.reparam.reparam.ReparamMessage) -> None:
"""
Slightly gross workaround that mutates the msg in place
to avoid triggering overzealous validation logic in
Expand All @@ -317,36 +317,16 @@ class ExpandReparamMessenger(pyro.poutine.reparam_messenger.ReparamMessenger):
the observed entries, it just packs counterfactual values around them;
the equality check being approximated by that logic would still pass.
"""

def _pyro_sample(self, msg: pyro.infer.reparam.reparam.ReparamMessage) -> None:
if msg["is_observed"]:
value_indices = indices_of(
msg["value"], event_dim=len(msg["fn"].event_shape)
)
dist_indices = indices_of(msg["fn"])
if not union(value_indices, dist_indices) or value_indices == dist_indices:
# not ambiguous
msg["infer"]["_specified_conditioning"] = msg["infer"].get(
"_specified_conditioning", True
)

if (
not msg["infer"].get("_specified_conditioning", False)
and msg["value"] is not None
and not pyro.poutine.util.site_is_subsample(msg)
):
msg["value"] = torch.as_tensor(msg["value"])
msg["infer"]["orig_shape"] = msg["value"].shape
_custom_init = getattr(msg["value"], "_pyro_custom_init", False)
msg["value"] = msg["value"].expand(
torch.broadcast_shapes(
msg["fn"].batch_shape + msg["fn"].event_shape,
msg["value"].shape,
)
)
setattr(msg["value"], "_pyro_custom_init", _custom_init)

return super()._pyro_sample(msg)
msg["value"] = torch.as_tensor(msg["value"])
msg["infer"]["orig_shape"] = msg["value"].shape
_custom_init = getattr(msg["value"], "_pyro_custom_init", False)
msg["value"] = msg["value"].expand(
torch.broadcast_shapes(
msg["fn"].batch_shape + msg["fn"].event_shape,
msg["value"].shape,
)
)
setattr(msg["value"], "_pyro_custom_init", _custom_init)


def get_sample_msg_device(
Expand Down
90 changes: 90 additions & 0 deletions tests/test_conditioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import logging

import pyro
import pyro.distributions as dist
import pytest
import torch

from causal_pyro.counterfactual.handlers import (
Factual,
MultiWorldCounterfactual,
TwinWorldCounterfactual,
)
from causal_pyro.counterfactual.selection import SelectCounterfactual, SelectFactual
from causal_pyro.primitives import indices_of
from causal_pyro.query.do_messenger import do

logger = logging.getLogger(__name__)


@pytest.mark.parametrize(
"cf_class", [MultiWorldCounterfactual, TwinWorldCounterfactual]
)
@pytest.mark.parametrize("cf_dim", [-1, -2, -3])
@pytest.mark.parametrize("event_shape", [(), (4,), (4, 3)])
def test_ambiguous_conditioning_transform(cf_class, cf_dim, event_shape):
event_dim = len(event_shape)

def model():
# x
# / \
# v v
# y --> z
X = pyro.sample(
"x",
dist.TransformedDistribution(
dist.Normal(0.0, 1).expand(event_shape).to_event(event_dim),
[dist.transforms.ExpTransform()],
),
)
Y = pyro.sample(
"y",
dist.TransformedDistribution(
dist.Normal(0.0, 1).expand(event_shape).to_event(event_dim),
[dist.transforms.AffineTransform(X, 1.0, event_dim=event_dim)],
),
)
Z = pyro.sample(
"z",
dist.TransformedDistribution(
dist.Normal(0.0, 1).expand(event_shape).to_event(event_dim),
[
dist.transforms.AffineTransform(
0.3 * X + 0.7 * Y, 1.0, event_dim=event_dim
)
],
),
)
return X, Y, Z

observations = {
"z": torch.full(event_shape, 1.0),
"x": torch.full(event_shape, 1.1),
"y": torch.full(event_shape, 1.3),
}
interventions = {
"z": torch.full(event_shape, 0.5),
"x": torch.full(event_shape, 0.6),
}

queried_model = pyro.condition(data=observations)(do(actions=interventions)(model))
cf_handler = cf_class(cf_dim)

with Factual():
obs_tr = pyro.poutine.trace(queried_model).get_trace()
obs_log_prob = obs_tr.log_prob_sum()

with cf_handler, SelectCounterfactual():
cf_tr = pyro.poutine.trace(queried_model).get_trace()
cf_log_prob = cf_tr.log_prob_sum()

with cf_handler, SelectFactual():
fact_tr = pyro.poutine.trace(queried_model).get_trace()
fact_log_prob = fact_tr.log_prob_sum()

assert set(obs_tr.nodes.keys()) < set(cf_tr.nodes.keys())
assert set(fact_tr.nodes.keys()) == set(cf_tr.nodes.keys())
assert cf_log_prob != 0.0
assert fact_log_prob != 0.0
assert cf_log_prob != fact_log_prob
assert torch.allclose(obs_log_prob, fact_log_prob)