Skip to content

Commit

Permalink
4855 lazy resampling impl -- Compose (#5860)
Browse files Browse the repository at this point in the history
part of #4855

upgrade #4911 to use the
latest dev API

### Description
Example usage:

for a sequence of spatial transforms

```py
xforms = [
    mt.LoadImageD(keys, ensure_channel_first=True),
    mt.Orientationd(keys, "RAS"),
    mt.SpacingD(keys, (1.5, 1.5, 1.5)),
    mt.CenterScaleCropD(keys, roi_scale=0.9),
    # mt.CropForegroundD(keys, source_key="seg", k_divisible=5),
    mt.RandRotateD(keys, prob=1.0, range_y=np.pi / 2, range_x=np.pi / 3),
    mt.RandSpatialCropD(keys, roi_size=(76, 87, 73)),
    mt.RandScaleCropD(keys, roi_scale=0.9),
    mt.Resized(keys, (30, 40, 60)),
    # mt.NormalizeIntensityd(keys),
    mt.ZoomD(keys, 1.3, keep_size=False),
    mt.FlipD(keys),
    mt.Rotate90D(keys),
    mt.RandAffined(keys),
    mt.ResizeWithPadOrCropd(keys, spatial_size=(32, 43, 54)),
    mt.DivisiblePadD(keys, k=3),
]
lazy_kwargs = dict(mode=("bilinear", 0), padding_mode=("border", "nearest"), dtype=(torch.float32, torch.uint8))
xform = mt.Compose(xforms, lazy_evaluation=True, overrides=lazy_kwargs, override_keys=keys)
xform.set_random_state(0)
```
lazy_evaluation=True preserves more details
![Screenshot 2023-01-17 at 00 31
40](https://user-images.githubusercontent.com/831580/212784981-ea39833b-54ab-42fb-bc03-38b012281857.png)
compared with the regular compose
![Screenshot 2023-01-17 at 00 31
43](https://user-images.githubusercontent.com/831580/212785016-ba3be8ff-f17f-47b4-8025-cd351a637a82.png)



### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
Signed-off-by: Yiheng Wang <vennw@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: Ben Murray <ben.murray@gmail.com>
Co-authored-by: Ben Murray <ben.murray@gmail.com>
Co-authored-by: binliu <binliu@nvidia.com>
Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Yiheng Wang <vennw@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
8 people authored Mar 23, 2023
1 parent 1cd0d7b commit b87375f
Show file tree
Hide file tree
Showing 16 changed files with 578 additions and 102 deletions.
6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2206,3 +2206,9 @@ Utilities

.. automodule:: monai.transforms.utils_pytorch_numpy_unification
:members:

Lazy
----
.. automodule:: monai.transforms.lazy
:members:
:imported-members:
12 changes: 12 additions & 0 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def _pre_transform(self, item_transformed):
break
# this is to be consistent with CacheDataset even though it's not in a multi-thread situation.
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform)
item_transformed = apply_transform(_xform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
if self.reset_ops_id:
reset_ops_id(item_transformed)
return item_transformed
Expand All @@ -348,7 +350,9 @@ def _post_transform(self, item_transformed):
or not isinstance(_transform, Transform)
):
start_post_randomize_run = True
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _transform)
item_transformed = apply_transform(_transform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
return item_transformed

def _cachecheck(self, item_transformed):
Expand Down Expand Up @@ -496,7 +500,9 @@ def _pre_transform(self, item_transformed):
if i == self.cache_n_trans:
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform)
item_transformed = apply_transform(_xform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
reset_ops_id(item_transformed)
return item_transformed

Expand All @@ -514,7 +520,9 @@ def _post_transform(self, item_transformed):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for i, _transform in enumerate(self.transform.transforms):
if i >= self.cache_n_trans:
item_transformed = self.transform.evaluate_with_overrides(item_transformed, item_transformed)
item_transformed = apply_transform(_transform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
return item_transformed


Expand Down Expand Up @@ -884,7 +892,9 @@ def _load_cache_item(self, idx: int):
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item = self.transform.evaluate_with_overrides(item, _xform)
item = apply_transform(_xform, item)
item = self.transform.evaluate_with_overrides(item, None)
if self.as_contiguous:
item = convert_to_contiguous(item, memory_format=torch.contiguous_format)
return item
Expand Down Expand Up @@ -921,7 +931,9 @@ def _transform(self, index: int):
start_run = True
if self.copy_cache:
data = deepcopy(data)
data = self.transform.evaluate_with_overrides(data, _transform)
data = apply_transform(_transform, data)
data = self.transform.evaluate_with_overrides(data, None)
return data


Expand Down
9 changes: 9 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,15 @@ def pending_operations(self) -> list[dict]:
return self._pending_operations
return MetaObj.get_default_applied_operations() # the same default as applied_ops

@property
def has_pending_operations(self) -> bool:
"""
Determine whether there are pending operations.
Returns:
True if there are pending operations; False if not
"""
return self.pending_operations is not None and len(self.pending_operations) > 0

def push_pending_operation(self, t: Any) -> None:
self._pending_operations.append(t)

Expand Down
2 changes: 1 addition & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def peek_pending_affine(self):
continue
res = convert_to_dst_type(res, next_matrix)[0]
next_matrix = monai.data.utils.to_affine_nd(r, next_matrix)
res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix)
res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix) # type: ignore
return res

def peek_pending_rank(self):
Expand Down
193 changes: 178 additions & 15 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,95 @@
import numpy as np

import monai
import monai.transforms as mt
from monai.apps.utils import get_logger
from monai.transforms.inverse import InvertibleTransform

# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform)
from monai.transforms.transform import ( # noqa: F401
LazyTransform,
MapTransform,
Randomizable,
RandomizableTransform,
Transform,
apply_transform,
)
from monai.utils import MAX_SEED, ensure_tuple, get_seed
from monai.utils.enums import TraceKeys
from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed
from monai.utils.misc import to_tuple_of_dictionaries

__all__ = ["Compose", "OneOf", "RandomOrder"]
logger = get_logger(__name__)

__all__ = ["Compose", "OneOf", "RandomOrder", "evaluate_with_overrides"]


def evaluate_with_overrides(
data,
upcoming,
lazy_evaluation: bool | None = False,
overrides: dict | None = None,
override_keys: Sequence[str] | None = None,
verbose: bool = False,
):
"""
The previously applied transform may have been lazily applied to MetaTensor `data` and
made `data.has_pending_operations` equals to True. Given the upcoming transform ``upcoming``,
this function determines whether `data.pending_operations` should be evaluated. If so, it will
evaluate the lazily applied transforms.
Currently, the conditions for evaluation are:
- ``lazy_evaluation`` is ``True``, AND
- the data is a ``MetaTensor`` and has pending operations, AND
- the upcoming transform is an instance of ``Identity`` or ``IdentityD`` or ``None``.
The returned `data` will then be ready for the ``upcoming`` transform.
Args:
data: data to be evaluated.
upcoming: the upcoming transform.
lazy_evaluation: whether to evaluate the pending operations.
override: keyword arguments to apply transforms.
override_keys: to which the override arguments are used when apply transforms.
verbose: whether to print debugging info when evaluate MetaTensor with pending operations.
"""
if not lazy_evaluation:
return data # eager evaluation
overrides = (overrides or {}).copy()
if isinstance(data, monai.data.MetaTensor):
if data.has_pending_operations and ((isinstance(upcoming, (mt.Identityd, mt.Identity))) or upcoming is None):
data, _ = mt.apply_transforms(data, None, overrides=overrides)
if verbose:
next_name = "final output" if upcoming is None else f"'{upcoming.__class__.__name__}'"
logger.info(f"Evaluated - '{override_keys}' - up-to-date for - {next_name}")
elif verbose:
logger.info(
f"Lazy - '{override_keys}' - upcoming: '{upcoming.__class__.__name__}'"
f"- pending {len(data.pending_operations)}"
)
return data
override_keys = ensure_tuple(override_keys)
if isinstance(data, dict):
if isinstance(upcoming, MapTransform):
applied_keys = {k for k in data if k in upcoming.keys}
if not applied_keys:
return data
else:
applied_keys = set(data.keys())

keys_to_override = {k for k in applied_keys if k in override_keys}
# generate a list of dictionaries with the appropriate override value per key
dict_overrides = to_tuple_of_dictionaries(overrides, override_keys)
for k in data:
if k in keys_to_override:
dict_for_key = dict_overrides[override_keys.index(k)]
data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, dict_for_key, k, verbose)
else:
data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, None, k, verbose)

