Skip to content

Commit

Permalink
Separate fill values for image and mask targets (#283)
Browse files Browse the repository at this point in the history
* Update transforms.py

Added `mask_value` parameters to specify fill value applied to masks.

* Add mask_value for all augmentations that supports it.

* Fix error in params['fill_value']

* Fix codestyle

* Add mask_value to serialized parameters

* Added closing pool call

* Bump up version of opencv

* Add missing fill values for cv2.remap functions

* Revert change in opencv packages
  • Loading branch information
BloodAxe committed Jul 4, 2019
1 parent d85bab5 commit 2c1a148
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 16 deletions.
9 changes: 5 additions & 4 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,12 +799,12 @@ def elastic_transform(image, alpha, sigma, alpha_affine, interpolation=cv2.INTER
mapx = np.float32(x + dx)
mapy = np.float32(y + dy)

return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode)
return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode, borderValue=value)


@preserve_shape
def elastic_transform_approx(image, alpha, sigma, alpha_affine, interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101, random_state=None):
border_mode=cv2.BORDER_REFLECT_101, value=None, random_state=None):
"""Elastic deformation of images as described in [Simard2003]_ (with modifications for speed).
Based on https://gist.github.com/erniejunior/601cdf56d2b424757de5
Expand All @@ -830,7 +830,8 @@ def elastic_transform_approx(image, alpha, sigma, alpha_affine, interpolation=cv
pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
matrix = cv2.getAffineTransform(pts1, pts2)

image = cv2.warpAffine(image, matrix, (width, height), flags=interpolation, borderMode=border_mode)
image = cv2.warpAffine(image, matrix, (width, height), flags=interpolation,
borderMode=border_mode, value=value)

dx = (random_state.rand(height, width).astype(np.float32) * 2 - 1)
cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx)
Expand All @@ -845,7 +846,7 @@ def elastic_transform_approx(image, alpha, sigma, alpha_affine, interpolation=cv
mapx = np.float32(x + dx)
mapy = np.float32(y + dy)

return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode)
return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode, borderValue=value)


def invert(img):
Expand Down
59 changes: 47 additions & 12 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class PadIfNeeded(DualTransform):
Args:
p (float): probability of applying the transform. Default: 1.0.
value (list of ints [r, g, b]): padding value if border_mode is cv2.BORDER_CONSTANT.
mask_value (int): padding value for mask if border_mode is cv2.BORDER_CONSTANT.
Targets:
image, mask, bbox, keypoints
Expand All @@ -45,12 +46,13 @@ class PadIfNeeded(DualTransform):
"""

def __init__(self, min_height=1024, min_width=1024, border_mode=cv2.BORDER_REFLECT_101,
value=None, always_apply=False, p=1.0):
value=None, mask_value=None, always_apply=False, p=1.0):
super(PadIfNeeded, self).__init__(always_apply, p)
self.min_height = min_height
self.min_width = min_width
self.border_mode = border_mode
self.value = value
self.mask_value = mask_value

def update_params(self, params, **kwargs):
params = super(PadIfNeeded, self).update_params(params, **kwargs)
Expand Down Expand Up @@ -81,6 +83,10 @@ def apply(self, img, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params)
return F.pad_with_params(img, pad_top, pad_bottom, pad_left, pad_right,
border_mode=self.border_mode, value=self.value)

def apply_to_mask(self, img, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params):
return F.pad_with_params(img, pad_top, pad_bottom, pad_left, pad_right,
border_mode=self.border_mode, value=self.mask_value)

