diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml new file mode 100644 index 00000000000..0ea0909b971 --- /dev/null +++ b/tests/conf/vhr10.yaml @@ -0,0 +1,17 @@ +model: + class_path: ObjectDetectionTask + init_args: + model: "faster-rcnn" + backbone: "resnet50" + num_classes: 11 + lr: 2.5e-5 + patience: 10 +data: + class_path: VHR10DataModule + init_args: + batch_size: 1 + num_workers: 0 + patch_size: 4 + dict_kwargs: + root: "tests/data/vhr10" + download: true diff --git a/tests/data/vhr10/NWPU VHR-10 dataset.rar b/tests/data/vhr10/NWPU VHR-10 dataset.rar index 6a1b98fed27..0f836ac8e17 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset.rar and b/tests/data/vhr10/NWPU VHR-10 dataset.rar differ diff --git a/tests/data/vhr10/annotations.json b/tests/data/vhr10/annotations.json index 60de0cb14f3..8b6183354ed 100644 --- a/tests/data/vhr10/annotations.json +++ b/tests/data/vhr10/annotations.json @@ -1 +1 @@ -{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]} \ No newline at end of file +{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]} \ No newline at end of file diff --git a/tests/data/vhr10/data.py b/tests/data/vhr10/data.py index 97cc4e54d29..63e4855a529 100755 --- a/tests/data/vhr10/data.py +++ b/tests/data/vhr10/data.py @@ -5,7 +5,6 @@ import os import shutil import subprocess -from copy import deepcopy import numpy as np from PIL import Image @@ -47,7 +46,7 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str: ) ann = 0 - for i, img in enumerate(ANNOTATION_FILE["images"]): + for _, img in enumerate(ANNOTATION_FILE["images"]): annot = { "id": ann, "image_id": img["id"], @@ -57,12 +56,7 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str: "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0, } - if i != 0: - ANNOTATION_FILE["annotations"].append(annot) - else: - noseg_annot = deepcopy(annot) - del noseg_annot["segmentation"] - ANNOTATION_FILE["annotations"].append(noseg_annot) + ANNOTATION_FILE["annotations"].append(annot) ann += 1 with open(ann_file, "w") as j: diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index 805b84a3117..5480acb4ef4 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -35,11 +35,11 @@ def dataset( monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) url = os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar") monkeypatch.setitem(VHR10.image_meta, "url", url) - md5 = "5fddb0dfd56a80638831df9f90cbf37a" + md5 = "92769845cae6a4e8c74bfa1a0d1d4a80" monkeypatch.setitem(VHR10.image_meta, "md5", md5) url = os.path.join("tests", "data", "vhr10", "annotations.json") monkeypatch.setitem(VHR10.target_meta, "url", url) - md5 = "833899cce369168e0d4ee420dac326dc" + md5 = "567c4cd8c12624864ff04865de504c58" monkeypatch.setitem(VHR10.target_meta, "md5", md5) root = str(tmp_path) split = request.param diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 51b4ac5c400..e4151ac0d29 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -67,7 +67,7 @@ def plot(*args: Any, **kwargs: Any) -> None: class TestObjectDetectionTask: - @pytest.mark.parametrize("name", ["nasa_marine_debris"]) + @pytest.mark.parametrize("name", ["nasa_marine_debris", "vhr10"]) @pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"]) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index a9ea200982a..64412c2bc6c 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -23,7 +23,7 @@ def batch_gray() -> dict[str, Tensor]: return { "image": torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), + "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -42,7 +42,7 @@ def batch_rgb() -> dict[str, Tensor]: dtype=torch.float, ), "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), + "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -63,7 +63,7 @@ def batch_multispectral() -> dict[str, Tensor]: dtype=torch.float, ), "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), + "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -79,7 +79,7 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: expected = { "image": torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), + "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( @@ -102,7 +102,7 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: dtype=torch.float, ), "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), + "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( @@ -129,7 +129,7 @@ def test_augmentation_sequential_multispectral( dtype=torch.float, ), "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), + "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( @@ -156,7 +156,7 @@ def test_augmentation_sequential_image_only( dtype=torch.float, ), "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), + "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( @@ -188,7 +188,7 @@ def test_sequential_transforms_augmentations( dtype=torch.float, ), "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), + "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } train_transforms = transforms.AugmentationSequential( diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 620d8334e44..1f05f31b3b4 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -38,6 +38,7 @@ from .usavars import USAVarsDataModule from .utils import MisconfigurationException from .vaihingen import Vaihingen2DDataModule +from .vhr10 import VHR10DataModule from .xview import XView2DataModule __all__ = ( @@ -79,6 +80,7 @@ "UCMercedDataModule", "USAVarsDataModule", "Vaihingen2DDataModule", + "VHR10DataModule", "XView2DataModule", # Base classes "BaseDataModule", diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index f740df5eb76..76848bc4e4b 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -5,28 +5,13 @@ from typing import Any +import kornia.augmentation as K import torch -from torch import Tensor from ..datasets import NASAMarineDebris +from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import dataset_split - - -def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: - """Custom object detection collate fn to handle variable boxes. - - Args: - batch: list of sample dicts return by dataset - - Returns: - batch dict output - """ - output: dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) - output["boxes"] = [sample["boxes"] for sample in batch] - output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch] - return output +from .utils import AugPipe, collate_fn_detection, dataset_split class NASAMarineDebrisDataModule(NonGeoDataModule): @@ -35,6 +20,8 @@ class NASAMarineDebrisDataModule(NonGeoDataModule): .. versionadded:: 0.2 """ + std = torch.tensor(255) + def __init__( self, batch_size: int = 64, @@ -58,7 +45,14 @@ def __init__( self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.collate_fn = collate_fn + self.aug = AugPipe( + AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "boxes"] + ), + batch_size, + ) + + self.collate_fn = collate_fn_detection def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index d0bb6af9934..ae098c9930c 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -5,10 +5,13 @@ import math from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import numpy as np -from torch import Generator +import torch +from einops import rearrange +from torch import Generator, Tensor +from torch.nn import Module from torch.utils.data import Subset, TensorDataset, random_split from ..datasets import NonGeoDataset @@ -19,6 +22,86 @@ class MisconfigurationException(Exception): """Exception used to inform users of misuse with Lightning.""" +class AugPipe(Module): + """Pipeline for applying augmentations sequentially on select data keys. + + .. versionadded:: 0.6 + """ + + def __init__( + self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int + ) -> None: + """Initialize a new AugPipe instance. + + Args: + augs: Augmentations to apply. + batch_size: Batch size + """ + super().__init__() + self.augs = augs + self.batch_size = batch_size + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply the augmentation. + + Args: + batch: Input batch. + + Returns: + Augmented batch. + """ + batch_len = len(batch["image"]) + for bs in range(batch_len): + batch_dict = { + "image": batch["image"][bs], + "labels": batch["labels"][bs], + "boxes": batch["boxes"][bs], + } + + if "masks" in batch: + batch_dict["masks"] = batch["masks"][bs] + + batch_dict = self.augs(batch_dict) + + batch["image"][bs] = batch_dict["image"] + batch["labels"][bs] = batch_dict["labels"] + batch["boxes"][bs] = batch_dict["boxes"] + + if "masks" in batch: + batch["masks"][bs] = batch_dict["masks"] + + # Stack images + batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w") + + return batch + + +def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: + """Custom collate fn for object detection and instance segmentation. + + Args: + batch: list of sample dicts return by dataset + + Returns: + batch dict output + + .. versionadded:: 0.6 + """ + output: dict[str, Any] = {} + output["image"] = [sample["image"] for sample in batch] + output["boxes"] = [sample["boxes"].float() for sample in batch] + if "labels" in batch[0]: + output["labels"] = [sample["labels"] for sample in batch] + else: + output["labels"] = [ + torch.tensor([1] * len(sample["boxes"])) for sample in batch + ] + + if "masks" in batch[0]: + output["masks"] = [sample["masks"] for sample in batch] + return output + + def dataset_split( dataset: Union[TensorDataset, NonGeoDataset], val_pct: float, diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py new file mode 100644 index 00000000000..0059d6c71f2 --- /dev/null +++ b/torchgeo/datamodules/vhr10.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NWPU VHR-10 datamodule.""" + +from typing import Any, Union + +import kornia.augmentation as K +import torch + +from ..datasets import VHR10 +from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule +from .utils import AugPipe, collate_fn_detection, dataset_split + + +class VHR10DataModule(NonGeoDataModule): + """LightningDataModule implementation for the VHR10 dataset. + + .. versionadded:: 0.6 + """ + + std = torch.tensor(255) + + def __init__( + self, + batch_size: int = 64, + patch_size: Union[tuple[int, int], int] = 512, + num_workers: int = 0, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a new VHR10DataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.VHR10`. + """ + super().__init__(VHR10, batch_size, num_workers, **kwargs) + + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + self.patch_size = _to_tuple(patch_size) + + self.collate_fn = collate_fn_detection + + self.train_aug = AugPipe( + AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.Resize(self.patch_size), + K.RandomHorizontalFlip(), + K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.7), + K.RandomVerticalFlip(), + data_keys=["image", "boxes", "masks"], + ), + batch_size, + ) + self.aug = AugPipe( + AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.Resize(self.patch_size), + data_keys=["image", "boxes", "masks"], + ), + batch_size, + ) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + self.dataset = VHR10(**self.kwargs) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + self.dataset, self.val_split_pct, self.test_split_pct + ) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index db0807ee930..43756df71a5 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -45,10 +45,7 @@ def convert_coco_poly_to_mask( mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) - if masks: - masks_tensor = torch.stack(masks, dim=0) - else: - masks_tensor = torch.zeros((0, height, width), dtype=torch.uint8) + masks_tensor = torch.stack(masks, dim=0) return masks_tensor @@ -89,10 +86,8 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: categories = [obj["category_id"] for obj in anno] classes = torch.tensor(categories, dtype=torch.int64) - if "segmentation" in anno[0]: - segmentations = [obj["segmentation"] for obj in anno] - else: - segmentations = [] + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) @@ -258,8 +253,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: sample = self.coco_convert(sample) sample["labels"] = sample["label"]["labels"] sample["boxes"] = sample["label"]["boxes"] - if "masks" in sample["label"]: - sample["masks"] = sample["label"]["masks"] + sample["masks"] = sample["label"]["masks"] del sample["label"] if self.transforms is not None: @@ -296,6 +290,7 @@ def _load_image(self, id_: int) -> Tensor: with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) tensor = torch.from_numpy(array) + tensor = tensor.float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -439,7 +434,7 @@ def plot( ncols += 1 # Display image - fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 10)) + fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 13)) axs[0, 0].imshow(image) axs[0, 0].axis("off") @@ -536,9 +531,9 @@ def plot( if show_titles: axs[0, 1].set_title("Prediction") - plt.tight_layout() - if suptitle is not None: plt.suptitle(suptitle) + plt.tight_layout() + return fig diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index f09a52050ff..c3c3d54f56b 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -10,6 +10,7 @@ from einops import rearrange from kornia.contrib import extract_tensor_patches from kornia.geometry import crop_by_indices +from kornia.geometry.boxes import Boxes from torch import Tensor from torch.nn.modules import Module @@ -47,6 +48,8 @@ def __init__( keys.append("input") elif key == "boxes": keys.append("bbox") + elif key == "masks": + keys.append("mask") else: keys.append(key) @@ -67,10 +70,19 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: dtype[key] = batch[key].dtype batch[key] = batch[key].float() + # Convert shape of boxes from [N, 4] to [N, 4, 2] + if "boxes" in batch and ( + isinstance(batch["boxes"], list) or batch["boxes"].ndim == 2 + ): + batch["boxes"] = Boxes.from_tensor(batch["boxes"]).data + # Kornia requires masks to have a channel dimension - if "mask" in batch and len(batch["mask"].shape) == 3: + if "mask" in batch and batch["mask"].ndim == 3: batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") + if "masks" in batch and batch["masks"].ndim == 3: + batch["masks"] = rearrange(batch["masks"], "c h w -> () c h w") + inputs = [batch[k] for k in self.data_keys] outputs_list: Union[Tensor, list[Tensor]] = self.augs(*inputs) outputs_list = ( @@ -85,9 +97,17 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: for key in self.data_keys: batch[key] = batch[key].to(dtype[key]) + # Convert boxes to default [N, 4] + if "boxes" in batch: + batch["boxes"] = Boxes(batch["boxes"]).to_tensor( + mode="xyxy" + ) # type:ignore[assignment] + # Torchmetrics does not support masks with a channel dimension if "mask" in batch and batch["mask"].shape[1] == 1: batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") + if "masks" in batch and batch["masks"].ndim == 4: + batch["masks"] = rearrange(batch["masks"], "() c h w -> c h w") return batch