Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Posterize transform from Pillow #333

Merged
merged 9 commits into from
Sep 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [MedianBlur](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.MedianBlur)
- [MotionBlur](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.MotionBlur)
- [Normalize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Normalize)
- [Posterize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Posterize)
- [RGBShift](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RGBShift)
- [RandomBrightness](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomBrightness)
- [RandomBrightnessContrast](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomBrightnessContrast)
Expand Down
45 changes: 45 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,51 @@ def solarize(img, threshold=128):
return result_img


@preserve_shape
def posterize(img, bits):
"""Reduce the number of bits for each color channel.

Args:
img: image to posterize.
bits: number of high bits. Must be in range [0, 8]
"""
bits = np.uint8(bits)

assert img.dtype == np.uint8, 'Image must have uint8 channel type'
Dipet marked this conversation as resolved.
Show resolved Hide resolved
assert np.all((0 <= bits) & (bits <= 8)), "bits must be in range [0, 8]"

if not bits.shape or len(bits) == 1:
if bits == 0:
return np.zeros_like(img)
elif bits == 8:
return img.copy()

lut = np.arange(0, 256, dtype=np.uint8)
mask = ~np.uint8(2 ** (8 - bits) - 1)
lut &= mask

return cv2.LUT(img, lut)

assert is_rgb_image(img), 'If bits is iterable image must be RGB'

result_img = np.empty_like(img)
for i, channel_bits in enumerate(bits):
if channel_bits == 0:
result_img[..., i] = np.zeros_like(img[..., i])
continue
elif channel_bits == 8:
result_img[..., i] = img[..., i].copy()
continue

lut = np.arange(0, 256, dtype=np.uint8)
mask = ~np.uint8(2 ** (8 - channel_bits) - 1)
lut &= mask

result_img[..., i] = cv2.LUT(img[..., i], lut)

return result_img


def _equalize_pil(img, mask=None):
histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()
h = [_f for _f in histogram if _f]
Expand Down
43 changes: 42 additions & 1 deletion albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
'Resize', 'RandomSizedCrop', 'RandomResizedCrop', 'RandomBrightnessContrast',
'RandomCropNearBBox', 'RandomSizedBBoxSafeCrop', 'RandomSnow',
'RandomRain', 'RandomFog', 'RandomSunFlare', 'RandomShadow', 'Lambda',
'ChannelDropout', 'ISONoise', 'Solarize', 'Equalize'
'ChannelDropout', 'ISONoise', 'Solarize', 'Equalize', 'Posterize'
]


Expand Down Expand Up @@ -1899,6 +1899,47 @@ def get_transform_init_args_names(self):
return ('threshold', )


class Posterize(ImageOnlyTransform):
"""Reduce the number of bits for each color channel.

Args:
num_bits ((int, int) or int,
or list of ints [r, g, b],
or list of ints [[r1, r1], [g1, g2], [b1, b2]]): number of high bits.
If num_bits is a single value, the range will be [num_bits, num_bits].
Must be in range [0, 8]. Default: 4.
p (float): probability of applying the transform. Default: 0.5.

Targets:
image

Image types:
uint8
"""

def __init__(self, num_bits=4, always_apply=False, p=0.5):
super(Posterize, self).__init__(always_apply, p)

if isinstance(num_bits, (list, tuple)):
if len(num_bits) == 3:
self.num_bits = [to_tuple(i, 0) for i in num_bits]
else:
self.num_bits = to_tuple(num_bits, 0)
else:
self.num_bits = to_tuple(num_bits, num_bits)

def apply(self, image, num_bits=1, **params):
return F.posterize(image, num_bits)

def get_params(self):
if len(self.num_bits) == 3:
return {'num_bits': [random.randint(i[0], i[1]) for i in self.num_bits]}
return {'num_bits': random.randint(self.num_bits[0], self.num_bits[1])}

def get_transform_init_args_names(self):
return ('num_bits',)


