Skip to content

Commit

Permalink
Histogram equalization transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Dipet committed Aug 16, 2019
1 parent 00c349b commit 9f71038
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 4 deletions.
105 changes: 105 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,111 @@ def solarize(img, threshold=128):
return result_img


def _equalize_pil(img, mask=None):
if mask is None:
histogram = np.histogram(img, 256, range=(0, 255))[0]
else:
histogram = np.histogram(img[mask], 256, range=(0, 255))[0]
h = [_f for _f in histogram if _f]

if len(h) <= 1:
return img.copy()

step = np.sum(h[:-1]) // 255
if not step:
return img.copy()

lut = np.empty(256, dtype=np.uint8)
n = step // 2
for i in range(256):
lut[i] = min(n // step, 255)
n += histogram[i]

return cv2.LUT(img, np.array(lut))


def _equalize_cv(img, mask=None):
if mask is None:
return cv2.equalizeHist(img)

histogram = np.histogram(img[mask], 256, range=(0, 255))[0]
i = 0
for val in histogram:
if val > 0:
break
i += 1

total = np.sum(histogram)
if histogram[i] == total:
return np.full_like(img, i)

scale = 255 / (total - histogram[i])
_sum = 0

lut = np.zeros(256, dtype=np.uint8)
i += 1
for i in range(i, len(histogram)):
_sum += histogram[i]
lut[i] = clip(round(_sum * scale), np.dtype('uint8'), 255)

return cv2.LUT(img, lut)


@preserve_channel_dim
def equalize(img, mask=None, mode='cv', by_channels=True):
"""Equalize the image histogram.
Args:
img (np.ndarray): RGB or grayscale image.
mask (np.ndarray): An optional mask. If given, only the pixels selected by
the mask are included in the analysis. Maybe 1 channel or 3 channel array.
mode (str): {'cv', 'pil'}. Use OpenCV or Pillow equalization method.
by_channels (bool): If True, use equalization by channels separately,
else convert image to YCbCr representation and use equalization by `Y` channel.
Returns:
Equalized image.
"""
modes = ['cv', 'pil']

if mode not in modes:
raise ValueError('Unsupported equalization mode. Supports: {}. '
'Got: {}'.format(modes, mode))
if mask is not None:
if is_rgb_image(mask) and is_grayscale_image(img):
raise ValueError('Wrong mask shape. Image shape: {}. '
'Mask shape: {}'.format(img.shape, mask.shape))
if not by_channels and not is_grayscale_image(mask):
raise ValueError('When by_channels=False only 1-channel mask supports. '
'Mask shape: {}'.format(mask.shape))

if mode == 'pil':
function = _equalize_pil
else:
function = _equalize_cv

if is_grayscale_image(img):
return function(img, mask)

if not by_channels:
result_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
result_img[..., 0] = function(result_img[..., 0], mask)
return cv2.cvtColor(result_img, cv2.COLOR_YCrCb2RGB)

result_img = img.copy()
for i in range(3):
if mask is None:
_mask = None
elif is_grayscale_image(mask):
_mask = mask
else:
_mask = mask[..., i]

result_img[..., i] = function(img[..., i], _mask)
return result_img


@clipped
def shift_rgb(img, r_shift, g_shift, b_shift):
if img.dtype == np.uint8:
Expand Down
37 changes: 35 additions & 2 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
'Resize', 'RandomSizedCrop', 'RandomBrightnessContrast',
'RandomCropNearBBox', 'RandomSizedBBoxSafeCrop', 'RandomSnow',
'RandomRain', 'RandomFog', 'RandomSunFlare', 'RandomShadow', 'Lambda',
'ChannelDropout', 'ISONoise', 'Solarize'
'ChannelDropout', 'ISONoise', 'Solarize', 'Equalize'
]


Expand Down Expand Up @@ -1563,7 +1563,7 @@ class Solarize(ImageOnlyTransform):
p (float): probability of applying the transform. Default: 0.5.
Targets:
image
image
Image types:
any
Expand All @@ -1587,6 +1587,39 @@ def get_transform_init_args_names(self):
return ('threshold', )


class Equalize(ImageOnlyTransform):
"""Equalize the image histogram.
Args:
mode (str): {'cv', 'pil'}. Use OpenCV or Pillow equalization method.
by_channels (bool): If True, use equalization by channels separately,
else convert image to YCbCr representation and use equalization by `Y` channel.
Targets:
image
Image types:
uint8
"""

def __init__(self, mode='cv', by_channels=True, always_apply=False, p=0.5):
modes = ['cv', 'pil']
if mode not in modes:
raise ValueError('Unsupported equalization mode. Supports: {}. '
'Got: {}'.format(modes, mode))

super(Equalize, self).__init__(always_apply, p)
self.mode = mode
self.by_channels = by_channels

def apply(self, image, **params):
return F.equalize(image, mode=self.mode, by_channels=self.by_channels)

def get_transform_init_args_names(self):
return ('mode', 'by_channels')


