Skip to content

Commit

Permalink
Posterize transform from Pillow (#333)
Browse files Browse the repository at this point in the history
* Posterize transform from Pillow

* Readme updated - added Posterize transform

* Posterize transform, num_bits sanity check

* removed typo in docstring

* Fixed forgot return statement
  • Loading branch information
Dipet authored and ternaus committed Sep 12, 2019
1 parent 4e12c6e commit ad95fa0
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [MedianBlur](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.MedianBlur)
- [MotionBlur](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.MotionBlur)
- [Normalize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Normalize)
- [Posterize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Posterize)
- [RGBShift](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RGBShift)
- [RandomBrightness](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomBrightness)
- [RandomBrightnessContrast](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomBrightnessContrast)
Expand Down
45 changes: 45 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,51 @@ def solarize(img, threshold=128):
return result_img


@preserve_shape
def posterize(img, bits):
"""Reduce the number of bits for each color channel.
Args:
img: image to posterize.
bits: number of high bits. Must be in range [0, 8]
"""
bits = np.uint8(bits)

assert img.dtype == np.uint8, 'Image must have uint8 channel type'
assert np.all((0 <= bits) & (bits <= 8)), "bits must be in range [0, 8]"

if not bits.shape or len(bits) == 1:
if bits == 0:
return np.zeros_like(img)
elif bits == 8:
return img.copy()

lut = np.arange(0, 256, dtype=np.uint8)
mask = ~np.uint8(2 ** (8 - bits) - 1)
lut &= mask

return cv2.LUT(img, lut)

assert is_rgb_image(img), 'If bits is iterable image must be RGB'

result_img = np.empty_like(img)
for i, channel_bits in enumerate(bits):
if channel_bits == 0:
result_img[..., i] = np.zeros_like(img[..., i])
continue
elif channel_bits == 8:
result_img[..., i] = img[..., i].copy()
continue

lut = np.arange(0, 256, dtype=np.uint8)
mask = ~np.uint8(2 ** (8 - channel_bits) - 1)
lut &= mask

result_img[..., i] = cv2.LUT(img[..., i], lut)

return result_img


def _equalize_pil(img, mask=None):
histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()
h = [_f for _f in histogram if _f]
Expand Down
43 changes: 42 additions & 1 deletion albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
'Resize', 'RandomSizedCrop', 'RandomResizedCrop', 'RandomBrightnessContrast',
'RandomCropNearBBox', 'RandomSizedBBoxSafeCrop', 'RandomSnow',
'RandomRain', 'RandomFog', 'RandomSunFlare', 'RandomShadow', 'Lambda',
'ChannelDropout', 'ISONoise', 'Solarize', 'Equalize'
'ChannelDropout', 'ISONoise', 'Solarize', 'Equalize', 'Posterize'
]


Expand Down Expand Up @@ -1899,6 +1899,47 @@ def get_transform_init_args_names(self):
return ('threshold', )


class Posterize(ImageOnlyTransform):
"""Reduce the number of bits for each color channel.
Args:
num_bits ((int, int) or int,
or list of ints [r, g, b],
or list of ints [[r1, r1], [g1, g2], [b1, b2]]): number of high bits.
If num_bits is a single value, the range will be [num_bits, num_bits].
Must be in range [0, 8]. Default: 4.
p (float): probability of applying the transform. Default: 0.5.
Targets:
image
Image types:
uint8
"""

def __init__(self, num_bits=4, always_apply=False, p=0.5):
super(Posterize, self).__init__(always_apply, p)

if isinstance(num_bits, (list, tuple)):
if len(num_bits) == 3:
self.num_bits = [to_tuple(i, 0) for i in num_bits]
else:
self.num_bits = to_tuple(num_bits, 0)
else:
self.num_bits = to_tuple(num_bits, num_bits)

def apply(self, image, num_bits=1, **params):
return F.posterize(image, num_bits)

def get_params(self):
if len(self.num_bits) == 3:
return {'num_bits': [random.randint(i[0], i[1]) for i in self.num_bits]}
return {'num_bits': random.randint(self.num_bits[0], self.num_bits[1])}

def get_transform_init_args_names(self):
return ('num_bits',)