class Equalize(ImageOnlyTransform):
"""Equalize the image histogram.

Expand Down
13 changes: 11 additions & 2 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,15 @@ def torchvision(self, img):
return torchvision.to_grayscale(img, num_output_channels=3)


class Posterize(BenchmarkTest):

def albumentations(self, img):
return albumentations.posterize(img, 4)

def pillow(self, img):
return ImageOps.posterize(img, 4)


def main():
args = parse_args()
if args.print_package_versions:
Expand All @@ -394,7 +403,7 @@ def main():
'keras',
'augmentor',
'solt',
'pillow',
'pillow'
]
data_dir = args.data_dir
paths = list(sorted(os.listdir(data_dir)))
Expand All @@ -417,12 +426,12 @@ def main():
RandomCrop64(),
PadToSize512(),
Resize512(),
Posterize(),
Solarize(),
Equalize(),
]
for library in libraries:
imgs = imgs_pillow if library in ('torchvision', 'augmentor', 'pillow') else imgs_cv2

pbar = tqdm(total=len(benchmarks))
for benchmark in benchmarks:
pbar.set_description('Current benchmark: {} | {}'.format(library, benchmark))
Expand Down
6 changes: 5 additions & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Cutout, CoarseDropout, Normalize, ToFloat, FromFloat,
RandomBrightnessContrast, RandomSnow, RandomRain, RandomFog,
RandomSunFlare, RandomCropNearBBox, RandomShadow, RandomSizedCrop, RandomResizedCrop,
ChannelDropout, ISONoise, Solarize, Equalize, CropNonEmptyMaskIfExists)
ChannelDropout, ISONoise, Solarize, Equalize, CropNonEmptyMaskIfExists, Posterize)


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
Expand Down Expand Up @@ -43,6 +43,7 @@
[ChannelDropout, {}],
[ISONoise, {}],
[Solarize, {}],
[Posterize, {}],
[Equalize, {}],
])
def test_image_only_augmentations(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -204,6 +205,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):
[ISONoise, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
[Posterize, {}],
[Equalize, {}],
])
def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -292,6 +294,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i
[FromFloat, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
[Posterize, {}],
[Equalize, {}],
])
def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -360,6 +363,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima
[ISONoise, {}],
[RandomGridShuffle, {}],
[Solarize, {}],
[Posterize, {}],
[Equalize, {}],
])
def test_augmentations_wont_change_shape_rgb(augmentation_cls, params, image, mask):
Expand Down
28 changes: 28 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,34 @@ def test_solarize_equal_to_pillow():
assert np.all(result_albu == result_pil)


def test_posterize():
img_cv = np.random.randint(0, 256, [256, 256, 3], dtype=np.uint8)
img_pil = Image.fromarray(img_cv)

assert np.all(F.posterize(img_cv, 4)) == np.all(np.array(ImageOps.posterize(img_pil, 4)))

bits = [3, 4, 5]
img_pil = []
for i, b in enumerate(bits):
img = Image.fromarray(img_cv[..., i])
img_pil.append(np.array(ImageOps.posterize(img, b)))
img_pil = cv2.merge(img_pil).astype(np.uint8)

assert np.all(F.posterize(img_cv, bits) == img_pil)


def test_posterize_checks():
img = np.random.random([256, 256, 3])
with pytest.raises(AssertionError) as exc_info:
F.posterize(img, 4)
assert str(exc_info.value) == 'Image must have uint8 channel type'

img = np.random.randint(0, 256, [256, 256], dtype=np.uint8)
with pytest.raises(AssertionError) as exc_info:
F.posterize(img, [1, 2, 3])
assert str(exc_info.value) == 'If bits is iterable image must be RGB'


def test_equalize_checks():
img = np.random.randint(0, 255, [256, 256], dtype=np.uint8)

Expand Down
5 changes: 5 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
[A.LongestMaxSize, {}],
[A.RandomGridShuffle, {}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -167,6 +168,7 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
[A.LongestMaxSize, {'max_size': 128, 'interpolation': cv2.INTER_CUBIC}],
[A.RandomGridShuffle, {'grid': (5, 5)}],
[A.Solarize, {'threshold': 32}],
[A.Posterize, {'num_bits': 1}],
[A.Equalize, {'mode': 'pil', 'by_channels': False}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -238,6 +240,7 @@ def test_augmentations_serialization_with_custom_parameters(
[A.LongestMaxSize, {}],
[A.RandomSizedBBoxSafeCrop, {'height': 50, 'width': 50}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -295,6 +298,7 @@ def test_augmentations_for_bboxes_serialization(augmentation_cls, params, p, see
[A.RandomContrast, {}],
[A.RandomScale, {}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('p', [0.5, 1])
Expand Down Expand Up @@ -573,6 +577,7 @@ def test_transform_pipeline_serialization_with_keypoints(seed, image, keypoints,
[A.FromFloat, {}],
[A.RandomGridShuffle, {}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
@pytest.mark.parametrize('seed', TEST_SEEDS)
Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def test_force_apply():
[A.FromFloat, {}],
[A.ChannelDropout, {}],
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
])
def test_additional_targets_for_image_only(augmentation_cls, params):
Expand Down