Skip to content

Commit

Permalink
Superpixels from imgaug (#849)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dipet authored Mar 5, 2021
1 parent e3a2403 commit ae50578
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ Pixel-level transforms will change just an input image and will leave any additi
- [GlassBlur](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.GlassBlur)
- [HistogramMatching](https://albumentations.ai/docs/api_reference/augmentations/domain_adaptation/#albumentations.augmentations.domain_adaptation.HistogramMatching)
- [HueSaturationValue](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.HueSaturationValue)
- [IAASuperpixels](https://albumentations.ai/docs/api_reference/imgaug/transforms/#albumentations.imgaug.transforms.IAASuperpixels)
- [ISONoise](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ISONoise)
- [ImageCompression](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ImageCompression)
- [InvertImg](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.InvertImg)
Expand All @@ -149,6 +148,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [RandomToneCurve](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.RandomToneCurve)
- [Sharpen](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Sharpen)
- [Solarize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Solarize)
- [Superpixels](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Superpixels)
- [ToFloat](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ToFloat)
- [ToGray](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ToGray)
- [ToSepia](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ToSepia)
Expand Down
62 changes: 62 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from itertools import product
import cv2
import numpy as np
import skimage

from typing import Sequence, Optional, Union
from albumentations.augmentations.keypoints_utils import angle_to_2pi_range

MAX_VALUES_BY_DTYPE = {
Expand Down Expand Up @@ -1724,3 +1726,63 @@ def adjust_hue_torchvision(img, factor):
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
img[..., 0] = np.mod(img[..., 0] + factor * 360, 360)
return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)


@preserve_shape
def superpixels(
image: np.ndarray, n_segments: int, replace_samples: Sequence[bool], max_size: Optional[int], interpolation: int
) -> np.ndarray:
if not np.any(replace_samples):
return image

orig_shape = image.shape
if max_size is not None:
size = max(image.shape[:2])
if size > max_size:
scale = max_size / size
height, width = image.shape[:2]
new_height, new_width = int(height * scale), int(width * scale)
resize_fn = _maybe_process_in_chunks(
cv2.resize, dsize=(new_width, new_height), interpolation=interpolation
)
image = resize_fn(image)

from skimage.segmentation import slic

segments = skimage.segmentation.slic(image, n_segments=n_segments, compactness=10)

min_value = 0
max_value = MAX_VALUES_BY_DTYPE[image.dtype]
image = np.copy(image)
if image.ndim == 2:
image = image.reshape(*image.shape, 1)
nb_channels = image.shape[2]
for c in range(nb_channels):
# segments+1 here because otherwise regionprops always misses the last label
regions = skimage.measure.regionprops(segments + 1, intensity_image=image[..., c])
for ridx, region in enumerate(regions):
# with mod here, because slic can sometimes create more superpixel than requested.
# replace_samples then does not have enough values, so we just start over with the first one again.
if replace_samples[ridx % len(replace_samples)]:
mean_intensity = region.mean_intensity
image_sp_c = image[..., c]

if image_sp_c.dtype.kind in ["i", "u", "b"]:
# After rounding the value can end up slightly outside of the value_range. Hence, we need to clip.
# We do clip via min(max(...)) instead of np.clip because
# the latter one does not seem to keep dtypes for dtypes with large itemsizes (e.g. uint64).
value: Union[int, float]
value = int(np.round(mean_intensity))
value = min(max(value, min_value), max_value)
else:
value = mean_intensity

image_sp_c[segments == ridx] = value

if orig_shape != image.shape:
resize_fn = _maybe_process_in_chunks(
cv2.resize, dsize=(orig_shape[1], orig_shape[0]), interpolation=interpolation
)
image = resize_fn(image)

return image
79 changes: 76 additions & 3 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
import warnings
from enum import IntEnum
from types import LambdaType
from typing import Optional
from typing import Optional, Union, Sequence, Tuple

import cv2
import numpy as np
from skimage.measure import label

from . import functional as F
from .geometric import functional as FGeometric
from .bbox_utils import denormalize_bbox, normalize_bbox, union_of_bboxes
from .bbox_utils import denormalize_bbox, normalize_bbox
from ..core.transforms_interface import (
DualTransform,
ImageOnlyTransform,
Expand Down Expand Up @@ -76,6 +75,7 @@
"ColorJitter",
"Sharpen",
"Emboss",
"Superpixels",
]


Expand Down Expand Up @@ -2953,3 +2953,76 @@ def apply(self, img, emboss_matrix=None, **params):

def get_transform_init_args_names(self):
return ("alpha", "strength")


class Superpixels(ImageOnlyTransform):
"""Transform images parially/completely to their superpixel representation.
This implementation uses skimage's version of the SLIC algorithm.
Args:
p_replace (float or tuple of float): Defines for any segment the probability that the pixels within that
segment are replaced by their average color (otherwise, the pixels are not changed).
Examples:
* A probability of ``0.0`` would mean, that the pixels in no
segment are replaced by their average color (image is not
changed at all).
* A probability of ``0.5`` would mean, that around half of all
segments are replaced by their average color.
* A probability of ``1.0`` would mean, that all segments are
replaced by their average color (resulting in a voronoi
image).
Behaviour based on chosen data types for this parameter:
* If a ``float``, then that ``flat`` will always be used.
* If ``tuple`` ``(a, b)``, then a random probability will be
sampled from the interval ``[a, b]`` per image.
n_segments (int, or tuple of int): Rough target number of how many superpixels to generate (the algorithm
may deviate from this number). Lower value will lead to coarser superpixels.
Higher values are computationally more intensive and will hence lead to a slowdown
* If a single ``int``, then that value will always be used as the
number of segments.
* If a ``tuple`` ``(a, b)``, then a value from the discrete
interval ``[a..b]`` will be sampled per image.
max_size (int or None): Maximum image size at which the augmentation is performed.
If the width or height of an image exceeds this value, it will be
downscaled before the augmentation so that the longest side matches `max_size`.
This is done to speed up the process. The final output image has the same size as the input image.
Note that in case `p_replace` is below ``1.0``,
the down-/upscaling will affect the not-replaced pixels too.
Use ``None`` to apply no down-/upscaling.
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_LINEAR.
p (float): probability of applying the transform. Default: 0.5.
Targets:
image
"""

def __init__(
self,
p_replace: Union[float, Sequence[float]] = 0.1,
n_segments: Union[int, Sequence[int]] = 100,
max_size: Optional[int] = 128,
interpolation: int = cv2.INTER_LINEAR,
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply=always_apply, p=p)
self.p_replace = to_tuple(p_replace, p_replace)
self.n_segments = to_tuple(n_segments, n_segments)
self.max_size = max_size
self.interpolation = interpolation

if min(self.n_segments) < 1:
raise ValueError(f"n_segments must be >= 1. Got: {n_segments}")

def get_transform_init_args_names(self) -> Tuple[str, str, str, str]:
return ("p_replace", "n_segments", "max_size", "interpolation")

def get_params(self) -> dict:
n_segments = random.randint(*self.n_segments)
p = random.uniform(*self.p_replace)
return {"replace_samples": np.random.random(n_segments) < p, "n_segments": n_segments}

def apply(self, img: np.ndarray, replace_samples: Sequence[bool] = (False,), n_segments: int = 1, **kwargs):
return F.superpixels(img, n_segments, replace_samples, self.max_size, self.interpolation)
1 change: 1 addition & 0 deletions albumentations/imgaug/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(self, p_replace=0.1, n_segments=100, always_apply=False, p=0.5):
super(IAASuperpixels, self).__init__(always_apply, p)
self.p_replace = p_replace
self.n_segments = n_segments
warnings.warn("IAASuperpixels is deprecated. Please use Superpixels instead.", FutureWarning)

@property
def processor(self):
Expand Down
9 changes: 9 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
Sharpen,
Emboss,
CropAndPad,
Superpixels,
)


Expand Down Expand Up @@ -118,6 +119,7 @@
],
[Sharpen, {}],
[Emboss, {}],
[Superpixels, {}],
],
)
def test_image_only_augmentations(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -167,6 +169,7 @@ def test_image_only_augmentations(augmentation_cls, params, image, mask):
],
[Sharpen, {}],
[Emboss, {}],
[Superpixels, {}],
],
)
def test_image_only_augmentations_with_float_values(augmentation_cls, params, float_image, mask):
Expand Down Expand Up @@ -328,6 +331,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):
[Sharpen, {}],
[Emboss, {}],
[CropAndPad, {"px": 10}],
[Superpixels, {}],
],
)
def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -399,6 +403,7 @@ def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
[Sharpen, {}],
[Emboss, {}],
[CropAndPad, {"px": 10}],
[Superpixels, {}],
],
)
def test_augmentations_wont_change_float_input(augmentation_cls, params, float_image):
Expand Down Expand Up @@ -452,6 +457,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i
[Sharpen, {}],
[Emboss, {}],
[CropAndPad, {"px": 10}],
[Superpixels, {}],
],
)
def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -532,6 +538,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima
[Sharpen, {}],
[Emboss, {}],
[CropAndPad, {"px": 10}],
[Superpixels, {}],
],
)
def test_augmentations_wont_change_shape_rgb(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -597,6 +604,7 @@ def test_mask_fill_value(augmentation_cls, params):
[GridDropout, {}],
[Perspective, {}],
[CropAndPad, {"px": 10}],
[Superpixels, {}],
],
)
def test_multichannel_image_augmentations(augmentation_cls, params):
Expand Down Expand Up @@ -627,6 +635,7 @@ def test_multichannel_image_augmentations(augmentation_cls, params):
[GridDropout, {}],
[Perspective, {}],
[CropAndPad, {"px": 10}],
[Superpixels, {}],
],
)
def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def set_seed(seed):
[A.Sharpen, {}],
[A.Emboss, {}],
[A.CropAndPad, {"px": 10}],
[A.Superpixels, {}],
],
)
@pytest.mark.parametrize("p", [0.5, 1])
Expand Down Expand Up @@ -269,6 +270,10 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
"pad_mode": cv2.BORDER_REFLECT101,
},
],
[
A.Superpixels,
{"p_replace": (0.5, 0.7), "n_segments": (20, 30), "max_size": 25, "interpolation": cv2.INTER_CUBIC},
],
],
)

Expand Down Expand Up @@ -367,6 +372,7 @@ def test_augmentations_serialization_to_file_with_custom_parameters(
[A.Perspective, {}],
[A.Sharpen, {}],
[A.Emboss, {}],
[A.Superpixels, {}],
],
)
@pytest.mark.parametrize("p", [0.5, 1])
Expand Down Expand Up @@ -435,6 +441,7 @@ def test_augmentations_for_bboxes_serialization(
[A.Perspective, {}],
[A.Sharpen, {}],
[A.Emboss, {}],
[A.Superpixels, {}],
],
)
@pytest.mark.parametrize("p", [0.5, 1])
Expand Down Expand Up @@ -722,6 +729,7 @@ def test_transform_pipeline_serialization_with_keypoints(seed, image, keypoints,
[A.Emboss, {}],
[A.RandomToneCurve, {}],
[A.CropAndPad, {"px": -12}],
[A.Superpixels, {}],
],
)
@pytest.mark.parametrize("seed", TEST_SEEDS)
Expand Down

0 comments on commit ae50578

Please sign in to comment.