Skip to content

Commit

Permalink
Changed JPEGCompress to ImageCompress for supporting webp (#312)
Browse files Browse the repository at this point in the history
* Changed JPEGCompress to ImageCompress for supporting webp

* Changed README.md

* Restored the API class: JPEGCompression

* Robust conditional for checking image type
  • Loading branch information
toshiks authored and ternaus committed Sep 11, 2019
1 parent d96dabe commit b612786
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 24 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [IAASharpen](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAASharpen)
- [IAASuperpixels](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.imgaug.transforms.IAASuperpixels)
- [ISONoise](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ISONoise)
- [ImageCompression](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ImageCompression)
- [InvertImg](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.InvertImg)
- [JpegCompression](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.JpegCompression)
- [MedianBlur](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.MedianBlur)
Expand Down
15 changes: 11 additions & 4 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,21 +617,28 @@ def motion_blur(img, kernel):


@preserve_shape
def jpeg_compression(img, quality):
def image_compression(img, quality, image_type):
if image_type == '.jpeg' or image_type == '.jpg':
quality_flag = cv2.IMWRITE_JPEG_QUALITY
elif image_type == '.webp':
quality_flag = cv2.IMWRITE_WEBP_QUALITY
else:
NotImplementedError("Only '.jpg' and '.webp' compression transforms are implemented. ")

input_dtype = img.dtype
needs_float = False

if input_dtype == np.float32:
warn('Jpeg compression augmentation '
warn('Image compression augmentation '
'is most effective with uint8 inputs, '
'{} is used as input.'.format(input_dtype),
UserWarning)
img = from_float(img, dtype=np.dtype('uint8'))
needs_float = True
elif input_dtype not in (np.uint8, np.float32):
raise ValueError('Unexpected dtype {} for Jpeg augmentation'.format(input_dtype))
raise ValueError('Unexpected dtype {} for image augmentation'.format(input_dtype))

_, encoded_img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, quality))
_, encoded_img = cv2.imencode(image_type, img, (int(quality_flag), quality))
img = cv2.imdecode(encoded_img, cv2.IMREAD_UNCHANGED)

if needs_float:
Expand Down
73 changes: 60 additions & 13 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import random
import warnings
from enum import Enum

import cv2
import numpy as np
Expand All @@ -20,7 +21,7 @@
'ElasticTransform', 'RandomGridShuffle', 'HueSaturationValue', 'PadIfNeeded', 'RGBShift',
'RandomBrightness', 'RandomContrast', 'MotionBlur', 'MedianBlur',
'GaussianBlur', 'GaussNoise', 'CLAHE', 'ChannelShuffle', 'InvertImg',
'ToGray', 'JpegCompression', 'Cutout', 'CoarseDropout', 'ToFloat',
'ToGray', 'JpegCompression', 'ImageCompression', 'Cutout', 'CoarseDropout', 'ToFloat',
'FromFloat', 'Crop', 'CropNonEmptyMaskIfExists', 'RandomScale', 'LongestMaxSize', 'SmallestMaxSize',
'Resize', 'RandomSizedCrop', 'RandomResizedCrop', 'RandomBrightnessContrast',
'RandomCropNearBBox', 'RandomSizedBBoxSafeCrop', 'RandomSnow',
Expand Down Expand Up @@ -1329,12 +1330,14 @@ def get_transform_init_args_names(self):
'min_height', 'min_width')


