diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index 666f3f935..126904a9d 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -799,12 +799,12 @@ def elastic_transform(image, alpha, sigma, alpha_affine, interpolation=cv2.INTER mapx = np.float32(x + dx) mapy = np.float32(y + dy) - return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode) + return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode, borderValue=value) @preserve_shape def elastic_transform_approx(image, alpha, sigma, alpha_affine, interpolation=cv2.INTER_LINEAR, - border_mode=cv2.BORDER_REFLECT_101, random_state=None): + border_mode=cv2.BORDER_REFLECT_101, value=None, random_state=None): """Elastic deformation of images as described in [Simard2003]_ (with modifications for speed). Based on https://gist.github.com/erniejunior/601cdf56d2b424757de5 @@ -830,7 +830,8 @@ def elastic_transform_approx(image, alpha, sigma, alpha_affine, interpolation=cv pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32) matrix = cv2.getAffineTransform(pts1, pts2) - image = cv2.warpAffine(image, matrix, (width, height), flags=interpolation, borderMode=border_mode) + image = cv2.warpAffine(image, matrix, (width, height), flags=interpolation, + borderMode=border_mode, value=value) dx = (random_state.rand(height, width).astype(np.float32) * 2 - 1) cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx) @@ -845,7 +846,7 @@ def elastic_transform_approx(image, alpha, sigma, alpha_affine, interpolation=cv mapx = np.float32(x + dx) mapy = np.float32(y + dy) - return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode) + return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode, borderValue=value) def invert(img): diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index 04549b3ef..ff006f9fe 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -35,6 +35,7 @@ class PadIfNeeded(DualTransform): Args: p (float): probability of applying the transform. Default: 1.0. value (list of ints [r, g, b]): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (int): padding value for mask if border_mode is cv2.BORDER_CONSTANT. Targets: image, mask, bbox, keypoints @@ -45,12 +46,13 @@ class PadIfNeeded(DualTransform): """ def __init__(self, min_height=1024, min_width=1024, border_mode=cv2.BORDER_REFLECT_101, - value=None, always_apply=False, p=1.0): + value=None, mask_value=None, always_apply=False, p=1.0): super(PadIfNeeded, self).__init__(always_apply, p) self.min_height = min_height self.min_width = min_width self.border_mode = border_mode self.value = value + self.mask_value = mask_value def update_params(self, params, **kwargs): params = super(PadIfNeeded, self).update_params(params, **kwargs) @@ -81,6 +83,10 @@ def apply(self, img, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params) return F.pad_with_params(img, pad_top, pad_bottom, pad_left, pad_right, border_mode=self.border_mode, value=self.value) + def apply_to_mask(self, img, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params): + return F.pad_with_params(img, pad_top, pad_bottom, pad_left, pad_right, + border_mode=self.border_mode, value=self.mask_value) + def apply_to_bbox(self, bbox, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, rows=0, cols=0, **params): x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols) bbox = [x_min + pad_left, y_min + pad_top, x_max + pad_left, y_max + pad_top] @@ -91,7 +97,7 @@ def apply_to_keypoint(self, keypoint, pad_top=0, pad_bottom=0, pad_left=0, pad_r return [x + pad_left, y + pad_top, a, s] def get_transform_init_args_names(self): - return ('min_height', 'min_width', 'border_mode', 'value') + return ('min_height', 'min_width', 'border_mode', 'value', 'mask_value') class Crop(DualTransform): @@ -383,6 +389,7 @@ class Rotate(DualTransform): cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. Default: cv2.BORDER_REFLECT_101 value (list of ints [r, g, b]): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (scalar or list of ints): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. p (float): probability of applying the transform. Default: 0.5. Targets: @@ -393,16 +400,20 @@ class Rotate(DualTransform): """ def __init__(self, limit=90, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101, - value=None, always_apply=False, p=.5): + value=None, mask_value=None, always_apply=False, p=.5): super(Rotate, self).__init__(always_apply, p) self.limit = to_tuple(limit) self.interpolation = interpolation self.border_mode = border_mode self.value = value + self.mask_value = mask_value def apply(self, img, angle=0, interpolation=cv2.INTER_LINEAR, **params): return F.rotate(img, angle, interpolation, self.border_mode, self.value) + def apply_to_mask(self, img, angle=0, **params): + return F.rotate(img, angle, cv2.INTER_NEAREST, self.border_mode, self.mask_value) + def get_params(self): return {'angle': random.uniform(self.limit[0], self.limit[1])} @@ -413,7 +424,7 @@ def apply_to_keypoint(self, keypoint, angle=0, **params): return F.keypoint_rotate(keypoint, angle, **params) def get_transform_init_args_names(self): - return ('limit', 'interpolation', 'border_mode', 'value') + return ('limit', 'interpolation', 'border_mode', 'value', 'mask_value') class RandomScale(DualTransform): @@ -477,6 +488,7 @@ class ShiftScaleRotate(DualTransform): cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. Default: cv2.BORDER_REFLECT_101 value (list of ints [r, g, b]): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (scalar or list of ints): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. p (float): probability of applying the transform. Default: 0.5. Targets: @@ -487,7 +499,7 @@ class ShiftScaleRotate(DualTransform): """ def __init__(self, shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, interpolation=cv2.INTER_LINEAR, - border_mode=cv2.BORDER_REFLECT_101, value=None, always_apply=False, p=0.5): + border_mode=cv2.BORDER_REFLECT_101, value=None, mask_value=None, always_apply=False, p=0.5): super(ShiftScaleRotate, self).__init__(always_apply, p) self.shift_limit = to_tuple(shift_limit) self.scale_limit = to_tuple(scale_limit, bias=1.0) @@ -495,10 +507,14 @@ def __init__(self, shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, interpo self.interpolation = interpolation self.border_mode = border_mode self.value = value + self.mask_value = mask_value def apply(self, img, angle=0, scale=0, dx=0, dy=0, interpolation=cv2.INTER_LINEAR, **params): return F.shift_scale_rotate(img, angle, scale, dx, dy, interpolation, self.border_mode, self.value) + def apply_to_mask(self, img, angle=0, scale=0, dx=0, dy=0, **params): + return F.shift_scale_rotate(img, angle, scale, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value) + def apply_to_keypoint(self, keypoint, angle=0, scale=0, dx=0, dy=0, rows=0, cols=0, interpolation=cv2.INTER_LINEAR, **params): return F.keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols) @@ -520,6 +536,7 @@ def get_transform_init_args(self): 'interpolation': self.interpolation, 'border_mode': self.border_mode, 'value': self.value, + 'mask_value': self.mask_value } @@ -779,24 +796,28 @@ class OpticalDistortion(DualTransform): """ def __init__(self, distort_limit=0.05, shift_limit=0.05, interpolation=cv2.INTER_LINEAR, - border_mode=cv2.BORDER_REFLECT_101, value=None, always_apply=False, p=0.5): + border_mode=cv2.BORDER_REFLECT_101, value=None, mask_value=None, always_apply=False, p=0.5): super(OpticalDistortion, self).__init__(always_apply, p) self.shift_limit = to_tuple(shift_limit) self.distort_limit = to_tuple(distort_limit) self.interpolation = interpolation self.border_mode = border_mode self.value = value + self.mask_value = mask_value def apply(self, img, k=0, dx=0, dy=0, interpolation=cv2.INTER_LINEAR, **params): return F.optical_distortion(img, k, dx, dy, interpolation, self.border_mode, self.value) + def apply_to_mask(self, img, k=0, dx=0, dy=0, **params): + return F.optical_distortion(img, k, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value) + def get_params(self): return {'k': random.uniform(self.distort_limit[0], self.distort_limit[1]), 'dx': round(random.uniform(self.shift_limit[0], self.shift_limit[1])), 'dy': round(random.uniform(self.shift_limit[0], self.shift_limit[1]))} def get_transform_init_args_names(self): - return ('distort_limit', 'shift_limit', 'interpolation', 'border_mode', 'value') + return ('distort_limit', 'shift_limit', 'interpolation', 'border_mode', 'value', 'mask_value') class GridDistortion(DualTransform): @@ -809,16 +830,22 @@ class GridDistortion(DualTransform): """ def __init__(self, num_steps=5, distort_limit=0.3, interpolation=cv2.INTER_LINEAR, - border_mode=cv2.BORDER_REFLECT_101, value=None, always_apply=False, p=0.5): + border_mode=cv2.BORDER_REFLECT_101, value=None, mask_value=None, always_apply=False, p=0.5): super(GridDistortion, self).__init__(always_apply, p) self.num_steps = num_steps self.distort_limit = to_tuple(distort_limit) self.interpolation = interpolation self.border_mode = border_mode self.value = value + self.mask_value = mask_value def apply(self, img, stepsx=[], stepsy=[], interpolation=cv2.INTER_LINEAR, **params): - return F.grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode, self.value) + return F.grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, + self.border_mode, self.value) + + def apply_to_mask(self, img, stepsx=[], stepsy=[], **params): + return F.grid_distortion(img, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, + self.border_mode, self.mask_value) def get_params(self): stepsx = [1 + random.uniform(self.distort_limit[0], self.distort_limit[1]) for i in @@ -831,7 +858,7 @@ def get_params(self): } def get_transform_init_args_names(self): - return ('num_steps', 'distort_limit', 'interpolation', 'border_mode', 'value') + return ('num_steps', 'distort_limit', 'interpolation', 'border_mode', 'value', 'mask_value') class ElasticTransform(DualTransform): @@ -855,7 +882,8 @@ class ElasticTransform(DualTransform): """ def __init__(self, alpha=1, sigma=50, alpha_affine=50, interpolation=cv2.INTER_LINEAR, - border_mode=cv2.BORDER_REFLECT_101, value=None, always_apply=False, approximate=False, p=0.5): + border_mode=cv2.BORDER_REFLECT_101, value=None, mask_value=None, + always_apply=False, approximate=False, p=0.5): super(ElasticTransform, self).__init__(always_apply, p) self.alpha = alpha self.alpha_affine = alpha_affine @@ -863,6 +891,7 @@ def __init__(self, alpha=1, sigma=50, alpha_affine=50, interpolation=cv2.INTER_L self.interpolation = interpolation self.border_mode = border_mode self.value = value + self.mask_value = mask_value self.approximate = approximate def apply(self, img, random_state=None, interpolation=cv2.INTER_LINEAR, **params): @@ -870,11 +899,17 @@ def apply(self, img, random_state=None, interpolation=cv2.INTER_LINEAR, **params self.border_mode, self.value, np.random.RandomState(random_state), self.approximate) + def apply_to_mask(self, img, random_state=None, **params): + return F.elastic_transform(img, self.alpha, self.sigma, self.alpha_affine, cv2.INTER_NEAREST, + self.border_mode, self.mask_value, np.random.RandomState(random_state), + self.approximate) + def get_params(self): return {'random_state': random.randint(0, 10000)} def get_transform_init_args_names(self): - return ('alpha', 'sigma', 'alpha_affine', 'interpolation', 'border_mode', 'value', 'approximate') + return ('alpha', 'sigma', 'alpha_affine', 'interpolation', 'border_mode', 'value', + 'mask_value', 'approximate') class Normalize(ImageOnlyTransform): diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index a80666d50..1eae4e7f3 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -1,3 +1,6 @@ +import random + +import cv2 import numpy as np import pytest @@ -352,3 +355,21 @@ def test_image_only_crop_around_bbox_augmentation(augmentation_cls, params, imag annotations = {'image': image, 'cropping_bbox': [-59, 77, 177, 231]} data = aug(**annotations) assert data['image'].dtype == np.uint8 + + +@pytest.mark.parametrize(['augmentation_cls', 'params'], [ + [PadIfNeeded, {'min_height': 514, 'min_width': 514, + 'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}], + [Rotate, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}], + [ShiftScaleRotate, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}], + [OpticalDistortion, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}], + [ElasticTransform, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}], + [GridDistortion, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}], +]) +def test_mask_fill_value(augmentation_cls, params): + random.seed(42) + aug = augmentation_cls(p=1, **params) + input = {'image': np.zeros((512, 512), dtype=np.uint8) + 100, 'mask': np.ones((512, 512))} + output = aug(**input) + assert (output['image'] == 100).all() + assert (output['mask'] == 1).all() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index f3a04ba93..6a7bd995b 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -169,6 +169,8 @@ def test_multiprocessing_support(augmentation_cls, params): pool = Pool(8) pool.map(__test_multiprocessing_support_proc, map(lambda x: (x, aug), [image] * 100)) + pool.close() + pool.join() def test_force_apply():