diff --git a/albumentations/pytorch/transforms.py b/albumentations/pytorch/transforms.py index 38246ef43..fb6afdfe4 100644 --- a/albumentations/pytorch/transforms.py +++ b/albumentations/pytorch/transforms.py @@ -1,19 +1,89 @@ from __future__ import absolute_import +import warnings + +import numpy as np import torch +from torchvision.transforms import functional as F from ..core.transforms_interface import BasicTransform -__all__ = ['ToTensor'] +__all__ = ['ToTensor', 'ToTensorV2'] + + +def img_to_tensor(im, normalize=None): + tensor = torch.from_numpy(np.moveaxis(im / (255. if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32)) + if normalize is not None: + return F.normalize(tensor, **normalize) + return tensor + + +def mask_to_tensor(mask, num_classes, sigmoid): + # todo + if num_classes > 1: + if not sigmoid: + # softmax + long_mask = np.zeros((mask.shape[:2]), dtype=np.int64) + if len(mask.shape) == 3: + for c in range(mask.shape[2]): + long_mask[mask[..., c] > 0] = c + else: + long_mask[mask > 127] = 1 + long_mask[mask == 0] = 0 + mask = long_mask + else: + mask = np.moveaxis(mask / (255. if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32) + else: + mask = np.expand_dims(mask / (255. if mask.dtype == np.uint8 else 1), 0).astype(np.float32) + return torch.from_numpy(mask) class ToTensor(BasicTransform): + """Convert image and mask to `torch.Tensor` and divide by 255 if image or mask are `uint8` type. + WARNING! Please use this with care and look into sources before usage. + + Args: + num_classes (int): only for segmentation + sigmoid (bool, optional): only for segmentation, transform mask to LongTensor or not. + normalize (dict, optional): dict with keys [mean, std] to pass it into torchvision.normalize + + """ + + def __init__(self, num_classes=1, sigmoid=True, normalize=None): + super(ToTensor, self).__init__(always_apply=True, p=1.) + self.num_classes = num_classes + self.sigmoid = sigmoid + self.normalize = normalize + warnings.warn("ToTensor is deprecated and will be replaced by ToTensorV2 " + "in albumentations 0.4.0", DeprecationWarning) + + def __call__(self, force_apply=True, **kwargs): + kwargs.update({'image': img_to_tensor(kwargs['image'], self.normalize)}) + if 'mask' in kwargs.keys(): + kwargs.update({'mask': mask_to_tensor(kwargs['mask'], self.num_classes, sigmoid=self.sigmoid)}) + + for k, v in kwargs.items(): + if self._additional_targets.get(k) == 'image': + kwargs.update({k: img_to_tensor(kwargs[k], self.normalize)}) + if self._additional_targets.get(k) == 'mask': + kwargs.update({k: mask_to_tensor(kwargs[k], self.num_classes, sigmoid=self.sigmoid)}) + return kwargs + + @property + def targets(self): + raise NotImplementedError + + def get_transform_init_args_names(self): + return 'num_classes', 'sigmoid', 'normalize' + + +class ToTensorV2(BasicTransform): """Convert image and mask to `torch.Tensor`. """ def __init__(self): - super(ToTensor, self).__init__(always_apply=True) + super(ToTensorV2, self).__init__(always_apply=True) @property def targets(self): @@ -29,7 +99,7 @@ def apply_to_mask(self, mask, **params): return torch.from_numpy(mask) def get_transform_init_args_names(self): - return {} + return [] def get_params_dependent_on_targets(self, params): return {} diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 8dcd5e330..04ddc8e2a 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -1,12 +1,13 @@ +import pytest import numpy as np import torch import albumentations as A -from albumentations.pytorch.transforms import ToTensor +from albumentations.pytorch.transforms import ToTensor, ToTensorV2 -def test_torch_to_tensor_augmentations(image, mask): - aug = ToTensor() +def test_torch_to_tensor_v2_augmentations(image, mask): + aug = ToTensorV2() data = aug(image=image, mask=mask, force_apply=True) assert isinstance(data['image'], torch.Tensor) and data['image'].shape == image.shape[::-1] assert isinstance(data['mask'], torch.Tensor) and data['mask'].shape == mask.shape @@ -14,9 +15,9 @@ def test_torch_to_tensor_augmentations(image, mask): assert data['mask'].dtype == torch.uint8 -def test_additional_targets_for_totensor(): +def test_additional_targets_for_totensorv2(): aug = A.Compose( - [ToTensor()], additional_targets={'image2': 'image', 'mask2': 'mask'}) + [ToTensorV2()], additional_targets={'image2': 'image', 'mask2': 'mask'}) for i in range(10): image1 = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8) image2 = image1.copy() @@ -29,3 +30,25 @@ def test_additional_targets_for_totensor(): assert isinstance(res['mask2'], torch.Tensor) and res['mask2'].shape == mask2.shape assert np.array_equal(res['image'], res['image2']) assert np.array_equal(res['mask'], res['mask2']) + + +def test_torch_to_tensor_augmentations(image, mask): + with pytest.warns(DeprecationWarning): + aug = ToTensor() + data = aug(image=image, mask=mask, force_apply=True) + assert data['image'].dtype == torch.float32 + assert data['mask'].dtype == torch.float32 + + +def test_additional_targets_for_totensor(): + with pytest.warns(DeprecationWarning): + aug = A.Compose( + [ToTensor(num_classes=4)], additional_targets={'image2': 'image', 'mask2': 'mask'}) + for i in range(10): + image1 = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8) + image2 = image1.copy() + mask1 = np.random.randint(low=0, high=256, size=(100, 100, 4), dtype=np.uint8) + mask2 = mask1.copy() + res = aug(image=image1, image2=image2, mask=mask1, mask2=mask2) + assert np.array_equal(res['image'], res['image2']) + assert np.array_equal(res['mask'], res['mask2'])