Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a BiasedPreemptions handler #239

Merged
merged 2 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 90 additions & 6 deletions chirho/counterfactual/handlers/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ def _pyro_intervene(msg: Dict[str, Any]) -> None:
@staticmethod
def _pyro_preempt(msg: Dict[str, Any]) -> None:
obs, acts, case = msg["args"]
msg["kwargs"]["name"] = f"__split_{msg['name']}"
if msg["kwargs"].get("name", None) is None:
msg["kwargs"]["name"] = msg["name"]

if case is not None:
return

case_dist = pyro.distributions.Categorical(torch.ones(len(acts) + 1))
case = pyro.sample(msg["kwargs"]["name"], case_dist.mask(False), obs=case)
case = pyro.sample(msg["name"], case_dist.mask(False), obs=case)
msg["args"] = (obs, acts, case)


Expand Down Expand Up @@ -100,23 +105,102 @@ class Preemptions(Generic[T], pyro.poutine.messenger.Messenger):
or one of its subclasses, typically from an auxiliary discrete random variable.

:param actions: A mapping from sample site names to interventions.
:param prefix: Prefix usable for naming any auxiliary random variables.
"""

actions: Mapping[str, Intervention[T]]
prefix: str

def __init__(self, actions: Mapping[str, Intervention[T]]):
def __init__(
self, actions: Mapping[str, Intervention[T]], *, prefix: str = "__split_"
):
self.actions = actions
self.prefix = prefix
super().__init__()

def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a bit of mismatch in the way split and preempt unpack their arguments versus intervene and the structure of the complex Intervention type in chirho.interventional.ops. I removed the annotation here to prevent mypy from causing CI failures over this minor issue, which can be dealt with directly in a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is related to and perhaps subsumed by #26

def _pyro_post_sample(self, msg):
try:
action = self.actions[msg["name"]]
except KeyError:
return
msg["value"] = preempt(
msg["value"],
(action,),
(action,) if not isinstance(action, tuple) else action,
None,
event_dim=len(msg["fn"].event_shape),
name=msg["name"],
name=f"{self.prefix}{msg['name']}",
)


class BiasedPreemptions(pyro.poutine.messenger.Messenger):
"""
Effect handler that applies the operation :func:`~chirho.counterfactual.ops.preempt`
to sample sites in a probabilistic program,
similar to the handler :func:`~chirho.observational.handlers.condition`
for :func:`~chirho.observational.ops.observe` .
or the handler :func:`~chirho.interventional.handlers.do`
for :func:`~chirho.interventional.ops.intervene` .

See the documentation for :func:`~chirho.counterfactual.ops.preempt` for more details.

This handler introduces an auxiliary discrete random variable at each preempted sample site
whose name is the name of the sample site prefixed by ``prefix``, and
whose value is used as the ``case`` argument to :func:`preempt`,
to determine whether the preemption returns the present value of the site
or the new value specified for the site in ``actions``

The distributions of the auxiliary discrete random variables are parameterized by ``bias``.
By default, ``bias == 0`` and the value returned by the sample site is equally likely
to be the factual case (i.e. the present value of the site) or one of the counterfactual cases
(i.e. the new value(s) specified for the site in ``actions``).
When ``0 < bias <= 0.5``, the preemption is less than equally likely to occur.
When ``-0.5 <= bias < 0``, the preemption is more than equally likely to occur.

More specifically, the probability of the factual case is ``0.5 - bias``,
and the probability of each counterfactual case is ``(0.5 + bias) / num_actions``,
where ``num_actions`` is the number of counterfactual actions for the sample site (usually 1).

:param actions: A mapping from sample site names to interventions.
:param bias: The scalar bias towards the factual case. Must be between -0.5 and 0.5.
:param prefix: The prefix for naming the auxiliary discrete random variables.
"""

actions: Mapping[str, Intervention[torch.Tensor]]
bias: float
prefix: str

def __init__(
self,
actions: Mapping[str, Intervention[torch.Tensor]],
*,
bias: float = 0.0,
prefix: str = "__witness_split_",
):
assert -0.5 <= bias <= 0.5, "bias must be between -0.5 and 0.5"
self.actions = actions
self.bias = bias
self.prefix = prefix
super().__init__()

def _pyro_post_sample(self, msg):
try:
action = self.actions[msg["name"]]
except KeyError:
return

action = (action,) if not isinstance(action, tuple) else action
num_actions = len(action) if isinstance(action, tuple) else 1
weights = torch.tensor(
[0.5 - self.bias] + ([(0.5 + self.bias) / num_actions] * num_actions),
device=msg["value"].device,
)
case_dist = pyro.distributions.Categorical(probs=weights)
case = pyro.sample(f"{self.prefix}{msg['name']}", case_dist)

msg["value"] = preempt(
msg["value"],
action,
case,
event_dim=len(msg["fn"].event_shape),
name=f"{self.prefix}{msg['name']}",
)
15 changes: 11 additions & 4 deletions tests/counterfactual/test_counterfactual_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SingleWorldFactual,
TwinWorldCounterfactual,
)
from chirho.counterfactual.handlers.counterfactual import Preemptions
from chirho.counterfactual.handlers.counterfactual import BiasedPreemptions, Preemptions
from chirho.counterfactual.handlers.selection import SelectFactual
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of, union
Expand Down Expand Up @@ -675,14 +675,14 @@ def model():

@pytest.mark.parametrize("cf_dim", [-2, -3, None])
@pytest.mark.parametrize("event_shape", [(), (4,), (4, 3)])
def test_cf_handler_preemptions(cf_dim, event_shape):
@pytest.mark.parametrize("use_biased_preemption", [False, True])
def test_cf_handler_preemptions(cf_dim, event_shape, use_biased_preemption):
event_dim = len(event_shape)

splits = {"x": torch.tensor(0.0)}
preemptions = {"y": torch.tensor(1.0)}

@do(actions=splits)
@Preemptions(actions=preemptions)
@pyro.plate("data", size=1000, dim=-1)
def model():
w = pyro.sample(
Expand All @@ -693,7 +693,14 @@ def model():
z = pyro.sample("z", dist.Normal(x + y, 1).to_event(len(event_shape)))
return dict(w=w, x=x, y=y, z=z)

with MultiWorldCounterfactual(cf_dim):
if use_biased_preemption:
preemption_handler = BiasedPreemptions(
actions=preemptions, bias=0.1, prefix="__split_"
)
else:
preemption_handler = Preemptions(actions=preemptions)

with MultiWorldCounterfactual(cf_dim), preemption_handler:
tr = pyro.poutine.trace(model).get_trace()
assert all(f"__split_{k}" in tr.nodes for k in preemptions.keys())
assert indices_of(tr.nodes["w"]["value"], event_dim=event_dim) == IndexSet()
Expand Down