From cb0280d3eb37a7a81af0bae97b47e6c38968c619 Mon Sep 17 00:00:00 2001 From: Dipet Date: Tue, 11 May 2021 19:13:06 +0300 Subject: [PATCH 1/9] PiecewiseAffine implementation --- .../augmentations/geometric/functional.py | 178 +++++++++++++++- .../augmentations/geometric/transforms.py | 193 +++++++++++++++++- 2 files changed, 363 insertions(+), 8 deletions(-) diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index 291a9a3f0..88ea141a7 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -1,19 +1,15 @@ import cv2 import math import numpy as np +import six.moves as sm import skimage.transform from scipy.ndimage.filters import gaussian_filter from ..bbox_utils import denormalize_bbox, normalize_bbox -from ..functional import ( - angle_2pi_range, - preserve_channel_dim, - _maybe_process_in_chunks, - preserve_shape, -) +from ..functional import angle_2pi_range, preserve_channel_dim, _maybe_process_in_chunks, preserve_shape, clipped -from typing import Union, List, Sequence +from typing import Union, List, Sequence, Tuple, Optional def bbox_rot90(bbox, factor, rows, cols): # skipcq: PYL-W0613 @@ -597,3 +593,171 @@ def safe_rotate_enlarged_img_size(angle: float, rows: int, cols: int): return int(r_cols), int(r_rows) else: return int(r_rows), int(r_cols) + + +@clipped +def piecewise_affine( + img: np.ndarray, + matrix: skimage.transform.PiecewiseAffineTransform, + interpolation: int, + mode: str, + cval: float, +) -> np.ndarray: + return skimage.transform.warp( + img, matrix, order=interpolation, mode=mode, cval=cval, preserve_range=True, output_shape=img.shape + ) + + +def to_distance_maps( + keypoints: Sequence[Sequence[float]], height: int, width: int, inverted: bool = False +) -> np.ndarray: + """Generate a ``(H,W,N)`` array of distance maps for ``N`` keypoints. + + The ``n``-th distance map contains at every location ``(y, x)`` the + euclidean distance to the ``n``-th keypoint. + + This function can be used as a helper when augmenting keypoints with a + method that only supports the augmentation of images. + + Args: + keypoint (sequence of float): keypoint coordinates + height (int): image height + width (int): image width + inverted (bool): If ``True``, inverted distance maps are returned where each + distance value d is replaced by ``d/(d+1)``, i.e. the distance + maps have values in the range ``(0.0, 1.0]`` with ``1.0`` denoting + exactly the position of the respective keypoint. + + Returns: + (H,W,N) ndarray + A ``float32`` array containing ``N`` distance maps for ``N`` + keypoints. Each location ``(y, x, n)`` in the array denotes the + euclidean distance at ``(y, x)`` to the ``n``-th keypoint. + If `inverted` is ``True``, the distance ``d`` is replaced + by ``d/(d+1)``. The height and width of the array match the + height and width in ``KeypointsOnImage.shape``. + """ + distance_maps = np.zeros((height, width, len(keypoints)), dtype=np.float32) + + yy = np.arange(0, height) + xx = np.arange(0, width) + grid_xx, grid_yy = np.meshgrid(xx, yy) + + for i, (x, y) in enumerate(keypoints): + distance_maps[:, :, i] = (grid_xx - x) ** 2 + (grid_yy - y) ** 2 + + distance_maps = np.sqrt(distance_maps) + if inverted: + return 1 / (distance_maps + 1) + return distance_maps + + +def from_distance_maps( + distance_maps: np.ndarray, + inverted: bool, + if_not_found_coords: Optional[Union[Sequence[int], dict]], + threshold: Optional[float] = None, +) -> List[Tuple[float, float]]: + """Convert outputs of ``to_distance_maps()`` to ``KeypointsOnImage``. + This is the inverse of `to_distance_maps`. + + Args: + distance_maps (np.ndarray): The distance maps. ``N`` is the number of keypoints. + inverted (bool): Whether the given distance maps were generated in inverted mode + (i.e. :func:`KeypointsOnImage.to_distance_maps` was called with ``inverted=True``) or in non-inverted mode. + if_not_found_coords (tuple, list, dict or None, optional): + Coordinates to use for keypoints that cannot be found in `distance_maps`. + + * If this is a ``list``/``tuple``, it must contain two ``int`` values. + * If it is a ``dict``, it must contain the keys ``x`` and ``y`` with each containing one ``int`` value. + * If this is ``None``, then the keypoint will not be added. + threshold (float): The search for keypoints works by searching for the + argmin (non-inverted) or argmax (inverted) in each channel. This + parameters contains the maximum (non-inverted) or minimum (inverted) value to accept in order to view a hit + as a keypoint. Use ``None`` to use no min/max. + nb_channels (None, int): Number of channels of the image on which the keypoints are placed. + Some keypoint augmenters require that information. If set to ``None``, the keypoint's shape will be set + to ``(height, width)``, otherwise ``(height, width, nb_channels)``. + """ + assert ( + distance_maps.ndim == 3 + ), f"Expected three-dimensional input, got {distance_maps.ndim} dimensions and shape {distance_maps.shape}." + height, width, nb_keypoints = distance_maps.shape + + drop_if_not_found = False + if if_not_found_coords is None: + drop_if_not_found = True + if_not_found_x = -1 + if_not_found_y = -1 + elif isinstance(if_not_found_coords, (tuple, list)): + assert ( + len(if_not_found_coords) == 2 + ), f"Expected tuple/list 'if_not_found_coords' to contain exactly two entries, got {len(if_not_found_coords)}." + if_not_found_x = if_not_found_coords[0] + if_not_found_y = if_not_found_coords[1] + elif isinstance(if_not_found_coords, dict): + if_not_found_x = if_not_found_coords["x"] + if_not_found_y = if_not_found_coords["y"] + else: + raise Exception( + f"Expected if_not_found_coords to be None or tuple or list or dict, got {type(if_not_found_coords)}." + ) + + keypoints = [] + for i in sm.xrange(nb_keypoints): + if inverted: + hitidx_flat = np.argmax(distance_maps[..., i]) + else: + hitidx_flat = np.argmin(distance_maps[..., i]) + hitidx_ndim = np.unravel_index(hitidx_flat, (height, width)) + if not inverted and threshold is not None: + found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] < threshold + elif inverted and threshold is not None: + found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] >= threshold + else: + found = True + if found: + keypoints.append((hitidx_ndim[1], hitidx_ndim[0])) + else: + if not drop_if_not_found: + keypoints.append((if_not_found_x, if_not_found_y)) + + return keypoints + + +def keypoint_piecewise_affine( + keypoint: Sequence[float], + matrix: skimage.transform.PiecewiseAffineTransform, + h: int, + w: int, +) -> Tuple[float, float, float, float]: + x, y, a, s = keypoint + dist_maps = to_distance_maps([(x, y)], h, w, True) + dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0) + x, y = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, 0.01)[0] + return x, y, a, s + + +def bbox_piecewise_affine( + bbox: Sequence[float], + matrix: skimage.transform.PiecewiseAffineTransform, + h: int, + w: int, +) -> Tuple[float, float, float, float]: + x1, y1, x2, y2 = denormalize_bbox(tuple(bbox), h, w) + keypoints = [ + (x1, y1), + (x2, y1), + (x2, y2), + (x1, y2), + ] + dist_maps = to_distance_maps(keypoints, h, w, True) + dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0) + keypoints = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, 0.01) + keypoints = [i for i in keypoints if 0 <= i[0] < w and 0 <= i[1] < h] + keypoints_arr = np.array(keypoints) + x1 = keypoints_arr[:, 0].min() + y1 = keypoints_arr[:, 1].min() + x2 = keypoints_arr[:, 0].max() + y2 = keypoints_arr[:, 1].max() + return normalize_bbox((x1, y1, x2, y2), h, w) diff --git a/albumentations/augmentations/geometric/transforms.py b/albumentations/augmentations/geometric/transforms.py index a9a7956f6..5d1edd35d 100644 --- a/albumentations/augmentations/geometric/transforms.py +++ b/albumentations/augmentations/geometric/transforms.py @@ -8,7 +8,7 @@ from . import functional as F from ...core.transforms_interface import DualTransform, to_tuple -__all__ = ["ShiftScaleRotate", "ElasticTransform", "Perspective", "Affine"] +__all__ = ["ShiftScaleRotate", "ElasticTransform", "Perspective", "Affine", "PiecewiseAffine"] class ShiftScaleRotate(DualTransform): @@ -690,3 +690,194 @@ def _compute_affine_warp_output_shape( matrix_to_fit = skimage.transform.SimilarityTransform(translation=translation) matrix = matrix + matrix_to_fit return matrix, output_shape + + +class _ConcavePolygonRecoverer: + def __init__(self): + pass + + +class PiecewiseAffine(DualTransform): + """Apply affine transformations that differ between local neighbourhoods. + This augmentation places a regular grid of points on an image and randomly moves the neighbourhood of these point + around via affine transformations. This leads to local distortions. + + This is mostly a wrapper around scikit-image's ``PiecewiseAffine``. + See also ``Affine`` for a similar technique. + + Note: + This augmenter is very slow. Try to use ``ElasticTransformation`` instead, which is at least 10x faster. + + Note: + For coordinate-based inputs (keypoints, bounding boxes, polygons, ...), + this augmenter still has to perform an image-based augmentation, + which will make it significantly slower and not fully correct for such inputs than other transforms. + + Args: + scale (float, tuple of float): Each point on the regular grid is moved around via a normal distribution. + This scale factor is equivalent to the normal distribution's sigma. + Note that the jitter (how far each point is moved in which direction) is multiplied by the height/width of + the image if ``absolute_scale=False`` (default), so this scale can be the same for different sized images. + Recommended values are in the range ``0.01`` to ``0.05`` (weak to strong augmentations). + * If a single ``float``, then that value will always be used as the scale. + * If a tuple ``(a, b)`` of ``float`` s, then a random value will + be uniformly sampled per image from the interval ``[a, b]``. + nb_rows (int, tuple of int): Number of rows of points that the regular grid should have. + Must be at least ``2``. For large images, you might want to pick a higher value than ``4``. + You might have to then adjust scale to lower values. + * If a single ``int``, then that value will always be used as the number of rows. + * If a tuple ``(a, b)``, then a value from the discrete interval + ``[a..b]`` will be uniformly sampled per image. + nb_cols (int, tuple of int): Number of columns. Analogous to `nb_rows`. + interpolation (int): The order of interpolation. The order has to be in the range 0-5: + - 0: Nearest-neighbor + - 1: Bi-linear (default) + - 2: Bi-quadratic + - 3: Bi-cubic + - 4: Bi-quartic + - 5: Bi-quintic + mask_interpolation (int): same as interpolation but for mask. + cval (number): The constant value to use when filling in newly created pixels. + cval_mask (number): Same as cval but only for masks. + mode (str): {'constant', 'edge', 'symmetric', 'reflect', 'wrap'}, optional + Points outside the boundaries of the input are filled according + to the given mode. Modes match the behaviour of `numpy.pad`. + absolute_scale (bool): Take `scale` as an absolute value rather than a relative value. + polygon_recoverer ('auto', None or _ConcavePolygonRecoverer): The class to use to repair invalid polygons. + If ``"auto"``, a new instance of `_ConcavePolygonRecoverer` will be created. + If ``None``, no polygon recoverer will be used. + If an object, then that object will be used and must provide a ``recover_from()`` method, similar to + `_ConcavePolygonRecoverer`. + + Targets: + image, mask, keypoints, bboxes + + Image types: + uint8, float32 + + """ + + def __init__( + self, + scale: Union[float, Sequence[float]] = (0.03, 0.05), + nb_rows: Union[int, Sequence[int]] = 4, + nb_cols: Union[int, Sequence[int]] = 4, + interpolation: int = 1, + mask_interpolation: int = 0, + cval: int = 0, + cval_mask: int = 0, + mode: str = "constant", + absolute_scale: bool = False, + polygon_recoverer: Optional[Union[str, _ConcavePolygonRecoverer]] = None, + always_apply: bool = False, + p: float = 0.5, + ): + super(PiecewiseAffine, self).__init__(always_apply, p) + + self.scale = to_tuple(scale, scale) + self.nb_rows = to_tuple(nb_rows, nb_rows) + self.nb_cols = to_tuple(nb_cols, nb_cols) + self.interpolation = interpolation + self.mask_interpolation = mask_interpolation + self.cval = cval + self.cval_mask = cval_mask + self.mode = mode + self.absolute_scale = absolute_scale + + self.polygon_recoverer = polygon_recoverer + if polygon_recoverer == "auto": + self.polygon_recoverer = _ConcavePolygonRecoverer() + + def get_transform_init_args_names(self): + return ( + "scale", + "nb_rows", + "nb_cols", + "interpolation", + "mask_interpolation", + "cval", + "cval_mask", + "mode", + "absolute_scale", + ) + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params) -> dict: + h, w = params["image"].shape[:2] + + nb_rows = np.clip(random.randint(*self.nb_rows), 2, None) + nb_cols = np.clip(random.randint(*self.nb_cols), 2, None) + nb_cells = nb_cols * nb_rows + scale = random.uniform(*self.scale) + + state = np.random.RandomState(random.randint(0, 1 << 31)) + jitter = state.normal(0, scale, (nb_cells, 2)) + if not np.any(jitter > 0): + return {"matrix": None} + + y = np.linspace(0, h, nb_rows) + x = np.linspace(0, w, nb_cols) + + # (H, W) and (H, W) for H=rows, W=cols + xx_src, yy_src = np.meshgrid(x, y) + + # (1, HW, 2) => (HW, 2) for H=rows, W=cols + points_src = np.dstack([yy_src.flat, xx_src.flat])[0] + + if self.absolute_scale: + jitter[:, 0] = jitter[:, 0] / h if h > 0 else 0.0 + jitter[:, 1] = jitter[:, 1] / jitter[1] if w > 0 else 0.0 + + jitter[:, 0] = jitter[:, 0] * h + jitter[:, 1] = jitter[:, 1] * w + + points_dest = np.copy(points_src) + points_dest[:, 0] = points_dest[:, 0] + jitter[:, 0] + points_dest[:, 1] = points_dest[:, 1] + jitter[:, 1] + + # Restrict all destination points to be inside the image plane. + # This is necessary, as otherwise keypoints could be augmented + # outside of the image plane and these would be replaced by + # (-1, -1), which would not conform with the behaviour of the other augmenters. + points_dest[:, 0] = np.clip(points_dest[:, 0], 0, h - 1) + points_dest[:, 1] = np.clip(points_dest[:, 1], 0, w - 1) + + matrix = skimage.transform.PiecewiseAffineTransform() + matrix.estimate(points_src[:, ::-1], points_dest[:, ::-1]) + + return { + "matrix": matrix, + } + + def apply( + self, img: np.ndarray, matrix: skimage.transform.PiecewiseAffineTransform = None, **params + ) -> np.ndarray: + return F.piecewise_affine(img, matrix, self.interpolation, self.mode, self.cval) + + def apply_to_mask( + self, img: np.ndarray, matrix: skimage.transform.PiecewiseAffineTransform = None, **params + ) -> np.ndarray: + return F.piecewise_affine(img, matrix, self.mask_interpolation, self.mode, self.cval_mask) + + def apply_to_bbox( + self, + bbox: Sequence[float], + rows: int = 0, + cols: int = 0, + matrix: skimage.transform.PiecewiseAffineTransform = None, + **params + ) -> Sequence[float]: + return F.bbox_piecewise_affine(bbox, matrix, rows, cols) + + def apply_to_keypoint( + self, + keypoint: Sequence[float], + rows: int = 0, + cols: int = 0, + matrix: skimage.transform.PiecewiseAffineTransform = None, + **params + ): + return F.keypoint_piecewise_affine(keypoint, matrix, rows, cols) From c3ddec59d89845070fd025881762562fce64266f Mon Sep 17 00:00:00 2001 From: mdruzhinin Date: Sat, 15 May 2021 15:35:24 +0300 Subject: [PATCH 2/9] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e5cdaf8bf..8a8e35ff5 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,7 @@ Spatial-level transforms will simultaneously change both an input image as well | [OpticalDistortion](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.OpticalDistortion) | ✓ | ✓ | | | | [PadIfNeeded](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.PadIfNeeded) | ✓ | ✓ | ✓ | ✓ | | [Perspective](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.Perspective) | ✓ | ✓ | ✓ | ✓ | +| [PiecewiseAffine](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.PiecewiseAffine) | ✓ | ✓ | ✓ | ✓ | | [RandomCrop](https://albumentations.ai/docs/api_reference/augmentations/crops/transforms/#albumentations.augmentations.crops.transforms.RandomCrop) | ✓ | ✓ | ✓ | ✓ | | [RandomCropNearBBox](https://albumentations.ai/docs/api_reference/augmentations/crops/transforms/#albumentations.augmentations.crops.transforms.RandomCropNearBBox) | ✓ | ✓ | ✓ | ✓ | | [RandomGridShuffle](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.RandomGridShuffle) | ✓ | ✓ | | | From 8d8a787be8dd9fe936e4ad967e1aa7d27a2e53c2 Mon Sep 17 00:00:00 2001 From: mdruzhinin Date: Sat, 15 May 2021 16:00:33 +0300 Subject: [PATCH 3/9] PiecewiseAffine remove polygon recoverer --- .../augmentations/geometric/transforms.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/albumentations/augmentations/geometric/transforms.py b/albumentations/augmentations/geometric/transforms.py index 5d1edd35d..cee4232b3 100644 --- a/albumentations/augmentations/geometric/transforms.py +++ b/albumentations/augmentations/geometric/transforms.py @@ -692,11 +692,6 @@ def _compute_affine_warp_output_shape( return matrix, output_shape -class _ConcavePolygonRecoverer: - def __init__(self): - pass - - class PiecewiseAffine(DualTransform): """Apply affine transformations that differ between local neighbourhoods. This augmentation places a regular grid of points on an image and randomly moves the neighbourhood of these point @@ -743,11 +738,6 @@ class PiecewiseAffine(DualTransform): Points outside the boundaries of the input are filled according to the given mode. Modes match the behaviour of `numpy.pad`. absolute_scale (bool): Take `scale` as an absolute value rather than a relative value. - polygon_recoverer ('auto', None or _ConcavePolygonRecoverer): The class to use to repair invalid polygons. - If ``"auto"``, a new instance of `_ConcavePolygonRecoverer` will be created. - If ``None``, no polygon recoverer will be used. - If an object, then that object will be used and must provide a ``recover_from()`` method, similar to - `_ConcavePolygonRecoverer`. Targets: image, mask, keypoints, bboxes @@ -768,7 +758,6 @@ def __init__( cval_mask: int = 0, mode: str = "constant", absolute_scale: bool = False, - polygon_recoverer: Optional[Union[str, _ConcavePolygonRecoverer]] = None, always_apply: bool = False, p: float = 0.5, ): @@ -784,10 +773,6 @@ def __init__( self.mode = mode self.absolute_scale = absolute_scale - self.polygon_recoverer = polygon_recoverer - if polygon_recoverer == "auto": - self.polygon_recoverer = _ConcavePolygonRecoverer() - def get_transform_init_args_names(self): return ( "scale", From 907838fae7eeb90f07367b09d3f92f1a133bdb72 Mon Sep 17 00:00:00 2001 From: mdruzhinin Date: Sat, 15 May 2021 16:25:38 +0300 Subject: [PATCH 4/9] PiecewiseAffine fix jitter --- albumentations/augmentations/geometric/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/albumentations/augmentations/geometric/transforms.py b/albumentations/augmentations/geometric/transforms.py index cee4232b3..746608bb1 100644 --- a/albumentations/augmentations/geometric/transforms.py +++ b/albumentations/augmentations/geometric/transforms.py @@ -814,7 +814,7 @@ def get_params_dependent_on_targets(self, params) -> dict: if self.absolute_scale: jitter[:, 0] = jitter[:, 0] / h if h > 0 else 0.0 - jitter[:, 1] = jitter[:, 1] / jitter[1] if w > 0 else 0.0 + jitter[:, 1] = jitter[:, 1] / w if w > 0 else 0.0 jitter[:, 0] = jitter[:, 0] * h jitter[:, 1] = jitter[:, 1] * w From f44b527eb4c168b9d71f27bc15f3c7a67013825c Mon Sep 17 00:00:00 2001 From: mdruzhinin Date: Sat, 15 May 2021 16:27:46 +0300 Subject: [PATCH 5/9] PiecewiseAffine tests --- tests/test_augmentations.py | 10 ++++++++++ tests/test_serialization.py | 17 +++++++++++++++++ tests/test_transforms.py | 3 +++ 3 files changed, 30 insertions(+) diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index bdde330d3..5a456b879 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -73,6 +73,7 @@ CropAndPad, Superpixels, Affine, + PiecewiseAffine, ) @@ -209,6 +210,7 @@ def test_image_only_augmentations_with_float_values(augmentation_cls, params, fl [Perspective, {}], [CropAndPad, {"px": 10}], [Affine, {}], + [PiecewiseAffine, {}], ], ) def test_dual_augmentations(augmentation_cls, params, image, mask): @@ -243,6 +245,7 @@ def test_dual_augmentations(augmentation_cls, params, image, mask): [Perspective, {}], [CropAndPad, {"px": 10}], [Affine, {}], + [PiecewiseAffine, {}], ], ) def test_dual_augmentations_with_float_values(augmentation_cls, params, float_image, mask): @@ -340,6 +343,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask): [CropAndPad, {"px": 10}], [Superpixels, {}], [Affine, {}], + [PiecewiseAffine, {}], ], ) def test_augmentations_wont_change_input(augmentation_cls, params, image, mask): @@ -414,6 +418,7 @@ def test_augmentations_wont_change_input(augmentation_cls, params, image, mask): [CropAndPad, {"px": 10}], [Superpixels, {}], [Affine, {}], + [PiecewiseAffine, {}], ], ) def test_augmentations_wont_change_float_input(augmentation_cls, params, float_image): @@ -470,6 +475,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i [CropAndPad, {"px": 10}], [Superpixels, {}], [Affine, {}], + [PiecewiseAffine, {}], ], ) def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, image, mask): @@ -553,6 +559,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima [CropAndPad, {"px": 10}], [Superpixels, {}], [Affine, {}], + [PiecewiseAffine, {}], ], ) def test_augmentations_wont_change_shape_rgb(augmentation_cls, params, image, mask): @@ -589,6 +596,7 @@ def test_image_only_crop_around_bbox_augmentation(augmentation_cls, params, imag [ElasticTransform, {"border_mode": cv2.BORDER_CONSTANT, "value": 100, "mask_value": 1}], [GridDistortion, {"border_mode": cv2.BORDER_CONSTANT, "value": 100, "mask_value": 1}], [Affine, {"mode": cv2.BORDER_CONSTANT, "cval_mask": 1, "cval": 100}], + [PiecewiseAffine, {"mode": "constant", "cval_mask": 1, "cval": 100}], ], ) def test_mask_fill_value(augmentation_cls, params): @@ -623,6 +631,7 @@ def test_mask_fill_value(augmentation_cls, params): [CropAndPad, {"px": 10}], [Superpixels, {}], [Affine, {}], + [PiecewiseAffine, {}], ], ) def test_multichannel_image_augmentations(augmentation_cls, params): @@ -656,6 +665,7 @@ def test_multichannel_image_augmentations(augmentation_cls, params): [CropAndPad, {"px": 10}], [Superpixels, {}], [Affine, {}], + [PiecewiseAffine, {}], ], ) def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params): diff --git a/tests/test_serialization.py b/tests/test_serialization.py index e8f4bf9ea..a4ec4632f 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -80,6 +80,7 @@ def set_seed(seed): [A.CropAndPad, {"px": 10}], [A.Superpixels, {}], [A.Affine, {}], + [A.PiecewiseAffine, {}], ], ) @pytest.mark.parametrize("p", [0.5, 1]) @@ -315,6 +316,20 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m "fit_output": True, }, ], + [ + A.PiecewiseAffine, + { + "scale": 0.33, + "nb_rows": (10, 20), + "nb_cols": 33, + "interpolation": 2, + "mask_interpolation": 1, + "cval": 10, + "cval_mask": 20, + "mode": "edge", + "absolute_scale": True, + }, + ], ], ) @@ -416,6 +431,7 @@ def test_augmentations_serialization_to_file_with_custom_parameters( [A.Emboss, {}], [A.Superpixels, {}], [A.Affine, {}], + [A.PiecewiseAffine, {}], ], ) @pytest.mark.parametrize("p", [0.5, 1]) @@ -487,6 +503,7 @@ def test_augmentations_for_bboxes_serialization( [A.Emboss, {}], [A.Superpixels, {}], [A.Affine, {}], + [A.PiecewiseAffine, {}], ], ) @pytest.mark.parametrize("p", [0.5, 1]) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 8e7be5143..35fbc8170 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -165,6 +165,7 @@ def test_elastic_transform_interpolation(monkeypatch, interpolation): [A.GlassBlur, {}], [A.Perspective, {}], [A.Affine, {}], + [A.PiecewiseAffine, {}], ], ) def test_binary_mask_interpolation(augmentation_cls, params): @@ -193,6 +194,7 @@ def test_binary_mask_interpolation(augmentation_cls, params): [A.GlassBlur, {}], [A.Perspective, {}], [A.Affine, {}], + [A.PiecewiseAffine, {}], ], ) def test_semantic_mask_interpolation(augmentation_cls, params): @@ -232,6 +234,7 @@ def __test_multiprocessing_support_proc(args): [A.GlassBlur, {}], [A.Perspective, {}], [A.Affine, {}], + [A.PiecewiseAffine, {}], ], ) def test_multiprocessing_support(augmentation_cls, params, multiprocessing_context): From ceafa50c9f98bf9cd933319d24384602586d49a9 Mon Sep 17 00:00:00 2001 From: mdruzhinin Date: Sun, 16 May 2021 10:31:34 +0300 Subject: [PATCH 6/9] Remove six.xrange --- albumentations/augmentations/geometric/functional.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index 88ea141a7..afbdb0cc1 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -1,7 +1,6 @@ import cv2 import math import numpy as np -import six.moves as sm import skimage.transform from scipy.ndimage.filters import gaussian_filter @@ -704,7 +703,7 @@ def from_distance_maps( ) keypoints = [] - for i in sm.xrange(nb_keypoints): + for i in range(nb_keypoints): if inverted: hitidx_flat = np.argmax(distance_maps[..., i]) else: From 93721194318bd82f906f8ec3f6d790f095b56df5 Mon Sep 17 00:00:00 2001 From: mdruzhinin Date: Sun, 16 May 2021 13:32:40 +0300 Subject: [PATCH 7/9] Remove asserts --- .../augmentations/geometric/functional.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index afbdb0cc1..2f611dfe9 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -678,9 +678,11 @@ def from_distance_maps( Some keypoint augmenters require that information. If set to ``None``, the keypoint's shape will be set to ``(height, width)``, otherwise ``(height, width, nb_channels)``. """ - assert ( - distance_maps.ndim == 3 - ), f"Expected three-dimensional input, got {distance_maps.ndim} dimensions and shape {distance_maps.shape}." + if distance_maps.ndim != 3: + raise ValueError( + f"Expected three-dimensional input, " + f"got {distance_maps.ndim} dimensions and shape {distance_maps.shape}." + ) height, width, nb_keypoints = distance_maps.shape drop_if_not_found = False @@ -689,16 +691,18 @@ def from_distance_maps( if_not_found_x = -1 if_not_found_y = -1 elif isinstance(if_not_found_coords, (tuple, list)): - assert ( - len(if_not_found_coords) == 2 - ), f"Expected tuple/list 'if_not_found_coords' to contain exactly two entries, got {len(if_not_found_coords)}." + if len(if_not_found_coords) != 2: + raise ValueError( + f"Expected tuple/list 'if_not_found_coords' to contain exactly two entries, " + f"got {len(if_not_found_coords)}." + ) if_not_found_x = if_not_found_coords[0] if_not_found_y = if_not_found_coords[1] elif isinstance(if_not_found_coords, dict): if_not_found_x = if_not_found_coords["x"] if_not_found_y = if_not_found_coords["y"] else: - raise Exception( + raise ValueError( f"Expected if_not_found_coords to be None or tuple or list or dict, got {type(if_not_found_coords)}." ) From 63468a6d54953263ebee7d3313cccb46464e2f3b Mon Sep 17 00:00:00 2001 From: mdruzhinin Date: Sun, 16 May 2021 13:44:29 +0300 Subject: [PATCH 8/9] Added keypoints_threshold argument for PiecewiseAffine --- albumentations/augmentations/geometric/functional.py | 6 ++++-- albumentations/augmentations/geometric/transforms.py | 12 ++++++++++-- tests/test_serialization.py | 1 + 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index 2f611dfe9..2716f7073 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -733,11 +733,12 @@ def keypoint_piecewise_affine( matrix: skimage.transform.PiecewiseAffineTransform, h: int, w: int, + keypoints_threshold: float, ) -> Tuple[float, float, float, float]: x, y, a, s = keypoint dist_maps = to_distance_maps([(x, y)], h, w, True) dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0) - x, y = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, 0.01)[0] + x, y = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, keypoints_threshold)[0] return x, y, a, s @@ -746,6 +747,7 @@ def bbox_piecewise_affine( matrix: skimage.transform.PiecewiseAffineTransform, h: int, w: int, + keypoints_threshold: float, ) -> Tuple[float, float, float, float]: x1, y1, x2, y2 = denormalize_bbox(tuple(bbox), h, w) keypoints = [ @@ -756,7 +758,7 @@ def bbox_piecewise_affine( ] dist_maps = to_distance_maps(keypoints, h, w, True) dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0) - keypoints = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, 0.01) + keypoints = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, keypoints_threshold) keypoints = [i for i in keypoints if 0 <= i[0] < w and 0 <= i[1] < h] keypoints_arr = np.array(keypoints) x1 = keypoints_arr[:, 0].min() diff --git a/albumentations/augmentations/geometric/transforms.py b/albumentations/augmentations/geometric/transforms.py index 746608bb1..f98907ec4 100644 --- a/albumentations/augmentations/geometric/transforms.py +++ b/albumentations/augmentations/geometric/transforms.py @@ -738,6 +738,11 @@ class PiecewiseAffine(DualTransform): Points outside the boundaries of the input are filled according to the given mode. Modes match the behaviour of `numpy.pad`. absolute_scale (bool): Take `scale` as an absolute value rather than a relative value. + keypoints_threshold (float): Used as threshold in conversion from distance maps to keypoints. + The search for keypoints works by searching for the + argmin (non-inverted) or argmax (inverted) in each channel. This + parameters contains the maximum (non-inverted) or minimum (inverted) value to accept in order to view a hit + as a keypoint. Use ``None`` to use no min/max. Default: 0.01 Targets: image, mask, keypoints, bboxes @@ -759,6 +764,7 @@ def __init__( mode: str = "constant", absolute_scale: bool = False, always_apply: bool = False, + keypoints_threshold: float = 0.01, p: float = 0.5, ): super(PiecewiseAffine, self).__init__(always_apply, p) @@ -772,6 +778,7 @@ def __init__( self.cval_mask = cval_mask self.mode = mode self.absolute_scale = absolute_scale + self.keypoints_threshold = keypoints_threshold def get_transform_init_args_names(self): return ( @@ -784,6 +791,7 @@ def get_transform_init_args_names(self): "cval_mask", "mode", "absolute_scale", + "keypoints_threshold", ) @property @@ -855,7 +863,7 @@ def apply_to_bbox( matrix: skimage.transform.PiecewiseAffineTransform = None, **params ) -> Sequence[float]: - return F.bbox_piecewise_affine(bbox, matrix, rows, cols) + return F.bbox_piecewise_affine(bbox, matrix, rows, cols, self.keypoints_threshold) def apply_to_keypoint( self, @@ -865,4 +873,4 @@ def apply_to_keypoint( matrix: skimage.transform.PiecewiseAffineTransform = None, **params ): - return F.keypoint_piecewise_affine(keypoint, matrix, rows, cols) + return F.keypoint_piecewise_affine(keypoint, matrix, rows, cols, self.keypoints_threshold) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index a4ec4632f..60c50a566 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -328,6 +328,7 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m "cval_mask": 20, "mode": "edge", "absolute_scale": True, + "keypoints_threshold": 0.1, }, ], ], From 538114f15700a69229d457ecf2df4be568fea9a4 Mon Sep 17 00:00:00 2001 From: mdruzhinin Date: Sun, 16 May 2021 15:44:32 +0300 Subject: [PATCH 9/9] Try restart github actions