Skip to content

Commit

Permalink
Random grid shuffle (#311)
Browse files Browse the repository at this point in the history
* Added random grid shuffle

* Fixed bug with iaa.PerspectiveTransform

* Fixed style code & fill docs

* Fixed the misleading name & restored the test of augmentation part

* Fixed style code

* Fixed conditional & description
  • Loading branch information
toshiks authored and ternaus committed Aug 15, 2019
1 parent ef68cd4 commit 4cf6c36
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [PadIfNeeded](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.PadIfNeeded) |||||
| [RandomCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCrop) |||||
| [RandomCropNearBBox](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCropNearBBox) |||| |
| [RandomGridShuffle](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomGridShuffle) ||| | |
| [RandomRotate90](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomRotate90) |||||
| [RandomScale](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomScale) |||||
| [RandomSizedBBoxSafeCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomSizedBBoxSafeCrop) |||| |
Expand Down
19 changes: 19 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,3 +1256,22 @@ def py3round(number):

def noop(input_obj, **params):
return input_obj


def swap_tiles_on_image(image, tiles):
"""
Swap tiles on image.
Args:
image (np.ndarray): Input image.
tiles (np.ndarray): array of tuples(current_left_up_corner_row, current_left_up_corner_col,
old_left_up_corner_row, old_left_up_corner_col,
height_tile, width_tile)
"""
new_image = image.copy()

for idx, tile in enumerate(tiles):
new_image[tile[0]:tile[0] + tile[4], tile[1]:tile[1] + tile[5]] = \
image[tile[2]:tile[2] + tile[4], tile[3]:tile[3] + tile[5]]

return new_image
92 changes: 91 additions & 1 deletion albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
'Blur', 'VerticalFlip', 'HorizontalFlip', 'Flip', 'Normalize', 'Transpose',
'RandomCrop', 'RandomGamma', 'RandomRotate90', 'Rotate',
'ShiftScaleRotate', 'CenterCrop', 'OpticalDistortion', 'GridDistortion',
'ElasticTransform', 'HueSaturationValue', 'PadIfNeeded', 'RGBShift',
'ElasticTransform', 'RandomGridShuffle', 'HueSaturationValue', 'PadIfNeeded', 'RGBShift',
'RandomBrightness', 'RandomContrast', 'MotionBlur', 'MedianBlur',
'GaussianBlur', 'GaussNoise', 'CLAHE', 'ChannelShuffle', 'InvertImg',
'ToGray', 'JpegCompression', 'Cutout', 'CoarseDropout', 'ToFloat',
Expand Down Expand Up @@ -912,6 +912,96 @@ def get_transform_init_args_names(self):
'mask_value', 'approximate')


class RandomGridShuffle(DualTransform):
"""
Random shuffle grid's cells on image.
Args:
grid ((int, int)): size of grid for splitting image.
Targets:
image, mask
Image types:
uint8, float32
"""

def __init__(self, grid=(3, 3), always_apply=False, p=1.0):
super(RandomGridShuffle, self).__init__(always_apply, p)
self.grid = grid

def apply(self, img, tiles=None, **params):
if tiles is None:
tiles = []

return F.swap_tiles_on_image(img, tiles)

def apply_to_mask(self, img, tiles=None, **params):
if tiles is None:
tiles = []

return F.swap_tiles_on_image(img, tiles)

def get_params_dependent_on_targets(self, params):
height, width = params['image'].shape[:2]
n, m = self.grid

if n <= 0 or m <= 0:
raise ValueError("Grid's values must be positive. Current grid [%s, %s]" % (n, m))

if n > height // 2 or m > width // 2:
raise ValueError("Incorrect size cell of grid. Just shuffle pixels of image")

random_state = np.random.RandomState(random.randint(0, 10000))

height_split = np.linspace(0, height, n + 1, dtype=np.int)
width_split = np.linspace(0, width, m + 1, dtype=np.int)

height_matrix, width_matrix = np.meshgrid(height_split, width_split, indexing='ij')

index_height_matrix = height_matrix[:-1, :-1]
index_width_matrix = width_matrix[:-1, :-1]

shifted_index_height_matrix = height_matrix[1:, 1:]
shifted_index_width_matrix = width_matrix[1:, 1:]

height_tile_sizes = shifted_index_height_matrix - index_height_matrix
width_tile_sizes = shifted_index_width_matrix - index_width_matrix

tiles_sizes = np.stack((height_tile_sizes, width_tile_sizes), axis=2)

index_matrix = np.indices((n, m))
new_index_matrix = np.stack(index_matrix, axis=2)

for bbox_size in np.unique(tiles_sizes.reshape(-1, 2), axis=0):
eq_mat = np.all(tiles_sizes == bbox_size, axis=2)
new_index_matrix[eq_mat] = random_state.permutation(new_index_matrix[eq_mat])

new_index_matrix = np.split(new_index_matrix, 2, axis=2)

old_x = index_height_matrix[new_index_matrix[0], new_index_matrix[1]].reshape(-1)
old_y = index_width_matrix[new_index_matrix[0], new_index_matrix[1]].reshape(-1)

shift_x = height_tile_sizes.reshape(-1)
shift_y = width_tile_sizes.reshape(-1)

curr_x = index_height_matrix.reshape(-1)
curr_y = index_width_matrix.reshape(-1)

