Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding undo_split and a test thereof (+small lint) #264

Merged
merged 23 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions chirho/counterfactual/handlers/explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import itertools
from typing import Callable, Iterable, TypeVar

from chirho.indexed.ops import IndexSet, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")


def undo_split(antecedents: Iterable[str] = [], event_dim: int = 0) -> Callable[[T], T]:
"""
A helper function that undoes an upstream :func:`~chirho.counterfactual.ops.split` operation,
meant to meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` ,
:func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt` .
Works by gathering the factual value and scattering it back into two alternative cases.

:param antecedents: A list of upstream intervened sites which induced the :func:`split` to be reversed.
:param event_dim: The event dimension of the value to be preempted.
:return: A callable that applied to a site value object returns a site value object in which
the factual value has been scattered back into two alternative cases.
"""

def _undo_split(value: T) -> T:
antecedents_ = [
a for a in antecedents if a in indices_of(value, event_dim=event_dim)
]

factual_value = gather(
value,
IndexSet(**{antecedent: {0} for antecedent in antecedents_}),
event_dim=event_dim,
)

# TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply
return scatter(
{
IndexSet(
**{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)}
): factual_value
for inds in itertools.product(*[[0, 1]] * len(antecedents_))
},
event_dim=event_dim,
)

return _undo_split
4 changes: 4 additions & 0 deletions docs/source/counterfactual.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ Handlers
:members:
:undoc-members:

.. automodule:: chirho.counterfactual.handlers.explanation
:members:
:undoc-members:

Internals
---------

Expand Down
165 changes: 165 additions & 0 deletions tests/counterfactual/test_handlers_explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import pyro
import pyro.distributions as dist
import pyro.infer
import pytest
import torch

from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.counterfactual.handlers.explanation import undo_split
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of


def test_undo_split():
with MultiWorldCounterfactual():
x_obs = torch.zeros(10)
x_cf_1 = torch.ones(10)
x_cf_2 = 2 * x_cf_1
x_split = split(x_obs, (x_cf_1,), name="split1")
x_split = split(x_split, (x_cf_2,), name="split2")

undo_split2 = undo_split(antecedents=["split2"])
x_undone = undo_split2(x_split)

assert indices_of(x_split) == indices_of(x_undone)
assert torch.all(gather(x_split, IndexSet(split2={0})) == x_undone)


@pytest.mark.parametrize("plate_size", [4, 50, 200])
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)])
def test_undo_split_parametrized(event_shape, plate_size):
joint_dims = torch.Size([plate_size, *event_shape])

replace1 = torch.ones(joint_dims)
preemption_tensor = replace1 * 5
case = torch.randint(0, 2, size=joint_dims)

@pyro.plate("data", size=plate_size, dim=-1)
def model():
w = pyro.sample(
"w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape))
)
w = split(w, (replace1,), name="split1")

w = pyro.deterministic(
"w_preempted", preempt(w, preemption_tensor, case, name="w_preempted")
)

w = pyro.deterministic("w_undone", undo_split(antecedents=["split1"])(w))

with MultiWorldCounterfactual() as mwc:
with pyro.poutine.trace() as tr:
model()

nd = tr.trace.nodes

with mwc:
assert indices_of(nd["w_undone"]["value"]) == IndexSet(split1={0, 1})

w_undone_shape = list(nd["w_undone"]["value"].shape)
desired_shape = list(
(2,)
+ (1,) * (len(w_undone_shape) - len(event_shape) - 2)
+ (plate_size,)
+ event_shape
)
assert w_undone_shape == desired_shape

cf_values = gather(nd["w_undone"]["value"], IndexSet(split1={1})).squeeze()
observed_values = gather(
nd["w_undone"]["value"], IndexSet(split1={0})
).squeeze()

preempted_values = cf_values[case == 1.0]
reverted_values = cf_values[case == 0.0]
picked_values = observed_values[case == 0.0]

assert torch.all(preempted_values == 5.0)
assert torch.all(reverted_values == picked_values)


def test_undo_split_with_interaction():
def model():
x = pyro.sample("x", dist.Delta(torch.tensor(1.0)))

x_split = pyro.deterministic(
"x_split",
split(x, (torch.tensor(0.5),), name="x_split", event_dim=0),
event_dim=0,
)

x_undone = pyro.deterministic(
"x_undone", undo_split(antecedents=["x_split"])(x_split), event_dim=0
)

x_case = torch.tensor(1)
x_preempted = pyro.deterministic(
"x_preempted",
preempt(
x_undone, (torch.tensor(5.0),), x_case, name="x_preempted", event_dim=0
),
event_dim=0,
)

x_undone_2 = pyro.deterministic(
"x_undone_2", undo_split(antecedents=["x"])(x_preempted), event_dim=0
)

x_split2 = pyro.deterministic(
"x_split2",
split(x_undone_2, (torch.tensor(2.0),), name="x_split2", event_dim=0),
event_dim=0,
)

x_undone_3 = pyro.deterministic(
"x_undone_3",
undo_split(antecedents=["x_split", "x_split2"])(x_split2),
event_dim=0,
)

return x_undone_3

with MultiWorldCounterfactual() as mwc:
with pyro.poutine.trace() as tr:
model()

nd = tr.trace.nodes

with mwc:
x_split_2 = nd["x_split2"]["value"]
x_00 = gather(
x_split_2, IndexSet(x_split={0}, x_split2={0}), event_dim=0
) # 5.0
x_10 = gather(
x_split_2, IndexSet(x_split={1}, x_split2={0}), event_dim=0
) # 5.0
x_01 = gather(
x_split_2, IndexSet(x_split={0}, x_split2={1}), event_dim=0
) # 2.0
x_11 = gather(
x_split_2, IndexSet(x_split={1}, x_split2={1}), event_dim=0
) # 2.0

assert (
nd["x_split"]["value"][0].item() == 1.0
and nd["x_split"]["value"][1].item() == 0.5
)

assert (
nd["x_undone"]["value"][0].item() == 1.0
and nd["x_undone"]["value"][1].item() == 1.0
)

assert (
nd["x_preempted"]["value"][0].item() == 5.0
and nd["x_preempted"]["value"][1].item() == 5.0
)

assert (
nd["x_undone_2"]["value"][0].item() == 5.0
and nd["x_undone_2"]["value"][1].item() == 5.0
)

assert torch.all(nd["x_undone_3"]["value"] == 5.0)

assert (x_00, x_10, x_01, x_11) == (5.0, 5.0, 2.0, 2.0)
Loading