if isinstance(data, (list, tuple)):
return [evaluate_with_overrides(v, upcoming, lazy_evaluation, overrides, override_keys, verbose) for v in data]
return data


class Compose(Randomizable, InvertibleTransform):
Expand Down Expand Up @@ -114,7 +189,21 @@ class Compose(Randomizable, InvertibleTransform):
log_stats: whether to log the detailed information of data and applied transform when error happened,
for NumPy array and PyTorch Tensor, log the data shape and value range,
for other metadata, log the values directly. default to `False`.
lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be
carried out on a transform by transform basis. If True, all lazy transforms will
be executed by accumulating changes and resampling as few times as possible.
A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of
the pending operations and make the primary data up-to-date.
overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden
when executing a pipeline. These each parameter that is compatible with a given transform is then applied
to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation
is True. If lazy_evaluation is False they are ignored.
currently supported args are:
{``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``},
please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details.
override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If
``overrides`` is set, ``override_keys`` must also be set.
verbose: whether to print debugging info when lazy_evaluation=True.
"""

def __init__(
Expand All @@ -123,6 +212,10 @@ def __init__(
map_items: bool = True,
unpack_items: bool = False,
log_stats: bool = False,
lazy_evaluation: bool | None = None,
overrides: dict | None = None,
override_keys: Sequence[str] | None = None,
verbose: bool = False,
) -> None:
if transforms is None:
transforms = []
Expand All @@ -132,6 +225,16 @@ def __init__(
self.log_stats = log_stats
self.set_random_state(seed=get_seed())

self.lazy_evaluation = lazy_evaluation
self.overrides = overrides
self.override_keys = override_keys
self.verbose = verbose

if self.lazy_evaluation is not None:
for t in self.flatten().transforms: # TODO: test Compose of Compose/OneOf
if isinstance(t, LazyTransform):
t.lazy_evaluation = self.lazy_evaluation

def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Compose:
super().set_random_state(seed=seed, state=state)
for _transform in self.transforms:
Expand Down Expand Up @@ -172,9 +275,26 @@ def __len__(self):
"""Return number of transformations."""
return len(self.flatten().transforms)

def evaluate_with_overrides(self, input_, upcoming_xform):
"""
Args:
input_: input data to be transformed.
upcoming_xform: a transform used to determine whether to evaluate with override
"""
return evaluate_with_overrides(
input_,
upcoming_xform,
lazy_evaluation=self.lazy_evaluation,
overrides=self.overrides,
override_keys=self.override_keys,
verbose=self.verbose,
)

def __call__(self, input_):
for _transform in self.transforms:
input_ = self.evaluate_with_overrides(input_, _transform)
input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats)
input_ = self.evaluate_with_overrides(input_, None)
return input_

def inverse(self, data):
Expand Down Expand Up @@ -204,7 +324,21 @@ class OneOf(Compose):
log_stats: whether to log the detailed information of data and applied transform when error happened,
for NumPy array and PyTorch Tensor, log the data shape and value range,
for other metadata, log the values directly. default to `False`.
lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will
be executed by accumulating changes and resampling as few times as possible. If False, transforms will be
carried out on a transform by transform basis.
A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of
the pending operations and make the primary data up-to-date.
overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden
when executing a pipeline. These each parameter that is compatible with a given transform is then applied
to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation
is True. If lazy_evaluation is False they are ignored.
currently supported args are:
{``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``},
please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details.
override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If
``overrides`` is set, ``override_keys`` must also be set.
verbose: whether to print debugging info when lazy_evaluation=True.
"""