def apply_to_bbox(self, bbox, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, rows=0, cols=0, **params):
x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)
bbox = [x_min + pad_left, y_min + pad_top, x_max + pad_left, y_max + pad_top]
Expand All @@ -91,7 +97,7 @@ def apply_to_keypoint(self, keypoint, pad_top=0, pad_bottom=0, pad_left=0, pad_r
return [x + pad_left, y + pad_top, a, s]

def get_transform_init_args_names(self):
return ('min_height', 'min_width', 'border_mode', 'value')
return ('min_height', 'min_width', 'border_mode', 'value', 'mask_value')


class Crop(DualTransform):
Expand Down Expand Up @@ -383,6 +389,7 @@ class Rotate(DualTransform):
cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
Default: cv2.BORDER_REFLECT_101
value (list of ints [r, g, b]): padding value if border_mode is cv2.BORDER_CONSTANT.
mask_value (scalar or list of ints): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
p (float): probability of applying the transform. Default: 0.5.
Targets:
Expand All @@ -393,16 +400,20 @@ class Rotate(DualTransform):
"""

def __init__(self, limit=90, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101,
value=None, always_apply=False, p=.5):
value=None, mask_value=None, always_apply=False, p=.5):
super(Rotate, self).__init__(always_apply, p)
self.limit = to_tuple(limit)
self.interpolation = interpolation
self.border_mode = border_mode
self.value = value
self.mask_value = mask_value

def apply(self, img, angle=0, interpolation=cv2.INTER_LINEAR, **params):
return F.rotate(img, angle, interpolation, self.border_mode, self.value)

def apply_to_mask(self, img, angle=0, **params):
return F.rotate(img, angle, cv2.INTER_NEAREST, self.border_mode, self.mask_value)

def get_params(self):
return {'angle': random.uniform(self.limit[0], self.limit[1])}

Expand All @@ -413,7 +424,7 @@ def apply_to_keypoint(self, keypoint, angle=0, **params):
return F.keypoint_rotate(keypoint, angle, **params)

def get_transform_init_args_names(self):
return ('limit', 'interpolation', 'border_mode', 'value')
return ('limit', 'interpolation', 'border_mode', 'value', 'mask_value')


class RandomScale(DualTransform):
Expand Down Expand Up @@ -477,6 +488,7 @@ class ShiftScaleRotate(DualTransform):
cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
Default: cv2.BORDER_REFLECT_101
value (list of ints [r, g, b]): padding value if border_mode is cv2.BORDER_CONSTANT.
mask_value (scalar or list of ints): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
p (float): probability of applying the transform. Default: 0.5.
Targets:
Expand All @@ -487,18 +499,22 @@ class ShiftScaleRotate(DualTransform):
"""

def __init__(self, shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101, value=None, always_apply=False, p=0.5):
border_mode=cv2.BORDER_REFLECT_101, value=None, mask_value=None, always_apply=False, p=0.5):
super(ShiftScaleRotate, self).__init__(always_apply, p)
self.shift_limit = to_tuple(shift_limit)
self.scale_limit = to_tuple(scale_limit, bias=1.0)
self.rotate_limit = to_tuple(rotate_limit)
self.interpolation = interpolation
self.border_mode = border_mode
self.value = value
self.mask_value = mask_value

def apply(self, img, angle=0, scale=0, dx=0, dy=0, interpolation=cv2.INTER_LINEAR, **params):
return F.shift_scale_rotate(img, angle, scale, dx, dy, interpolation, self.border_mode, self.value)

def apply_to_mask(self, img, angle=0, scale=0, dx=0, dy=0, **params):
return F.shift_scale_rotate(img, angle, scale, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)

def apply_to_keypoint(self, keypoint, angle=0, scale=0, dx=0, dy=0, rows=0, cols=0, interpolation=cv2.INTER_LINEAR,
**params):
return F.keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols)
Expand All @@ -520,6 +536,7 @@ def get_transform_init_args(self):
'interpolation': self.interpolation,
'border_mode': self.border_mode,
'value': self.value,
'mask_value': self.mask_value
}


Expand Down Expand Up @@ -779,24 +796,28 @@ class OpticalDistortion(DualTransform):
"""

def __init__(self, distort_limit=0.05, shift_limit=0.05, interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101, value=None, always_apply=False, p=0.5):
border_mode=cv2.BORDER_REFLECT_101, value=None, mask_value=None, always_apply=False, p=0.5):
super(OpticalDistortion, self).__init__(always_apply, p)
self.shift_limit = to_tuple(shift_limit)
self.distort_limit = to_tuple(distort_limit)
self.interpolation = interpolation
self.border_mode = border_mode
self.value = value
self.mask_value = mask_value

def apply(self, img, k=0, dx=0, dy=0, interpolation=cv2.INTER_LINEAR, **params):
return F.optical_distortion(img, k, dx, dy, interpolation, self.border_mode, self.value)

def apply_to_mask(self, img, k=0, dx=0, dy=0, **params):
return F.optical_distortion(img, k, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)

def get_params(self):
return {'k': random.uniform(self.distort_limit[0], self.distort_limit[1]),
'dx': round(random.uniform(self.shift_limit[0], self.shift_limit[1])),
'dy': round(random.uniform(self.shift_limit[0], self.shift_limit[1]))}

def get_transform_init_args_names(self):
return ('distort_limit', 'shift_limit', 'interpolation', 'border_mode', 'value')
return ('distort_limit', 'shift_limit', 'interpolation', 'border_mode', 'value', 'mask_value')


class GridDistortion(DualTransform):
Expand All @@ -809,16 +830,22 @@ class GridDistortion(DualTransform):
"""

