-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding undo_split and a test thereof (+small lint) (#264)
* added undo_split and a test thereof (+small lint) * moved undo_split, added a test * unlinting counterfactual * unlinting ops * parametrize test WIP * parametrizing test WIP * reverting linting on test_internals.p * isort fix on test_internals.py * parametrizing test WIP * linting WIP * tests linted * cleanup * removed implicit Optional * isort/lint test_internals.py * reverting test_internals to orginal * black linted no -l 120, to prevent github check failure * make lint, make format * format * fix multi-antecedent case and format * remove obsolete comments * appease flake8 * add to sphinx --------- Co-authored-by: Eli <eli@basis.ai>
- Loading branch information
1 parent
8eb2f76
commit d9ee12d
Showing
3 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import itertools | ||
from typing import Callable, Iterable, TypeVar | ||
|
||
from chirho.indexed.ops import IndexSet, gather, indices_of, scatter | ||
|
||
S = TypeVar("S") | ||
T = TypeVar("T") | ||
|
||
|
||
def undo_split(antecedents: Iterable[str] = [], event_dim: int = 0) -> Callable[[T], T]: | ||
""" | ||
A helper function that undoes an upstream :func:`~chirho.counterfactual.ops.split` operation, | ||
meant to meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` , | ||
:func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt` . | ||
Works by gathering the factual value and scattering it back into two alternative cases. | ||
:param antecedents: A list of upstream intervened sites which induced the :func:`split` to be reversed. | ||
:param event_dim: The event dimension of the value to be preempted. | ||
:return: A callable that applied to a site value object returns a site value object in which | ||
the factual value has been scattered back into two alternative cases. | ||
""" | ||
|
||
def _undo_split(value: T) -> T: | ||
antecedents_ = [ | ||
a for a in antecedents if a in indices_of(value, event_dim=event_dim) | ||
] | ||
|
||
factual_value = gather( | ||
value, | ||
IndexSet(**{antecedent: {0} for antecedent in antecedents_}), | ||
event_dim=event_dim, | ||
) | ||
|
||
# TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply | ||
return scatter( | ||
{ | ||
IndexSet( | ||
**{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)} | ||
): factual_value | ||
for inds in itertools.product(*[[0, 1]] * len(antecedents_)) | ||
}, | ||
event_dim=event_dim, | ||
) | ||
|
||
return _undo_split |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import pyro | ||
import pyro.distributions as dist | ||
import pyro.infer | ||
import pytest | ||
import torch | ||
|
||
from chirho.counterfactual.handlers import MultiWorldCounterfactual | ||
from chirho.counterfactual.handlers.explanation import undo_split | ||
from chirho.counterfactual.ops import preempt, split | ||
from chirho.indexed.ops import IndexSet, gather, indices_of | ||
|
||
|
||
def test_undo_split(): | ||
with MultiWorldCounterfactual(): | ||
x_obs = torch.zeros(10) | ||
x_cf_1 = torch.ones(10) | ||
x_cf_2 = 2 * x_cf_1 | ||
x_split = split(x_obs, (x_cf_1,), name="split1") | ||
x_split = split(x_split, (x_cf_2,), name="split2") | ||
|
||
undo_split2 = undo_split(antecedents=["split2"]) | ||
x_undone = undo_split2(x_split) | ||
|
||
assert indices_of(x_split) == indices_of(x_undone) | ||
assert torch.all(gather(x_split, IndexSet(split2={0})) == x_undone) | ||
|
||
|
||
@pytest.mark.parametrize("plate_size", [4, 50, 200]) | ||
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)]) | ||
def test_undo_split_parametrized(event_shape, plate_size): | ||
joint_dims = torch.Size([plate_size, *event_shape]) | ||
|
||
replace1 = torch.ones(joint_dims) | ||
preemption_tensor = replace1 * 5 | ||
case = torch.randint(0, 2, size=joint_dims) | ||
|
||
@pyro.plate("data", size=plate_size, dim=-1) | ||
def model(): | ||
w = pyro.sample( | ||
"w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape)) | ||
) | ||
w = split(w, (replace1,), name="split1") | ||
|
||
w = pyro.deterministic( | ||
"w_preempted", preempt(w, preemption_tensor, case, name="w_preempted") | ||
) | ||
|
||
w = pyro.deterministic("w_undone", undo_split(antecedents=["split1"])(w)) | ||
|
||
with MultiWorldCounterfactual() as mwc: | ||
with pyro.poutine.trace() as tr: | ||
model() | ||
|
||
nd = tr.trace.nodes | ||
|
||
with mwc: | ||
assert indices_of(nd["w_undone"]["value"]) == IndexSet(split1={0, 1}) | ||
|
||
w_undone_shape = list(nd["w_undone"]["value"].shape) | ||
desired_shape = list( | ||
(2,) | ||
+ (1,) * (len(w_undone_shape) - len(event_shape) - 2) | ||
+ (plate_size,) | ||
+ event_shape | ||
) | ||
assert w_undone_shape == desired_shape | ||
|
||
cf_values = gather(nd["w_undone"]["value"], IndexSet(split1={1})).squeeze() | ||
observed_values = gather( | ||
nd["w_undone"]["value"], IndexSet(split1={0}) | ||
).squeeze() | ||
|
||
preempted_values = cf_values[case == 1.0] | ||
reverted_values = cf_values[case == 0.0] | ||
picked_values = observed_values[case == 0.0] | ||
|
||
assert torch.all(preempted_values == 5.0) | ||
assert torch.all(reverted_values == picked_values) | ||
|
||
|
||
def test_undo_split_with_interaction(): | ||
def model(): | ||
x = pyro.sample("x", dist.Delta(torch.tensor(1.0))) | ||
|
||
x_split = pyro.deterministic( | ||
"x_split", | ||
split(x, (torch.tensor(0.5),), name="x_split", event_dim=0), | ||
event_dim=0, | ||
) | ||
|
||
x_undone = pyro.deterministic( | ||
"x_undone", undo_split(antecedents=["x_split"])(x_split), event_dim=0 | ||
) | ||
|
||
x_case = torch.tensor(1) | ||
x_preempted = pyro.deterministic( | ||
"x_preempted", | ||
preempt( | ||
x_undone, (torch.tensor(5.0),), x_case, name="x_preempted", event_dim=0 | ||
), | ||
event_dim=0, | ||
) | ||
|
||
x_undone_2 = pyro.deterministic( | ||
"x_undone_2", undo_split(antecedents=["x"])(x_preempted), event_dim=0 | ||
) | ||
|
||
x_split2 = pyro.deterministic( | ||
"x_split2", | ||
split(x_undone_2, (torch.tensor(2.0),), name="x_split2", event_dim=0), | ||
event_dim=0, | ||
) | ||
|
||
x_undone_3 = pyro.deterministic( | ||
"x_undone_3", | ||
undo_split(antecedents=["x_split", "x_split2"])(x_split2), | ||
event_dim=0, | ||
) | ||
|
||
return x_undone_3 | ||
|
||
with MultiWorldCounterfactual() as mwc: | ||
with pyro.poutine.trace() as tr: | ||
model() | ||
|
||
nd = tr.trace.nodes | ||
|
||
with mwc: | ||
x_split_2 = nd["x_split2"]["value"] | ||
x_00 = gather( | ||
x_split_2, IndexSet(x_split={0}, x_split2={0}), event_dim=0 | ||
) # 5.0 | ||
x_10 = gather( | ||
x_split_2, IndexSet(x_split={1}, x_split2={0}), event_dim=0 | ||
) # 5.0 | ||
x_01 = gather( | ||
x_split_2, IndexSet(x_split={0}, x_split2={1}), event_dim=0 | ||
) # 2.0 | ||
x_11 = gather( | ||
x_split_2, IndexSet(x_split={1}, x_split2={1}), event_dim=0 | ||
) # 2.0 | ||
|
||
assert ( | ||
nd["x_split"]["value"][0].item() == 1.0 | ||
and nd["x_split"]["value"][1].item() == 0.5 | ||
) | ||
|
||
assert ( | ||
nd["x_undone"]["value"][0].item() == 1.0 | ||
and nd["x_undone"]["value"][1].item() == 1.0 | ||
) | ||
|
||
assert ( | ||
nd["x_preempted"]["value"][0].item() == 5.0 | ||
and nd["x_preempted"]["value"][1].item() == 5.0 | ||
) | ||
|
||
assert ( | ||
nd["x_undone_2"]["value"][0].item() == 5.0 | ||
and nd["x_undone_2"]["value"][1].item() == 5.0 | ||
) | ||
|
||
assert torch.all(nd["x_undone_3"]["value"] == 5.0) | ||
|
||
assert (x_00, x_10, x_01, x_11) == (5.0, 5.0, 2.0, 2.0) |