def __init__(
Expand All @@ -214,8 +348,14 @@ def __init__(
map_items: bool = True,
unpack_items: bool = False,
log_stats: bool = False,
lazy_evaluation: bool | None = None,
overrides: dict | None = None,
override_keys: Sequence[str] | None = None,
verbose: bool = False,
) -> None:
super().__init__(transforms, map_items, unpack_items, log_stats)
super().__init__(
transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose
)
if len(self.transforms) == 0:
weights = []
elif weights is None or isinstance(weights, float):
Expand Down Expand Up @@ -265,8 +405,8 @@ def __call__(self, data):
self.push_transform(data, extra_info={"index": index})
elif isinstance(data, Mapping):
for key in data: # dictionary not change size during iteration
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
self.push_transform(data, key, extra_info={"index": index})
if isinstance(data[key], monai.data.MetaTensor):
self.push_transform(data[key], extra_info={"index": index})
return data

def inverse(self, data):
Expand All @@ -278,7 +418,7 @@ def inverse(self, data):
index = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["index"]
elif isinstance(data, Mapping):
for key in data:
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
if isinstance(data[key], monai.data.MetaTensor):
index = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
else:
raise RuntimeError(
Expand Down Expand Up @@ -306,7 +446,21 @@ class RandomOrder(Compose):
log_stats: whether to log the detailed information of data and applied transform when error happened,
for NumPy array and PyTorch Tensor, log the data shape and value range,
for other metadata, log the values directly. default to `False`.
lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will
be executed by accumulating changes and resampling as few times as possible. If False, transforms will be
carried out on a transform by transform basis.
A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of
the pending operations and make the primary data up-to-date.
overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden
when executing a pipeline. These each parameter that is compatible with a given transform is then applied
to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation
is True. If lazy_evaluation is False they are ignored.
currently supported args are:
{``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``},
please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details.
override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If
``overrides`` is set, ``override_keys`` must also be set.
verbose: whether to print debugging info when lazy_evaluation=True.
"""

def __init__(
Expand All @@ -315,8 +469,14 @@ def __init__(
map_items: bool = True,
unpack_items: bool = False,
log_stats: bool = False,
lazy_evaluation: bool | None = None,
overrides: dict | None = None,
override_keys: Sequence[str] | None = None,
verbose: bool = False,
) -> None:
super().__init__(transforms, map_items, unpack_items, log_stats)
super().__init__(
transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose
)

def __call__(self, input_):
if len(self.transforms) == 0:
Expand All @@ -331,8 +491,8 @@ def __call__(self, input_):
self.push_transform(input_, extra_info={"applied_order": applied_order})
elif isinstance(input_, Mapping):
for key in input_: # dictionary not change size during iteration
if isinstance(input_[key], monai.data.MetaTensor) or self.trace_key(key) in input_:
self.push_transform(input_, key, extra_info={"applied_order": applied_order})
if isinstance(input_[key], monai.data.MetaTensor):
self.push_transform(input_[key], extra_info={"applied_order": applied_order})
return input_

def inverse(self, data):
Expand All @@ -344,7 +504,7 @@ def inverse(self, data):
applied_order = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["applied_order"]
elif isinstance(data, Mapping):
for key in data:
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
if isinstance(data[key], monai.data.MetaTensor):
applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["applied_order"]
else:
raise RuntimeError(
Expand All @@ -356,5 +516,8 @@ def inverse(self, data):

# loop backwards over transforms
for o in reversed(applied_order):
data = apply_transform(self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats)
if isinstance(self.transforms[o], InvertibleTransform):
data = apply_transform(
self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats
)
return data
Loading

0 comments on commit b87375f

Please sign in to comment.