Skip to content

Commit

Permalink
Feature CropNonEmptyMaskIfExists augmentation (#342)
Browse files Browse the repository at this point in the history
* Add transfrom CropWithMaskIfExist

* Fix transform

* Fix crop 2

* Update transforms.py

* Update transforms.py

* Add test of crop_non_empty_mask

* Update augmentation

* Update tests

* Add CropNonEmptyMaskIfExists to transforms list

* Add failed test for crop

* Fix bug; h, w renamed; apply_always=False

* Fix docstring
  • Loading branch information
qubvel authored and ternaus committed Sep 10, 2019
1 parent 389d31a commit d96dabe
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 35 deletions.
67 changes: 34 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,39 +145,40 @@ Pixel-level transforms will change just an input image and will leave any additi
## Spatial-level transforms
Spatial-level transforms will simultaneously change both an input image as well as additional targets such as masks, bounding boxes, and keypoints. The following table shows which additional targets are supported by each transform.

| Transform | Image | Masks | BBoxes | Keypoints |
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---: | :---: | :----: | :-------: |
| [CenterCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.CenterCrop) |||||
| [Crop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Crop) |||| |
| [ElasticTransform](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ElasticTransform) ||| | |
| [Flip](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Flip) |||||
| [GridDistortion](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.GridDistortion) ||| | |
| [HorizontalFlip](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.HorizontalFlip) |||||
| [IAAAffine](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAAffine) |||||
| [IAACropAndPad](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAACropAndPad) |||||
| [IAAFliplr](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAFliplr) |||||
| [IAAFlipud](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAFlipud) |||||
| [IAAPerspective](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAPerspective) |||||
| [IAAPiecewiseAffine](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAPiecewiseAffine) |||||
| [Lambda](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Lambda) |||||
| [LongestMaxSize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.LongestMaxSize) |||| |
| NoOp |||||
| [OpticalDistortion](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.OpticalDistortion) ||| | |
| [PadIfNeeded](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.PadIfNeeded) |||||
| [RandomCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCrop) |||||
| [RandomCropNearBBox](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCropNearBBox) |||| |
| [RandomGridShuffle](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomGridShuffle) ||| | |
| [RandomResizedCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomResizedCrop) |||||
| [RandomRotate90](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomRotate90) |||||
| [RandomScale](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomScale) |||||
| [RandomSizedBBoxSafeCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomSizedBBoxSafeCrop) |||| |
| [RandomSizedCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomSizedCrop) |||||
| [Resize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Resize) |||| |
| [Rotate](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Rotate) |||||
| [ShiftScaleRotate](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ShiftScaleRotate) |||||
| [SmallestMaxSize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.SmallestMaxSize) |||| |
| [Transpose](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Transpose) |||| |
| [VerticalFlip](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.VerticalFlip) |||||
| Transform | Image | Masks | BBoxes | Keypoints |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---: | :---: | :----: | :-------: |
| [CenterCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.CenterCrop) |||||
| [Crop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Crop) |||| |
| [CropNonEmptyMaskIfExists](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.CropNonEmptyMaskIfExists) ||| | |
| [ElasticTransform](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ElasticTransform) ||| | |
| [Flip](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Flip) |||||
| [GridDistortion](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.GridDistortion) ||| | |
| [HorizontalFlip](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.HorizontalFlip) |||||
| [IAAAffine](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAAffine) |||||
| [IAACropAndPad](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAACropAndPad) |||||
| [IAAFliplr](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAFliplr) |||||
| [IAAFlipud](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAFlipud) |||||
| [IAAPerspective](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAPerspective) |||||
| [IAAPiecewiseAffine](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAAPiecewiseAffine) |||||
| [Lambda](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Lambda) |||||
| [LongestMaxSize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.LongestMaxSize) |||| |
| NoOp |||||
| [OpticalDistortion](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.OpticalDistortion) ||| | |
| [PadIfNeeded](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.PadIfNeeded) |||||
| [RandomCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCrop) |||||
| [RandomCropNearBBox](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCropNearBBox) |||| |
| [RandomGridShuffle](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomGridShuffle) ||| | |
| [RandomResizedCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomResizedCrop) |||||
| [RandomRotate90](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomRotate90) |||||
| [RandomScale](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomScale) |||||
| [RandomSizedBBoxSafeCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomSizedBBoxSafeCrop) |||| |
| [RandomSizedCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomSizedCrop) |||||
| [Resize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Resize) |||| |
| [Rotate](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Rotate) |||||
| [ShiftScaleRotate](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ShiftScaleRotate) |||||
| [SmallestMaxSize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.SmallestMaxSize) |||| |
| [Transpose](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Transpose) |||| |
| [VerticalFlip](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.VerticalFlip) |||||

## Migrating from torchvision to albumentations

Expand Down
85 changes: 84 additions & 1 deletion albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
'RandomBrightness', 'RandomContrast', 'MotionBlur', 'MedianBlur',
'GaussianBlur', 'GaussNoise', 'CLAHE', 'ChannelShuffle', 'InvertImg',
'ToGray', 'JpegCompression', 'Cutout', 'CoarseDropout', 'ToFloat',
'FromFloat', 'Crop', 'RandomScale', 'LongestMaxSize', 'SmallestMaxSize',
'FromFloat', 'Crop', 'CropNonEmptyMaskIfExists', 'RandomScale', 'LongestMaxSize', 'SmallestMaxSize',
'Resize', 'RandomSizedCrop', 'RandomResizedCrop', 'RandomBrightnessContrast',
'RandomCropNearBBox', 'RandomSizedBBoxSafeCrop', 'RandomSnow',
'RandomRain', 'RandomFog', 'RandomSunFlare', 'RandomShadow', 'Lambda',
Expand Down Expand Up @@ -878,6 +878,89 @@ def get_transform_init_args_names(self):
return ('height', 'width', 'erosion_rate', 'interpolation')


class CropNonEmptyMaskIfExists(DualTransform):
"""Crop area with mask if mask is non-empty, else make random crop.
Args:
height (int): vertical size of crop in pixels
width (int): horizontal size of crop in pixels
ignore_values (list of int): values to ignore in mask, `0` values are always ignored
(e.g. if background value is 5 set `ignore_values=[5]` to ignore)
ignore_channels (list of int): channels to ignore in mask
(e.g. if background is a first channel set `ignore_channels=[0]` to ignore)
p (float): probability of applying the transform. Default: 1.0.
Targets:
image, mask
Image types:
uint8, float32
"""

def __init__(self, height, width, ignore_values=None,
ignore_channels=None, always_apply=False, p=1.0):
super(CropNonEmptyMaskIfExists, self).__init__(always_apply, p)

if ignore_values is not None and not isinstance(ignore_values, list):
raise ValueError('Expected `ignore_values` of type `list`, got `{}`'.format(type(ignore_values)))
if ignore_channels is not None and not isinstance(ignore_channels, list):
raise ValueError('Expected `ignore_channels` of type `list`, got `{}`'.format(type(ignore_channels)))

self.height = height
self.width = width
self.ignore_values = ignore_values
self.ignore_channels = ignore_channels

def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
return F.crop(img, x_min, y_min, x_max, y_max)

@property
def targets_as_params(self):
return ['mask']

def get_params_dependent_on_targets(self, params):
mask = params['mask']
mask_height, mask_width = mask.shape[:2]

if self.ignore_values is not None:
ignore_values_np = np.array(self.ignore_values)
mask = np.where(np.isin(mask, ignore_values_np), 0, mask)

if mask.ndim == 3 and self.ignore_channels is not None:
target_channels = np.array([ch for ch in range(mask.shape[-1])
if ch not in self.ignore_channels])
mask = np.take(mask, target_channels, axis=-1)

if self.height > mask_height or self.width > mask_width:
raise ValueError('Crop size ({},{}) is larger than image ({},{})'.format(
self.height, self.width, mask_height, mask_width))

if mask.sum() == 0:
x_min = random.randint(0, mask_width - self.width)
y_min = random.randint(0, mask_height - self.height)
else:
mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
non_zero_yx = np.argwhere(mask)
y, x = random.choice(non_zero_yx)
x_min = x - random.randint(0, self.width - 1)
y_min = y - random.randint(0, self.height - 1)
x_min = np.clip(x_min, 0, mask_width - self.width)
y_min = np.clip(y_min, 0, mask_height - self.height)

x_max = x_min + self.width
y_max = y_min + self.height

return {
'x_min': x_min,
'x_max': x_max,
'y_min': y_min,
'y_max': y_max,
}

def get_transform_init_args_names(self):
return ('height', 'width', 'ignore_values', 'ignore_channels')


class OpticalDistortion(DualTransform):
"""
Targets:
Expand Down
Loading

0 comments on commit d96dabe

Please sign in to comment.