From fa15eec1ad668efec89780775363ed69e9a07626 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 27 Oct 2023 00:25:37 +0800 Subject: [PATCH] simplify `list_data_collate` and `collate_meta_tensor` (#7165) Fixes #5917 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/utils.py | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 8c5ae88289..164fa78814 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -50,8 +50,12 @@ issequenceiterable, look_up_option, optional_import, + pytorch_after, ) +if pytorch_after(1, 13): + # import private code for reuse purposes, comment in case things break in the future + from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") nib, _ = optional_import("nibabel") @@ -444,6 +448,23 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): return data +def collate_meta_tensor_fn(batch, *, collate_fn_map=None): + """ + Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` + and so should not be used as a collate function directly in dataloaders. + """ + collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate + collated = collate_fn(batch) # type: ignore + meta_dicts = [i.meta or TraceKeys.NONE for i in batch] + common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) + if common_: + meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts] + collated.meta = default_collate(meta_dicts) + collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] + collated.is_batch = True + return collated + + def collate_meta_tensor(batch): """collate a sequence of meta tensor sequences/dictionaries into a single batched metatensor or a dictionary of batched metatensor""" @@ -451,15 +472,7 @@ def collate_meta_tensor(batch): raise NotImplementedError() elem_0 = first(batch) if isinstance(elem_0, MetaObj): - collated = default_collate(batch) - meta_dicts = [i.meta or TraceKeys.NONE for i in batch] - common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) - if common_: - meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts] - collated.meta = default_collate(meta_dicts) - collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] - collated.is_batch = True - return collated + return collate_meta_tensor_fn(batch) if isinstance(elem_0, Mapping): return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0} if isinstance(elem_0, (tuple, list)): @@ -479,9 +492,16 @@ def list_data_collate(batch: Sequence): Need to use this collate if apply some transforms that can generate batch data. """ + + if pytorch_after(1, 13): + # needs to go here to avoid circular import + from monai.data.meta_tensor import MetaTensor + + default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn}) elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch key = None + collate_fn = default_collate if pytorch_after(1, 13) else collate_meta_tensor try: if config.USE_META_DICT: data = pickle_operations(data) # bc 0.9.0 @@ -490,9 +510,9 @@ def list_data_collate(batch: Sequence): for k in elem: key = k data_for_batch = [d[key] for d in data] - ret[key] = collate_meta_tensor(data_for_batch) + ret[key] = collate_fn(data_for_batch) else: - ret = collate_meta_tensor(data) + ret = collate_fn(data) return ret except RuntimeError as re: re_str = str(re)