Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tuple-valued multiple interventions #121

Merged
merged 13 commits into from
Apr 26, 2023
Merged

Allow tuple-valued multiple interventions #121

merged 13 commits into from
Apr 26, 2023

Conversation

eb8680
Copy link
Contributor

@eb8680 eb8680 commented Apr 24, 2023

Addresses #26 and necessary for supporting name-based indexing in #14 and other tutorials

This PR adds support for passing multiple values to intervene in a tuple:

x = intervene(x, (x_cf_1, x_cf_2))

or a function that returns a tuple (for dependent interventions):

x = intervene(x, lambda x: (x + 1, x + 2))

By default, intervene ignores all but the final entry of the tuple:

assert intervene(x, (x_cf_1, x_cf_2)) is x_cf_2

However, when a counterfactual handler is active, it returns all of the values stacked along a new index variable:

with MultiWorldCounterfactual():
    x_cf = intervene(x, (x_cf_a, x_cf_b), name="x")
    assert indices_of(x_cf)              == IndexSet(x={0, 1, 2})
    assert gather(x_cf, IndexSet(x={0})) == x
    assert gather(x_cf, IndexSet(x={1})) == x_cf_a
    assert gather(x_cf, IndexSet(x={2})) == x_cf_b

Getting this PR working also required some minor refactoring of causal_pyro.counterfactual.

Tested:

  • Srengthened two existing unit tests and added new cases to cover new functionality

@eb8680 eb8680 added enhancement New feature or request status:awaiting review Awaiting response from reviewer labels Apr 24, 2023
@eb8680 eb8680 requested a review from cscherrer April 24, 2023 11:57
@eb8680 eb8680 added this to the Initial public release milestone Apr 24, 2023
@eb8680 eb8680 self-assigned this Apr 24, 2023
@cscherrer
Copy link
Contributor

This seems to assume setting a variable to a tuple or lambda is disallowed. Is that right? Maybe Pyro already requires everything to be a float tensor?

@eb8680
Copy link
Contributor Author

eb8680 commented Apr 24, 2023

This seems to assume setting a variable to a tuple or lambda is disallowed. Is that right? Maybe Pyro already requires everything to be a float tensor?

Hmm, I'm not quite sure what you mean. Interventions in causal_pyro are polymorphic and type-preserving, so it's not possible to set variables to completely arbitrary values. Any additional structure in interventions must be accounted for in the type Intervention[T]:

def intervene(obs: T, act: Intervention[T]) -> T: ...

This signature is currently only implemented for a few simple concrete types T because we've been focused on defining counterfactuals polymorphically. For example, there is an implementation for number- and Tensor-valued observations with Tensor-valued interventions or dependent Tensor-valued interventions:

def intervene(obs: Tensor, act: Tensor | Callable[[Tensor], Tensor]) -> Tensor: ...

This PR adds support for (dependent) tuple-valued Interventions on numbers and Tensors:

def intervene(obs: Tensor, act: tuple[Tensor, ...] | Callable[[Tensor], tuple[Tensor, ...]]) -> Tensor: ...

but there is currently no implementation of intervene for the case where T = tuple[float, ...] or T = tuple[Tensor, ...].

@cscherrer
Copy link
Contributor

My concern was that someone could write a bivariate_normal distribution where values are represented as (float, float) pairs. If x follows such a distribution, then it seems like

x = intervene(x, (x1, x2))

would lead to problems, because it wouldn't be clear x should be set to (x1, x2) or if this is intended as a multiple interaction. Or rather, that could be determined, but it depends on the support type of x and the types of x1 and x2, which I'm guessing is not covered here.

If Pyro explicitly disallows tuple-valued distributions, this is a non-issue here. But it's still worth noting that this interface may be very Pyro-specific, since in general distributions can have arbitrary support type.

Thinking about this a little more, I think the real root of the concern is that this approach dramatically changes the semantics of x = intervene(x, x0) based on the type of x0. It may happen to be a tuple or a lambda, in which case this special behavior is triggered. It just seems a little magical.

But I'm not sure what could be better. Maybe keyword arguments? Something like

x = intervene(x, modify=lambda x: (x + 1, x + 2))

Or maybe this is fine in Pyro and we just need to be careful if we generalize it elsewhere.

@eb8680
Copy link
Contributor Author

eb8680 commented Apr 25, 2023

Thinking about this a little more, I think the real root of the concern is that this approach dramatically changes the semantics of x = intervene(x, x0) based on the type of x0. It may happen to be a tuple or a lambda, in which case this special behavior is triggered. It just seems a little magical.

I see - in case it's not clear, intervene already works this way by design, for reasons unrelated to Pyro or Python implementation details.

As discussed in #26, we really do want the type of act to depend on the type of obs, and for the behavior of intervene to depend on the types of both, subject to the constraint that the result type matches the type of obs (i.e. intervene has type Callable[[T, Intervention[T]], T], as in the previous comment).

The type Intervention[T] describing the additional structure allowed in interventions in valid implementations of intervene can be written as a Python sum type:

Intervention = T | Callable[[T], T] | tuple[T, ...] | Callable[[T], tuple[T, ...]] | ...

For example, the following behavior is valid (i.e. the signature is a subtype of Callable[[T, Intervention[T]], T]) and is already implemented, although it is not actually separated so cleanly because of Python's lack of support for dispatch on generic types:

# valid - dependent interventions
def intervene(obs: T, act: Callable[[T], T]) -> T:
  return act(obs)

The following is a subtype of this signature (by the substitution T = Callable[..., T]) and hence a valid subtype of Callable[[T, Intervention[T]], T]:

# valid - can replace a function with a function of that function
def intervene(obs: Callable[..., T], act: Callable[[Callable[..., T]], Callable[..., T]]) -> Callable[..., T]: ...

The following is not a subtype of Callable[[T, Callable[[T], T]], T], but it is still a valid subtype of Callable[[T, Intervention[T]], T] (by the substitution T = Callable[..., T]):

# valid - can replace a function with another function of the same type
def intervene(obs: Callable[..., T], act: Callable[..., T]) -> Callable[..., T]: ...

By contrast, the following signatures are not valid subtypes of Callable[[T, Intervention[T]], T] and cannot describe valid implementations of intervene:

# invalid - cannot replace a function with a constant
def intervene(obs: Callable[..., T], act: T) -> T: ...

# invalid - cannot replace a constant with a function
def intervene(obs: T, act: Callable[..., T]) -> Callable[..., T]: ...

@cscherrer cscherrer merged commit bf2b23a into master Apr 26, 2023
@cscherrer cscherrer deleted the eb-split-op branch April 26, 2023 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request status:awaiting review Awaiting response from reviewer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants