Skip to content

Commit

Permalink
device inference
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 committed Mar 1, 2023
1 parent 9efdda2 commit 15337f8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
22 changes: 21 additions & 1 deletion causal_pyro/counterfactual/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def indexset_as_mask(
*,
event_dim: int = 0,
name_to_dim: Optional[Dict[Hashable, int]] = None,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""
Get a dense mask tensor for indexing into a tensor from an indexset.
Expand All @@ -36,7 +37,7 @@ def indexset_as_mask(
for name, values in indexset.items():
inds[name_to_dim[name]] = torch.tensor(list(sorted(values)), dtype=torch.long)
batch_shape[name_to_dim[name]] = max(len(values), max(values) + 1)
mask = torch.zeros(tuple(batch_shape), dtype=torch.bool)
mask = torch.zeros(tuple(batch_shape), dtype=torch.bool, device=device)
mask[tuple(inds)] = True
return mask[(...,) + (None,) * event_dim]

Expand Down Expand Up @@ -297,6 +298,25 @@ def _pyro_add_indices(self, msg):
), f"cannot add {name}={indices} to {self.plates[name].size}"


def get_sample_msg_device(
dist: pyro.distributions.Distribution,
value: Optional[Union[torch.Tensor, float, int, bool]],
) -> torch.device:
# some gross code to infer the device of the obs_mask tensor
# because distributions are hard to introspect
if isinstance(value, torch.Tensor):
return value.device
else:
dist_ = dist
while hasattr(dist_, "base_dist"):
dist_ = dist_.base_dist
for param_name in dist_.arg_constraints.keys():
p = getattr(dist_, param_name)
if isinstance(p, torch.Tensor):
return p.device
raise ValueError(f"could not infer device for {dist} and {value}")


def expand_reparam_msg_value_inplace(
config_fn: Callable[..., Optional[pyro.infer.reparam.reparam.Reparam]]
) -> Callable[..., Optional[pyro.infer.reparam.reparam.Reparam]]:
Expand Down
9 changes: 7 additions & 2 deletions causal_pyro/counterfactual/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import pyro

from causal_pyro.counterfactual.internals import get_index_plates, indexset_as_mask
from causal_pyro.counterfactual.internals import (
get_index_plates,
get_sample_msg_device,
indexset_as_mask,
)
from causal_pyro.primitives import IndexSet


Expand All @@ -16,7 +20,8 @@ def indices(self) -> IndexSet:
raise NotImplementedError

def _pyro_sample(self, msg: Dict[str, Any]) -> None:
mask = indexset_as_mask(self.indices)
mask_device = get_sample_msg_device(msg["fn"], msg["value"])
mask = indexset_as_mask(self.indices, device=mask_device)
msg["mask"] = mask if msg["mask"] is None else msg["mask"] & mask


Expand Down

0 comments on commit 15337f8

Please sign in to comment.