class JpegCompression(ImageOnlyTransform):
"""Decrease Jpeg compression of an image.
class ImageCompression(ImageOnlyTransform):
"""Decrease Jpeg, WebP compression of an image.
Args:
quality_lower (float): lower bound on the jpeg quality. Should be in [0, 100] range
quality_upper (float): upper bound on the jpeg quality. Should be in [0, 100] range
quality_lower (float): lower bound on the image quality.
Should be in [0, 100] range for jpeg and [1, 100] for webp.
quality_upper (float): upper bound on the image quality.
Should be in [0, 100] range for jpeg and [1, 100] for webp.
Targets:
image
Expand All @@ -1343,23 +1346,67 @@ class JpegCompression(ImageOnlyTransform):
uint8, float32
"""

def __init__(self, quality_lower=99, quality_upper=100, always_apply=False, p=0.5):
super(JpegCompression, self).__init__(always_apply, p)
class ImageCompressionType(Enum):
JPEG = 0
WEBP = 1

def __init__(self, quality_lower=99, quality_upper=100, compression_type=ImageCompressionType.JPEG,
always_apply=False, p=0.5):
super(ImageCompression, self).__init__(always_apply, p)

self.compression_type = compression_type
low_thresh_quality_assert = 0

if self.compression_type == ImageCompression.ImageCompressionType.WEBP:
low_thresh_quality_assert = 1

assert 0 <= quality_lower <= 100
assert 0 <= quality_upper <= 100
assert low_thresh_quality_assert <= quality_lower <= 100
assert low_thresh_quality_assert <= quality_upper <= 100

self.quality_lower = quality_lower
self.quality_upper = quality_upper

def apply(self, image, quality=100, **params):
return F.jpeg_compression(image, quality)
def apply(self, image, quality=100, image_type='.jpg', **params):
return F.image_compression(image, quality, image_type)

def get_params(self):
return {'quality': random.randint(self.quality_lower, self.quality_upper)}
image_type = '.jpg'

if self.compression_type == ImageCompression.ImageCompressionType.WEBP:
image_type = '.webp'

return {'quality': random.randint(self.quality_lower, self.quality_upper),
'image_type': image_type}

def get_transform_init_args_names(self):
return ('quality_lower', 'quality_upper')
return ('quality_lower', 'quality_upper', 'compression_type')


class JpegCompression(ImageCompression):
"""Decrease Jpeg compression of an image.
Args:
quality_lower (float): lower bound on the jpeg quality. Should be in [0, 100] range
quality_upper (float): upper bound on the jpeg quality. Should be in [0, 100] range
Targets:
image
Image types:
uint8, float32
"""

def __init__(self, quality_lower=99, quality_upper=100, always_apply=False, p=0.5):
super(JpegCompression, self).__init__(quality_lower=quality_lower, quality_upper=quality_upper,
compression_type=ImageCompression.ImageCompressionType.JPEG,
always_apply=always_apply, p=p)
warnings.warn("This class has been deprecated. Please use ImageCompression", DeprecationWarning)

def get_transform_init_args(self):
return {
'quality_lower': self.quality_lower,
'quality_upper': self.quality_upper
}


class RandomSnow(ImageOnlyTransform):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from albumentations import (
RandomCrop, PadIfNeeded, VerticalFlip, HorizontalFlip, Flip, Transpose,
RandomRotate90, Rotate, ShiftScaleRotate, CenterCrop, OpticalDistortion,
GridDistortion, ElasticTransform, RandomGridShuffle, ToGray, RandomGamma, JpegCompression,
GridDistortion, ElasticTransform, RandomGridShuffle, ToGray, RandomGamma, ImageCompression,
HueSaturationValue, RGBShift, Blur, MotionBlur, MedianBlur, GaussianBlur,
GaussNoise, CLAHE, ChannelShuffle, InvertImg, IAAEmboss, IAASuperpixels,
IAASharpen, IAAAdditiveGaussianNoise, IAAPiecewiseAffine, IAAPerspective,
Expand All @@ -18,7 +18,7 @@


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[JpegCompression, {}],
[ImageCompression, {}],
[HueSaturationValue, {}],
[RGBShift, {}],
[RandomBrightnessContrast, {}],
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_image_only_augmentations(augmentation_cls, params, image, mask):
[ChannelShuffle, {}],
[InvertImg, {}],
[RandomGamma, {}],
[JpegCompression, {}],
[ImageCompression, {}],
[ToGray, {}],
[Cutout, {}],
[CoarseDropout, {}],
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):

@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[Cutout, {}],
[JpegCompression, {}],
[ImageCompression, {}],
[HueSaturationValue, {}],
[RGBShift, {}],
[RandomBrightnessContrast, {}],
Expand Down Expand Up @@ -269,7 +269,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i
@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[Cutout, {}],
[CoarseDropout, {}],
[JpegCompression, {}],
[ImageCompression, {}],
[RandomBrightnessContrast, {}],
[Blur, {}],
[MotionBlur, {}],
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima
@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[Cutout, {}],
[CoarseDropout, {}],
[JpegCompression, {}],
[ImageCompression, {}],
[HueSaturationValue, {}],
[RGBShift, {}],
[RandomBrightnessContrast, {}],
Expand Down
6 changes: 6 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[A.ImageCompression, {}],
[A.JpegCompression, {}],
[A.HueSaturationValue, {}],
[A.RGBShift, {}],
Expand Down Expand Up @@ -72,6 +73,8 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[A.ImageCompression, {'quality_lower': 10, 'quality_upper': 80,
'compression_type': A.ImageCompression.ImageCompressionType.WEBP}],
[A.JpegCompression, {'quality_lower': 10, 'quality_upper': 80}],
[A.HueSaturationValue, {'hue_shift_limit': 70, 'sat_shift_limit': 95, 'val_shift_limit': 55}],
[A.RGBShift, {'r_shift_limit': 70, 'g_shift_limit': 80, 'b_shift_limit': 40}],
Expand Down Expand Up @@ -190,6 +193,7 @@ def test_augmentations_serialization_with_custom_parameters(


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[A.ImageCompression, {}],
[A.JpegCompression, {}],
[A.HueSaturationValue, {}],
[A.RGBShift, {}],
Expand Down Expand Up @@ -252,6 +256,7 @@ def test_augmentations_for_bboxes_serialization(augmentation_cls, params, p, see


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[A.ImageCompression, {}],
[A.JpegCompression, {}],
[A.HueSaturationValue, {}],
[A.RGBShift, {}],
Expand Down Expand Up @@ -542,6 +547,7 @@ def test_transform_pipeline_serialization_with_keypoints(seed, image, keypoints,
[A.ChannelShuffle, {}],
[A.GaussNoise, {}],
[A.Cutout, {}],
[A.ImageCompression, {}],
[A.JpegCompression, {}],
[A.HueSaturationValue, {}],
[A.RGBShift, {}],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_force_apply():
[A.GaussNoise, {}],
[A.Cutout, {}],
[A.CoarseDropout, {}],
[A.JpegCompression, {}],
[A.ImageCompression, {}],
[A.HueSaturationValue, {}],
[A.RGBShift, {}],
[A.RandomBrightnessContrast, {}],
Expand Down

0 comments on commit b612786

Please sign in to comment.