Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6777 fixes boolean indexing of batched metatensor #6781

Merged
merged 2 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 51 additions & 45 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
:py:func:`MetaTensor._copy_meta`).
"""
out = []
metas = None
metas = None # optional output metadicts for each of the return value in `rets`
is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch"))
for idx, ret in enumerate(rets):
# if not `MetaTensor`, nothing to do.
Expand All @@ -219,55 +219,61 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
# the following is not implemented but the network arch may run into this case:
# if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
# raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")

# If we have a batch of data, then we need to be careful if a slice of
# the data is returned. Depending on how the data are indexed, we return
# some or all of the metadata, and the return object may or may not be a
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
if is_batch:
# if indexing e.g., `batch[0]`
if func == torch.Tensor.__getitem__:
batch_idx = args[1]
if isinstance(batch_idx, Sequence):
batch_idx = batch_idx[0]
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
# first element will be `slice(None, None, None)` and `Ellipsis`,
# respectively. Don't need to do anything with the metadata.
if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0:
ret_meta = decollate_batch(args[0], detach=False)[batch_idx]
if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate
try:
ret_meta = list_data_collate(ret_meta)
except (TypeError, ValueError, RuntimeError, IndexError) as e:
raise ValueError(
"Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
"please convert it into a torch Tensor using `x.as_tensor()` or "
"a numpy array using `x.array`."
) from e
elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
ret_meta.is_batch = False
if hasattr(ret_meta, "__dict__"):
ret.__dict__ = ret_meta.__dict__.copy()
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th
# dimension.
elif func == torch.Tensor.unbind:
if len(args) > 1:
dim = args[1]
elif "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 0
if dim == 0:
if metas is None:
metas = decollate_batch(args[0], detach=False)
ret.__dict__ = metas[idx].__dict__.copy()
ret.is_batch = False

ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs)
out.append(ret)
# if the input was a tuple, then return it as a tuple
return tuple(out) if isinstance(rets, tuple) else out

@classmethod
def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
"""utility function to handle batched MetaTensors."""
# If we have a batch of data, then we need to be careful if a slice of
# the data is returned. Depending on how the data are indexed, we return
# some or all of the metadata, and the return object may or may not be a
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
# if indexing e.g., `batch[0]`
if func == torch.Tensor.__getitem__:
if idx > 0 or len(args) < 2 or len(args[0]) < 1:
return ret
batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1]
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
# first element will be `slice(None, None, None)` and `Ellipsis`,
# respectively. Don't need to do anything with the metadata.
if batch_idx in (slice(None, None, None), Ellipsis, None) or isinstance(batch_idx, torch.Tensor):
return ret
dec_batch = decollate_batch(args[0], detach=False)
ret_meta = dec_batch[batch_idx]
if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate
try:
ret_meta = list_data_collate(ret_meta)
except (TypeError, ValueError, RuntimeError, IndexError) as e:
raise ValueError(
"Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
"please consider converting it into a torch Tensor using `x.as_tensor()` or "
"a numpy array using `x.array`."
) from e
elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
ret_meta.is_batch = False
if hasattr(ret_meta, "__dict__"):
ret.__dict__ = ret_meta.__dict__.copy()
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th dimension.
elif func == torch.Tensor.unbind:
if len(args) > 1:
dim = args[1]
elif "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 0
if dim == 0:
if metas is None:
metas = decollate_batch(args[0], detach=False)
if hasattr(metas[idx], "__dict__"):
ret.__dict__ = metas[idx].__dict__.copy()
ret.is_batch = False
return ret

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
"""Wraps all torch functions."""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ def test_slicing(self):
x.is_batch = True
with self.assertRaises(ValueError):
x[slice(0, 8)]
x = MetaTensor(np.zeros((3, 3, 4)))
x.is_batch = True
self.assertEqual(x[torch.tensor([True, False, True])].shape, (2, 3, 4))
self.assertEqual(x[[True, False, True]].shape, (2, 3, 4))

@parameterized.expand(DTYPES)
@SkipIfBeforePyTorchVersion((1, 8))
Expand Down