-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow tuple-valued multiple interventions #121
Conversation
This seems to assume setting a variable to a tuple or lambda is disallowed. Is that right? Maybe Pyro already requires everything to be a float tensor? |
Hmm, I'm not quite sure what you mean. Interventions in def intervene(obs: T, act: Intervention[T]) -> T: ... This signature is currently only implemented for a few simple concrete types def intervene(obs: Tensor, act: Tensor | Callable[[Tensor], Tensor]) -> Tensor: ... This PR adds support for (dependent) tuple-valued def intervene(obs: Tensor, act: tuple[Tensor, ...] | Callable[[Tensor], tuple[Tensor, ...]]) -> Tensor: ... but there is currently no implementation of |
My concern was that someone could write a x = intervene(x, (x1, x2)) would lead to problems, because it wouldn't be clear If Pyro explicitly disallows tuple-valued distributions, this is a non-issue here. But it's still worth noting that this interface may be very Pyro-specific, since in general distributions can have arbitrary support type. Thinking about this a little more, I think the real root of the concern is that this approach dramatically changes the semantics of But I'm not sure what could be better. Maybe keyword arguments? Something like x = intervene(x, modify=lambda x: (x + 1, x + 2)) Or maybe this is fine in Pyro and we just need to be careful if we generalize it elsewhere. |
I see - in case it's not clear, As discussed in #26, we really do want the type of The type Intervention = T | Callable[[T], T] | tuple[T, ...] | Callable[[T], tuple[T, ...]] | ... For example, the following behavior is valid (i.e. the signature is a subtype of # valid - dependent interventions
def intervene(obs: T, act: Callable[[T], T]) -> T:
return act(obs) The following is a subtype of this signature (by the substitution # valid - can replace a function with a function of that function
def intervene(obs: Callable[..., T], act: Callable[[Callable[..., T]], Callable[..., T]]) -> Callable[..., T]: ... The following is not a subtype of # valid - can replace a function with another function of the same type
def intervene(obs: Callable[..., T], act: Callable[..., T]) -> Callable[..., T]: ... By contrast, the following signatures are not valid subtypes of # invalid - cannot replace a function with a constant
def intervene(obs: Callable[..., T], act: T) -> T: ...
# invalid - cannot replace a constant with a function
def intervene(obs: T, act: Callable[..., T]) -> Callable[..., T]: ... |
Addresses #26 and necessary for supporting name-based indexing in #14 and other tutorials
This PR adds support for passing multiple values to
intervene
in a tuple:or a function that returns a tuple (for dependent interventions):
By default,
intervene
ignores all but the final entry of the tuple:However, when a counterfactual handler is active, it returns all of the values stacked along a new index variable:
Getting this PR working also required some minor refactoring of
causal_pyro.counterfactual
.Tested: