Skip to content

Commit

Permalink
Adding undo_split and a test thereof (+small lint) (#264)
Browse files Browse the repository at this point in the history
* added undo_split and a test thereof (+small lint)

* moved undo_split, added a test

* unlinting counterfactual

* unlinting ops

* parametrize test WIP

* parametrizing test WIP

* reverting linting on test_internals.p

* isort fix on test_internals.py

* parametrizing test WIP

* linting WIP

* tests linted

* cleanup

* removed implicit Optional

* isort/lint test_internals.py

* reverting test_internals to orginal

* black linted no -l 120, to prevent github check failure

* make lint, make format

* format

* fix multi-antecedent case and format

* remove obsolete comments

* appease flake8

* add to sphinx

---------

Co-authored-by: Eli <eli@basis.ai>
  • Loading branch information
rfl-urbaniak and eb8680 committed Nov 29, 2023
1 parent 8eb2f76 commit d9ee12d
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 0 deletions.
45 changes: 45 additions & 0 deletions chirho/counterfactual/handlers/explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import itertools
from typing import Callable, Iterable, TypeVar

from chirho.indexed.ops import IndexSet, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")


def undo_split(antecedents: Iterable[str] = [], event_dim: int = 0) -> Callable[[T], T]:
"""
A helper function that undoes an upstream :func:`~chirho.counterfactual.ops.split` operation,
meant to meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` ,
:func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt` .
Works by gathering the factual value and scattering it back into two alternative cases.
:param antecedents: A list of upstream intervened sites which induced the :func:`split` to be reversed.
:param event_dim: The event dimension of the value to be preempted.
:return: A callable that applied to a site value object returns a site value object in which
the factual value has been scattered back into two alternative cases.
"""

def _undo_split(value: T) -> T:
antecedents_ = [
a for a in antecedents if a in indices_of(value, event_dim=event_dim)
]

factual_value = gather(
value,
IndexSet(**{antecedent: {0} for antecedent in antecedents_}),
event_dim=event_dim,
)

# TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply
return scatter(
{
IndexSet(
**{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)}
): factual_value
for inds in itertools.product(*[[0, 1]] * len(antecedents_))
},
event_dim=event_dim,
)

return _undo_split
4 changes: 4 additions & 0 deletions docs/source/counterfactual.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ Handlers
:members:
:undoc-members:

.. automodule:: chirho.counterfactual.handlers.explanation
:members:
:undoc-members:

Internals
---------

Expand Down
165 changes: 165 additions & 0 deletions tests/counterfactual/test_handlers_explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import pyro
import pyro.distributions as dist
import pyro.infer
import pytest
import torch

from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.counterfactual.handlers.explanation import undo_split
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of


def test_undo_split():
with MultiWorldCounterfactual():
x_obs = torch.zeros(10)
x_cf_1 = torch.ones(10)
x_cf_2 = 2 * x_cf_1
x_split = split(x_obs, (x_cf_1,), name="split1")
x_split = split(x_split, (x_cf_2,), name="split2")

undo_split2 = undo_split(antecedents=["split2"])
x_undone = undo_split2(x_split)

assert indices_of(x_split) == indices_of(x_undone)
assert torch.all(gather(x_split, IndexSet(split2={0})) == x_undone)


@pytest.mark.parametrize("plate_size", [4, 50, 200])
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)])
def test_undo_split_parametrized(event_shape, plate_size):
joint_dims = torch.Size([plate_size, *event_shape])

replace1 = torch.ones(joint_dims)
preemption_tensor = replace1 * 5
case = torch.randint(0, 2, size=joint_dims)

@pyro.plate("data", size=plate_size, dim=-1)
def model():
w = pyro.sample(
"w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape))
)
w = split(w, (replace1,), name="split1")

w = pyro.deterministic(
"w_preempted", preempt(w, preemption_tensor, case, name="w_preempted")
)

w = pyro.deterministic("w_undone", undo_split(antecedents=["split1"])(w))

with MultiWorldCounterfactual() as mwc:
with pyro.poutine.trace() as tr:
model()

nd = tr.trace.nodes

with mwc:
assert indices_of(nd["w_undone"]["value"]) == IndexSet(split1={0, 1})

w_undone_shape = list(nd["w_undone"]["value"].shape)
desired_shape = list(
(2,)
+ (1,) * (len(w_undone_shape) - len(event_shape) - 2)
+ (plate_size,)
+ event_shape
)
assert w_undone_shape == desired_shape

cf_values = gather(nd["w_undone"]["value"], IndexSet(split1={1})).squeeze()
observed_values = gather(
nd["w_undone"]["value"], IndexSet(split1={0})
).squeeze()

preempted_values = cf_values[case == 1.0]
reverted_values = cf_values[case == 0.0]
picked_values = observed_values[case == 0.0]

assert torch.all(preempted_values == 5.0)
assert torch.all(reverted_values == picked_values)


def test_undo_split_with_interaction():
def model():
x = pyro.sample("x", dist.Delta(torch.tensor(1.0)))

x_split = pyro.deterministic(
"x_split",
split(x, (torch.tensor(0.5),), name="x_split", event_dim=0),
event_dim=0,
)

x_undone = pyro.deterministic(
"x_undone", undo_split(antecedents=["x_split"])(x_split), event_dim=0
)

x_case = torch.tensor(1)
x_preempted = pyro.deterministic(
"x_preempted",
preempt(
x_undone, (torch.tensor(5.0),), x_case, name="x_preempted", event_dim=0
),
event_dim=0,
)

x_undone_2 = pyro.deterministic(
"x_undone_2", undo_split(antecedents=["x"])(x_preempted), event_dim=0
)

x_split2 = pyro.deterministic(
"x_split2",
split(x_undone_2, (torch.tensor(2.0),), name="x_split2", event_dim=0),
event_dim=0,
)

x_undone_3 = pyro.deterministic(
"x_undone_3",
undo_split(antecedents=["x_split", "x_split2"])(x_split2),
event_dim=0,
)

return x_undone_3

with MultiWorldCounterfactual() as mwc:
with pyro.poutine.trace() as tr:
model()

nd = tr.trace.nodes

with mwc:
x_split_2 = nd["x_split2"]["value"]
x_00 = gather(
x_split_2, IndexSet(x_split={0}, x_split2={0}), event_dim=0
) # 5.0
x_10 = gather(
x_split_2, IndexSet(x_split={1}, x_split2={0}), event_dim=0
) # 5.0
x_01 = gather(
x_split_2, IndexSet(x_split={0}, x_split2={1}), event_dim=0
) # 2.0
x_11 = gather(
x_split_2, IndexSet(x_split={1}, x_split2={1}), event_dim=0
) # 2.0

assert (
nd["x_split"]["value"][0].item() == 1.0
and nd["x_split"]["value"][1].item() == 0.5
)

assert (
nd["x_undone"]["value"][0].item() == 1.0
and nd["x_undone"]["value"][1].item() == 1.0
)

assert (
nd["x_preempted"]["value"][0].item() == 5.0
and nd["x_preempted"]["value"][1].item() == 5.0
)

assert (
nd["x_undone_2"]["value"][0].item() == 5.0
and nd["x_undone_2"]["value"][1].item() == 5.0
)

assert torch.all(nd["x_undone_3"]["value"] == 5.0)

assert (x_00, x_10, x_01, x_11) == (5.0, 5.0, 2.0, 2.0)

0 comments on commit d9ee12d

Please sign in to comment.