def __init__(self, num_steps=5, distort_limit=0.3, interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101, value=None, always_apply=False, p=0.5):
border_mode=cv2.BORDER_REFLECT_101, value=None, mask_value=None, always_apply=False, p=0.5):
super(GridDistortion, self).__init__(always_apply, p)
self.num_steps = num_steps
self.distort_limit = to_tuple(distort_limit)
self.interpolation = interpolation
self.border_mode = border_mode
self.value = value
self.mask_value = mask_value

def apply(self, img, stepsx=[], stepsy=[], interpolation=cv2.INTER_LINEAR, **params):
return F.grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode, self.value)
return F.grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation,
self.border_mode, self.value)

def apply_to_mask(self, img, stepsx=[], stepsy=[], **params):
return F.grid_distortion(img, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST,
self.border_mode, self.mask_value)

def get_params(self):
stepsx = [1 + random.uniform(self.distort_limit[0], self.distort_limit[1]) for i in
Expand All @@ -831,7 +858,7 @@ def get_params(self):
}

def get_transform_init_args_names(self):
return ('num_steps', 'distort_limit', 'interpolation', 'border_mode', 'value')
return ('num_steps', 'distort_limit', 'interpolation', 'border_mode', 'value', 'mask_value')


class ElasticTransform(DualTransform):
Expand All @@ -855,26 +882,34 @@ class ElasticTransform(DualTransform):
"""

def __init__(self, alpha=1, sigma=50, alpha_affine=50, interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101, value=None, always_apply=False, approximate=False, p=0.5):
border_mode=cv2.BORDER_REFLECT_101, value=None, mask_value=None,
always_apply=False, approximate=False, p=0.5):
super(ElasticTransform, self).__init__(always_apply, p)
self.alpha = alpha
self.alpha_affine = alpha_affine
self.sigma = sigma
self.interpolation = interpolation
self.border_mode = border_mode
self.value = value
self.mask_value = mask_value
self.approximate = approximate

def apply(self, img, random_state=None, interpolation=cv2.INTER_LINEAR, **params):
return F.elastic_transform(img, self.alpha, self.sigma, self.alpha_affine, interpolation,
self.border_mode, self.value, np.random.RandomState(random_state),
self.approximate)

def apply_to_mask(self, img, random_state=None, **params):
return F.elastic_transform(img, self.alpha, self.sigma, self.alpha_affine, cv2.INTER_NEAREST,
self.border_mode, self.mask_value, np.random.RandomState(random_state),
self.approximate)

def get_params(self):
return {'random_state': random.randint(0, 10000)}

def get_transform_init_args_names(self):
return ('alpha', 'sigma', 'alpha_affine', 'interpolation', 'border_mode', 'value', 'approximate')
return ('alpha', 'sigma', 'alpha_affine', 'interpolation', 'border_mode', 'value',
'mask_value', 'approximate')


class Normalize(ImageOnlyTransform):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import random

import cv2
import numpy as np
import pytest

Expand Down Expand Up @@ -352,3 +355,21 @@ def test_image_only_crop_around_bbox_augmentation(augmentation_cls, params, imag
annotations = {'image': image, 'cropping_bbox': [-59, 77, 177, 231]}
data = aug(**annotations)
assert data['image'].dtype == np.uint8


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[PadIfNeeded, {'min_height': 514, 'min_width': 514,
'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}],
[Rotate, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}],
[ShiftScaleRotate, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}],
[OpticalDistortion, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}],
[ElasticTransform, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}],
[GridDistortion, {'border_mode': cv2.BORDER_CONSTANT, 'value': 100, 'mask_value': 1}],
])
def test_mask_fill_value(augmentation_cls, params):
random.seed(42)
aug = augmentation_cls(p=1, **params)
input = {'image': np.zeros((512, 512), dtype=np.uint8) + 100, 'mask': np.ones((512, 512))}
output = aug(**input)
assert (output['image'] == 100).all()
assert (output['mask'] == 1).all()
2 changes: 2 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def test_multiprocessing_support(augmentation_cls, params):

pool = Pool(8)
pool.map(__test_multiprocessing_support_proc, map(lambda x: (x, aug), [image] * 100))
pool.close()
pool.join()


def test_force_apply():
Expand Down

0 comments on commit 2c1a148

Please sign in to comment.