class Equalize(ImageOnlyTransform):
"""Equalize the image histogram.
Expand Down
13 changes: 11 additions & 2 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,15 @@ def torchvision(self, img):
return torchvision.to_grayscale(img, num_output_channels=3)


class Posterize(BenchmarkTest):

def albumentations(self, img):
return albumentations.posterize(img, 4)

def pillow(self, img):
return ImageOps.posterize(img, 4)


def main():
args = parse_args()
if args.print_package_versions:
Expand All @@ -394,7 +403,7 @@ def main():
'keras',
'augmentor',
'solt',
'pillow',
'pillow'
]
data_dir = args.data_dir
paths = list(sorted(os.listdir(data_dir)))
Expand All @@ -417,12 +426,12 @@ def main():
RandomCrop64(),
PadToSize512(),
Resize512(),
Posterize(),
Solarize(),
Equalize(),
]
for library in libraries:
imgs = imgs_pillow if library in ('torchvision', 'augmentor', 'pillow') else imgs_cv2

pbar = tqdm(total=len(benchmarks))
for benchmark in benchmarks:
pbar.set_description('Current benchmark: {} | {}'.format(library, benchmark))
Expand Down
6 changes: 5 additions & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Cutout, CoarseDropout, Normalize, ToFloat, FromFloat,
RandomBrightnessContrast, RandomSnow, RandomRain, RandomFog,
RandomSunFlare, RandomCropNearBBox, RandomShadow, RandomSizedCrop, RandomResizedCrop,
ChannelDropout, ISONoise, Solarize, Equalize, CropNonEmptyMaskIfExists)
ChannelDropout, ISONoise, Solarize, Equalize, CropNonEmptyMaskIfExists, Posterize)


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
Expand Down Expand Up @@ -43,6 +43,7 @@
[ChannelDropout, {}],
[ISONoise, {}],
[Solarize, {}],
[Posterize, {}],
[Equalize, {}],
])
def test_image_only_augmentations(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -204,6 +205,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):
[ISONoise, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
[Posterize, {}],
[Equalize, {}],
])
def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -292,6 +294,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i
[FromFloat, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
[Posterize, {}],
[Equalize, {}],
])
def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -360,6 +363,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima
[ISONoise, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
[Posterize, {}],
[Equalize, {}],
])
def test_augmentations_wont_change_shape_rgb(augmentation_cls, params, image, mask):
Expand Down
28 changes: 28 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,34 @@ def test_solarize_equal_to_pillow():
assert np.all(result_albu == result_pil)


def test_posterize():
img_cv = np.random.randint(0, 256, [256, 256, 3], dtype=np.uint8)
img_pil = Image.fromarray(img_cv)

assert np.all(F.posterize(img_cv, 4)) == np.all(np.array(ImageOps.posterize(img_pil, 4)))

bits = [3, 4, 5]
img_pil = []
for i, b in enumerate(bits):
img = Image.fromarray(img_cv[..., i])
img_pil.append(np.array(ImageOps.posterize(img, b)))
img_pil = cv2.merge(img_pil).astype(np.uint8)

assert np.all(F.posterize(img_cv, bits) == img_pil)


def test_posterize_checks():
img = np.random.random([256, 256, 3])
with pytest.raises(AssertionError) as exc_info:
F.posterize(img, 4)
assert str(exc_info.value) == 'Image must have uint8 channel type'

img = np.random.randint(0, 256, [256, 256], dtype=np.uint8)
with pytest.raises(AssertionError) as exc_info:
F.posterize(img, [1, 2, 3])
assert str(exc_info.value) == 'If bits is iterable image must be RGB'


def test_equalize_checks():
img = np.random.randint(0, 255, [256, 256], dtype=np.uint8)

Expand Down
5 changes: 5 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
[A.LongestMaxSize, {}],
[A.RandomGridShuffle, {}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -167,6 +168,7 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
[A.LongestMaxSize, {'max_size': 128, 'interpolation': cv2.INTER_CUBIC}],
[A.RandomGridShuffle, {'grid': (5, 5)}],
[A.Solarize, {'threshold': 32}],
[A.Posterize, {'num_bits': 1}],
[A.Equalize, {'mode': 'pil', 'by_channels': False}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -238,6 +240,7 @@ def test_augmentations_serialization_with_custom_parameters(
[A.LongestMaxSize, {}],
[A.RandomSizedBBoxSafeCrop, {'height': 50, 'width': 50}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -295,6 +298,7 @@ def test_augmentations_for_bboxes_serialization(augmentation_cls, params, p, see
[A.RandomContrast, {}],
[A.RandomScale, {}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -573,6 +577,7 @@ def test_transform_pipeline_serialization_with_keypoints(seed, image, keypoints,
[A.FromFloat, {}],
[A.RandomGridShuffle, {}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('seed', TEST_SEEDS)
Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def test_force_apply():
[A.FromFloat, {}],
[A.ChannelDropout, {}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
def test_additional_targets_for_image_only(augmentation_cls, params):
Expand Down

0 comments on commit ad95fa0

Please sign in to comment.