tiles = np.stack([curr_x, curr_y, old_x, old_y, shift_x, shift_y], axis=1)

return {
"tiles": tiles,
}

@property
def targets_as_params(self):
return ['image']

def get_transform_init_args_names(self):
return ('grid',)


class Normalize(ImageOnlyTransform):
"""Divide pixel values by 255 = 2**8 - 1, subtract mean per channel and divide by std per channel.
Expand Down
2 changes: 1 addition & 1 deletion albumentations/imgaug/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=.5):

@property
def processor(self):
return iaa.PerspectiveTransform(self.scale, self.keep_size)
return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size)

def get_transform_init_args_names(self):
return ('scale', 'keep_size')
8 changes: 7 additions & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from albumentations import (
RandomCrop, PadIfNeeded, VerticalFlip, HorizontalFlip, Flip, Transpose,
RandomRotate90, Rotate, ShiftScaleRotate, CenterCrop, OpticalDistortion,
GridDistortion, ElasticTransform, ToGray, RandomGamma, JpegCompression,
GridDistortion, ElasticTransform, RandomGridShuffle, ToGray, RandomGamma, JpegCompression,
HueSaturationValue, RGBShift, Blur, MotionBlur, MedianBlur, GaussianBlur,
GaussNoise, CLAHE, ChannelShuffle, InvertImg, IAAEmboss, IAASuperpixels,
IAASharpen, IAAAdditiveGaussianNoise, IAAPiecewiseAffine, IAAPerspective,
Expand Down Expand Up @@ -101,6 +101,7 @@ def test_image_only_augmentations_with_float_values(augmentation_cls, params, fl
[RandomCrop, {'height': 10, 'width': 10}],
[RandomSizedCrop, {'min_max_height': (4, 8), 'height': 10, 'width': 10}],
[ISONoise, {}],
[RandomGridShuffle, {}]
])
def test_dual_augmentations(augmentation_cls, params, image, mask):
aug = augmentation_cls(p=1, **params)
Expand All @@ -124,6 +125,7 @@ def test_dual_augmentations(augmentation_cls, params, image, mask):
[CenterCrop, {'height': 10, 'width': 10}],
[RandomCrop, {'height': 10, 'width': 10}],
[RandomSizedCrop, {'min_max_height': (4, 8), 'height': 10, 'width': 10}],
[RandomGridShuffle, {}]
])
def test_dual_augmentations_with_float_values(augmentation_cls, params, float_image, mask):
aug = augmentation_cls(p=1, **params)
Expand Down Expand Up @@ -193,6 +195,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):
[RandomShadow, {}],
[ChannelDropout, {}],
[ISONoise, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
])
def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -244,6 +247,7 @@ def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
[RandomSunFlare, {}],
[RandomShadow, {}],
[ChannelDropout, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
])
def test_augmentations_wont_change_float_input(augmentation_cls, params, float_image):
Expand Down Expand Up @@ -277,6 +281,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i
[GaussNoise, {}],
[ToFloat, {}],
[FromFloat, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
])
def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -343,6 +348,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima
[RandomShadow, {}],
[ChannelDropout, {}],
[ISONoise, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
])
def test_augmentations_wont_change_shape_rgb(augmentation_cls, params, image, mask):
Expand Down
38 changes: 38 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,44 @@ def test_brightness_contrast():
F._brightness_contrast_adjust_non_uint(image_float))


def test_swap_tiles_on_image_with_empty_tiles():
img = np.array([
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4],
], dtype=np.uint8)

result_img = F.swap_tiles_on_image(img, [])

assert np.array_equal(img, result_img)


def test_swap_tiles_on_image_with_non_empty_tiles():
img = np.array([
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4],
], dtype=np.uint8)

tiles = np.array([
[0, 0, 2, 2, 2, 2],
[2, 2, 0, 0, 2, 2]
])

target = np.array([
[3, 3, 1, 1],
[4, 4, 2, 2],
[3, 3, 1, 1],
[4, 4, 2, 2],
], dtype=np.uint8)

result_img = F.swap_tiles_on_image(img, tiles)

assert np.array_equal(result_img, target)


@pytest.mark.parametrize('dtype', list(F.MAX_VALUES_BY_DTYPE.keys()))
def test_solarize(dtype):
max_value = F.MAX_VALUES_BY_DTYPE[dtype]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
[A.RandomScale, {}],
[A.SmallestMaxSize, {}],
[A.LongestMaxSize, {}],
[A.RandomGridShuffle, {}],
[A.Solarize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -159,6 +160,7 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
[A.Resize, {'height': 64, 'width': 64}],
[A.SmallestMaxSize, {'max_size': 64, 'interpolation': cv2.INTER_CUBIC}],
[A.LongestMaxSize, {'max_size': 128, 'interpolation': cv2.INTER_CUBIC}],
[A.RandomGridShuffle, {'grid': (5, 5)}],
[A.Solarize, {'threshold': 32}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -556,6 +558,7 @@ def test_transform_pipeline_serialization_with_keypoints(seed, image, keypoints,
[A.Normalize, {}],
[A.ToFloat, {}],
[A.FromFloat, {}],
[A.RandomGridShuffle, {}],
[A.Solarize, {}],
])
@pytest.mark.parametrize('seed', TEST_SEEDS)
Expand Down

0 comments on commit 4cf6c36

Please sign in to comment.