class RGBShift(ImageOnlyTransform):
"""Randomly shift values for each channel of the input RGB image.
Expand Down
9 changes: 9 additions & 0 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ def albumentations(self, img):
return albumentations.solarize(img)


class Equalize(BenchmarkTest):

def __init__(self):
pass

def albumentations(self, img):
return albumentations.equalize(img)


class RandomCrop64(BenchmarkTest):

def __init__(self):
Expand Down
8 changes: 6 additions & 2 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Cutout, CoarseDropout, Normalize, ToFloat, FromFloat,
RandomBrightnessContrast, RandomSnow, RandomRain, RandomFog,
RandomSunFlare, RandomCropNearBBox, RandomShadow, RandomSizedCrop,
ChannelDropout, ISONoise, Solarize)
ChannelDropout, ISONoise, Solarize, Equalize)


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
Expand Down Expand Up @@ -43,6 +43,7 @@
[ChannelDropout, {}],
[ISONoise, {}],
[Solarize, {}],
[Equalize, {}],
])
def test_image_only_augmentations(augmentation_cls, params, image, mask):
aug = augmentation_cls(p=1, **params)
Expand Down Expand Up @@ -75,7 +76,7 @@ def test_image_only_augmentations(augmentation_cls, params, image, mask):
[RandomSunFlare, {}],
[RandomShadow, {}],
[ChannelDropout, {}],
[Solarize, {}]
[Solarize, {}],
])
def test_image_only_augmentations_with_float_values(augmentation_cls, params, float_image, mask):
aug = augmentation_cls(p=1, **params)
Expand Down Expand Up @@ -194,6 +195,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):
[ChannelDropout, {}],
[ISONoise, {}],
[Solarize, {}],
[Equalize, {}],
])
def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
image_copy = image.copy()
Expand Down Expand Up @@ -278,6 +280,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i
[ToFloat, {}],
[FromFloat, {}],
[Solarize, {}],
[Equalize, {}],
])
def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, image, mask):
aug = augmentation_cls(p=1, **params)
Expand Down Expand Up @@ -344,6 +347,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima
[ChannelDropout, {}],
[ISONoise, {}],
[Solarize, {}],
[Equalize, {}],
])
def test_augmentations_wont_change_shape_rgb(augmentation_cls, params, image, mask):
aug = augmentation_cls(p=1, **params)
Expand Down
107 changes: 107 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,3 +923,110 @@ def test_solarize_equal_to_pillow():
result_pil = np.array(ImageOps.solarize(img_pil, i))

assert np.all(result_albu == result_pil)


def test_equalize_checks():
img = np.random.randint(0, 255, [256, 256], dtype=np.uint8)

with pytest.raises(ValueError) as exc_info:
F.equalize(img, mode='other')
assert str(exc_info.value) == "Unsupported equalization mode. Supports: ['cv', 'pil']. Got: other"

mask = np.random.randint(0, 1, [256, 256, 3], dtype=np.bool)
with pytest.raises(ValueError) as exc_info:
F.equalize(img, mask=mask)
assert str(exc_info.value) == "Wrong mask shape. Image shape: (256, 256). Mask shape: (256, 256, 3)"

img = np.random.randint(0, 255, [256, 256, 3], dtype=np.uint8)
with pytest.raises(ValueError) as exc_info:
F.equalize(img, mask=mask, by_channels=False)
assert str(exc_info.value) == "When by_channels=False only 1-channel mask supports. Mask shape: (256, 256, 3)"


def test_equalize_grayscale():
img = np.random.randint(0, 255, [256, 256], dtype=np.uint8)
pil_img = Image.fromarray(img)

assert np.all(cv2.equalizeHist(img) == F.equalize(img, mode='cv'))
assert np.all(np.array(ImageOps.equalize(pil_img)) == F.equalize(img, mode='pil'))


def test_equalize_rgb():
img = np.random.randint(0, 255, [256, 256, 3], dtype=np.uint8)
img_pil = Image.fromarray(img)

_img = img.copy()
for i in range(3):
_img[..., i] = cv2.equalizeHist(_img[..., i])
assert np.all(_img == F.equalize(img, mode='cv'))
assert np.all(np.array(ImageOps.equalize(img_pil)) == F.equalize(img, mode='pil'))

_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
img_cv = _img.copy()
img_pil = _img.copy()
img_cv[..., 0] = cv2.equalizeHist(_img[..., 0])
img_pil[..., 0] = np.array(ImageOps.equalize(Image.fromarray(_img[..., 0])))
img_cv = cv2.cvtColor(img_cv, cv2.COLOR_YCrCb2RGB)
img_pil = cv2.cvtColor(img_pil, cv2.COLOR_YCrCb2RGB)
assert np.all(img_cv == F.equalize(img, mode='cv', by_channels=False))
assert np.all(img_pil == F.equalize(img, mode='pil', by_channels=False))


def test_equalize_grayscale_mask():
img = np.random.randint(0, 255, [256, 256], dtype=np.uint8)
pil_img = Image.fromarray(img)

mask = np.zeros([256, 256], dtype=np.bool)
mask[:10, :10] = True
pil_mask = Image.fromarray(mask)

assert np.all(cv2.equalizeHist(img[:10, :10]) == F.equalize(img, mask=mask, mode='cv')[:10, :10])
assert np.all(np.array(ImageOps.equalize(pil_img, pil_mask)) == F.equalize(img, mask=mask, mode='pil'))


def test_equalize_rgb_mask():
img = np.random.randint(0, 255, [256, 256, 3], dtype=np.uint8)
img_pil = Image.fromarray(img)

mask = np.zeros([256, 256], dtype=np.bool)
mask[:10, :10] = True
mask_pil = Image.fromarray(mask)

_img = img.copy()[:10, :10]
for i in range(3):
_img[..., i] = cv2.equalizeHist(_img[..., i])
assert np.all(_img == F.equalize(img, mask, mode='cv')[:10, :10])
assert np.all(np.array(ImageOps.equalize(img_pil, mask_pil)) == F.equalize(img, mask, mode='pil'))

_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
img_cv = _img.copy()[:10, :10]
img_pil = _img.copy()
img_cv[..., 0] = cv2.equalizeHist(img_cv[..., 0])
img_pil[..., 0] = np.array(ImageOps.equalize(Image.fromarray(_img[..., 0]), mask_pil))
img_cv = cv2.cvtColor(img_cv, cv2.COLOR_YCrCb2RGB)
img_pil = cv2.cvtColor(img_pil, cv2.COLOR_YCrCb2RGB)
assert np.all(img_cv == F.equalize(img, mask=mask, mode='cv', by_channels=False)[:10, :10])
assert np.all(img_pil == F.equalize(img, mask=mask, mode='pil', by_channels=False))

mask = np.zeros([256, 256, 3], dtype=np.bool)
mask[:10, :10, 0] = True
mask[10:20, 10:20, 1] = True
mask[20:30, 20:30, 2] = True
img_r = img.copy()[:10, :10, 0]
img_g = img.copy()[10:20, 10:20, 1]
img_b = img.copy()[20:30, 20:30, 2]

img_r = cv2.equalizeHist(img_r)
img_g = cv2.equalizeHist(img_g)
img_b = cv2.equalizeHist(img_b)

result_img = F.equalize(img, mask=mask, mode='cv')
assert np.all(img_r == result_img[:10, :10, 0])
assert np.all(img_g == result_img[10:20, 10:20, 1])
assert np.all(img_b == result_img[20:30, 20:30, 2])

_img = img.copy()
for i in range(3):
_img[..., i] = np.array(ImageOps.equalize(Image.fromarray(_img[..., i]),
Image.fromarray(mask[..., i])))
assert np.all(_img == F.equalize(img, mask=mask, mode='pil'))
5 changes: 5 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
[A.SmallestMaxSize, {}],
[A.LongestMaxSize, {}],
[A.Solarize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
@pytest.mark.parametrize('seed', TEST_SEEDS)
Expand Down Expand Up @@ -160,6 +161,7 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
[A.SmallestMaxSize, {'max_size': 64, 'interpolation': cv2.INTER_CUBIC}],
[A.LongestMaxSize, {'max_size': 128, 'interpolation': cv2.INTER_CUBIC}],
[A.Solarize, {'threshold': 32}],
[A.Equalize, {'mode': 'pil', 'by_channels': False}],
])
@pytest.mark.parametrize('p', [0.5, 1])
@pytest.mark.parametrize('seed', TEST_SEEDS)
Expand Down Expand Up @@ -229,6 +231,7 @@ def test_augmentations_serialization_with_custom_parameters(
[A.LongestMaxSize, {}],
[A.RandomSizedBBoxSafeCrop, {'height': 50, 'width': 50}],
[A.Solarize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
@pytest.mark.parametrize('seed', TEST_SEEDS)
Expand Down Expand Up @@ -284,6 +287,7 @@ def test_augmentations_for_bboxes_serialization(augmentation_cls, params, p, see
[A.RandomContrast, {}],
[A.RandomScale, {}],
[A.Solarize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
@pytest.mark.parametrize('seed', TEST_SEEDS)
Expand Down Expand Up @@ -557,6 +561,7 @@ def test_transform_pipeline_serialization_with_keypoints(seed, image, keypoints,
[A.ToFloat, {}],
[A.FromFloat, {}],
[A.Solarize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('seed', TEST_SEEDS)
def test_additional_targets_for_image_only_serialization(augmentation_cls, params, image, seed):
Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def test_force_apply():
[A.FromFloat, {}],
[A.ChannelDropout, {}],
[A.Solarize, {}],
[A.Equalize, {}],
])
def test_additional_targets_for_image_only(augmentation_cls, params):
aug = A.Compose(
Expand Down

0 comments on commit 9f71038

Please sign in to comment.