Skip to content

Commit

Permalink
Added RandomSizedCropTorchVision (#335)
Browse files Browse the repository at this point in the history
* [WIP] Added RandomSizedCropTorchVision

* Refactored RandomSizedCrop-*, updated input args, added tests

* Renamed classes to RandomSizedCrop and RandomResizedCrop

* Removed deprecation warning for RandomSizedCrop

* Added missing get_transform_init_args_names
  • Loading branch information
vfdev-5 authored and ternaus committed Sep 4, 2019
1 parent ac499d0 commit 4dbe41e
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 24 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [RandomCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCrop) |||||
| [RandomCropNearBBox](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCropNearBBox) |||| |
| [RandomGridShuffle](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomGridShuffle) ||| | |
| [RandomResizedCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomResizedCrop) |||||
| [RandomRotate90](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomRotate90) |||||
| [RandomScale](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomScale) |||||
| [RandomSizedBBoxSafeCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomSizedBBoxSafeCrop) |||| |
Expand Down
130 changes: 111 additions & 19 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
'GaussianBlur', 'GaussNoise', 'CLAHE', 'ChannelShuffle', 'InvertImg',
'ToGray', 'JpegCompression', 'Cutout', 'CoarseDropout', 'ToFloat',
'FromFloat', 'Crop', 'RandomScale', 'LongestMaxSize', 'SmallestMaxSize',
'Resize', 'RandomSizedCrop', 'RandomBrightnessContrast',
'Resize', 'RandomSizedCrop', 'RandomResizedCrop', 'RandomBrightnessContrast',
'RandomCropNearBBox', 'RandomSizedBBoxSafeCrop', 'RandomSnow',
'RandomRain', 'RandomFog', 'RandomSunFlare', 'RandomShadow', 'Lambda',
'ChannelDropout', 'ISONoise', 'Solarize', 'Equalize'
Expand Down Expand Up @@ -666,7 +666,31 @@ def get_transform_init_args_names(self):
return ('max_part_shift',)


class RandomSizedCrop(DualTransform):
class _BaseRandomSizedCrop(DualTransform):
# Base class for RandomSizedCrop and RandomResizedCrop

def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0):
super(_BaseRandomSizedCrop, self).__init__(always_apply, p)
self.height = height
self.width = width
self.interpolation = interpolation

def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, interpolation=cv2.INTER_LINEAR, **params):
crop = F.random_crop(img, crop_height, crop_width, h_start, w_start)
return F.resize(crop, self.height, self.width, interpolation)

def apply_to_bbox(self, bbox, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
return F.bbox_random_crop(bbox, crop_height, crop_width, h_start, w_start, rows, cols)

def apply_to_keypoint(self, keypoint, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
keypoint = F.keypoint_random_crop(keypoint, crop_height, crop_width, h_start, w_start, rows, cols)
scale_x = self.width / crop_height
scale_y = self.height / crop_height
keypoint = F.keypoint_scale(keypoint, scale_x, scale_y)
return keypoint


class RandomSizedCrop(_BaseRandomSizedCrop):
"""Crop a random part of the input and rescale it to some size.
Args:
Expand All @@ -688,36 +712,104 @@ class RandomSizedCrop(DualTransform):

def __init__(self, min_max_height, height, width, w2h_ratio=1., interpolation=cv2.INTER_LINEAR,
always_apply=False, p=1.0):
super(RandomSizedCrop, self).__init__(always_apply, p)
self.height = height
self.width = width
self.interpolation = interpolation
super(RandomSizedCrop, self).__init__(height=height, width=width,
interpolation=interpolation,
always_apply=always_apply, p=p)
self.min_max_height = min_max_height
self.w2h_ratio = w2h_ratio

def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, interpolation=cv2.INTER_LINEAR, **params):
crop = F.random_crop(img, crop_height, crop_width, h_start, w_start)
return F.resize(crop, self.height, self.width, interpolation)

def get_params(self):
crop_height = random.randint(self.min_max_height[0], self.min_max_height[1])
return {'h_start': random.random(),
'w_start': random.random(),
'crop_height': crop_height,
'crop_width': int(crop_height * self.w2h_ratio)}

def apply_to_bbox(self, bbox, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
return F.bbox_random_crop(bbox, crop_height, crop_width, h_start, w_start, rows, cols)
def get_transform_init_args_names(self):
return 'min_max_height', 'height', 'width', 'w2h_ratio', 'interpolation'

def apply_to_keypoint(self, keypoint, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
keypoint = F.keypoint_random_crop(keypoint, crop_height, crop_width, h_start, w_start, rows, cols)
scale_x = self.width / crop_height
scale_y = self.height / crop_height
keypoint = F.keypoint_scale(keypoint, scale_x, scale_y)
return keypoint

class RandomResizedCrop(_BaseRandomSizedCrop):
"""Torchvision's variant of crop a random part of the input and rescale it to some size.
Args:
height (int): height after crop and resize.
width (int): width after crop and resize.
scale ((float, float)): range of size of the origin size cropped
ratio ((float, float)): range of aspect ratio of the origin aspect ratio cropped
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_LINEAR.
p (float): probability of applying the transform. Default: 1.
Targets:
image, mask, bboxes, keypoints
Image types:
uint8, float32
"""

def __init__(self, height, width, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333),
interpolation=cv2.INTER_LINEAR,
always_apply=False, p=1.0):

super(RandomResizedCrop, self).__init__(height=height, width=width,
interpolation=interpolation,
always_apply=always_apply, p=p)
self.scale = scale
self.ratio = ratio

def get_params_dependent_on_targets(self, params):
img = params['image']
area = img.shape[0] * img.shape[1]

for attempt in range(10):
target_area = random.uniform(*self.scale) * area
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))

w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))

if w <= img.shape[1] and h <= img.shape[0]:
i = random.randint(0, img.shape[0] - h)
j = random.randint(0, img.shape[1] - w)
return {
'crop_height': h,
'crop_width': w,
'h_start': i * 1.0 / (img.shape[0] - h + 1e-10),
'w_start': j * 1.0 / (img.shape[1] - w + 1e-10)
}

# Fallback to central crop
in_ratio = img.shape[1] / img.shape[0]
if in_ratio < min(self.ratio):
w = img.shape[1]
h = w / min(self.ratio)
elif in_ratio > max(self.ratio):
h = img.shape[0]
w = h * max(self.ratio)
else: # whole image
w = img.shape[1]
h = img.shape[0]
i = (img.shape[0] - h) // 2
j = (img.shape[1] - w) // 2
return {
'crop_height': h,
'crop_width': w,
'h_start': i * 1.0 / (img.shape[0] - h + 1e-10),
'w_start': j * 1.0 / (img.shape[1] - w + 1e-10)
}

def get_params(self):
return {}

@property
def targets_as_params(self):
return ['image']

def get_transform_init_args_names(self):
return ('min_max_height', 'height', 'width', 'w2h_ratio', 'interpolation')
return 'height', 'width', 'scale', 'ratio', 'interpolation'


class RandomSizedBBoxSafeCrop(DualTransform):
Expand Down
6 changes: 5 additions & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
IAASharpen, IAAAdditiveGaussianNoise, IAAPiecewiseAffine, IAAPerspective,
Cutout, CoarseDropout, Normalize, ToFloat, FromFloat,
RandomBrightnessContrast, RandomSnow, RandomRain, RandomFog,
RandomSunFlare, RandomCropNearBBox, RandomShadow, RandomSizedCrop,
RandomSunFlare, RandomCropNearBBox, RandomShadow, RandomSizedCrop, RandomResizedCrop,
ChannelDropout, ISONoise, Solarize, Equalize)


Expand Down Expand Up @@ -100,6 +100,7 @@ def test_image_only_augmentations_with_float_values(augmentation_cls, params, fl
[ElasticTransform, {}],
[CenterCrop, {'height': 10, 'width': 10}],
[RandomCrop, {'height': 10, 'width': 10}],
[RandomResizedCrop, {'height': 10, 'width': 10}],
[RandomSizedCrop, {'min_max_height': (4, 8), 'height': 10, 'width': 10}],
[ISONoise, {}],
[RandomGridShuffle, {}]
Expand All @@ -125,6 +126,7 @@ def test_dual_augmentations(augmentation_cls, params, image, mask):
[ElasticTransform, {}],
[CenterCrop, {'height': 10, 'width': 10}],
[RandomCrop, {'height': 10, 'width': 10}],
[RandomResizedCrop, {'height': 10, 'width': 10}],
[RandomSizedCrop, {'min_max_height': (4, 8), 'height': 10, 'width': 10}],
[RandomGridShuffle, {}]
])
Expand Down Expand Up @@ -184,6 +186,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):
[ElasticTransform, {}],
[CenterCrop, {'height': 10, 'width': 10}],
[RandomCrop, {'height': 10, 'width': 10}],
[RandomResizedCrop, {'height': 10, 'width': 10}],
[RandomSizedCrop, {'min_max_height': (4, 8), 'height': 10, 'width': 10}],
[Normalize, {}],
[GaussNoise, {}],
Expand Down Expand Up @@ -238,6 +241,7 @@ def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
[ElasticTransform, {}],
[CenterCrop, {'height': 10, 'width': 10}],
[RandomCrop, {'height': 10, 'width': 10}],
[RandomResizedCrop, {'height': 10, 'width': 10}],
[RandomSizedCrop, {'min_max_height': (4, 8), 'height': 10, 'width': 10}],
[Normalize, {}],
[GaussNoise, {}],
Expand Down
13 changes: 11 additions & 2 deletions tests/test_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
convert_bbox_from_albumentations, convert_bboxes_to_albumentations, convert_bboxes_from_albumentations
from albumentations.core.composition import Compose
from albumentations.core.transforms_interface import NoOp
from albumentations.augmentations.transforms import RandomSizedCrop, Rotate, RandomRotate90
from albumentations.augmentations.transforms import RandomSizedCrop, RandomResizedCrop, Rotate, RandomRotate90


@pytest.mark.parametrize(['bbox', 'expected'], [
Expand Down Expand Up @@ -186,7 +186,16 @@ def test_compose_with_bbox_noop_label_outside(bboxes, bbox_format, labels):
def test_random_sized_crop_size():
image = np.ones((100, 100, 3))
bboxes = [[0.2, 0.3, 0.6, 0.8], [0.3, 0.4, 0.7, 0.9, 99]]
aug = RandomSizedCrop((70, 90), 50, 50, p=1.)
aug = RandomSizedCrop(min_max_height=(70, 90), height=50, width=50, p=1.)
transformed = aug(image=image, bboxes=bboxes)
assert transformed['image'].shape == (50, 50, 3)
assert len(bboxes) == len(transformed['bboxes'])


def test_random_resized_crop_size():
image = np.ones((100, 100, 3))
bboxes = [[0.2, 0.3, 0.6, 0.8], [0.3, 0.4, 0.7, 0.9, 99]]
aug = RandomResizedCrop(height=50, width=50, p=1.)
transformed = aug(image=image, bboxes=bboxes)
assert transformed['image'].shape == (50, 50, 3)
assert len(bboxes) == len(transformed['bboxes'])
Expand Down
13 changes: 11 additions & 2 deletions tests/test_keypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
convert_keypoints_from_albumentations, convert_keypoint_to_albumentations, convert_keypoints_to_albumentations
from albumentations.core.composition import Compose
from albumentations.core.transforms_interface import NoOp
from albumentations.augmentations.transforms import RandomSizedCrop
from albumentations.augmentations.transforms import RandomSizedCrop, RandomResizedCrop
import albumentations.augmentations.functional as F


Expand Down Expand Up @@ -129,7 +129,16 @@ def test_compose_with_keypoint_noop_label_outside(keypoints, keypoint_format, la
def test_random_sized_crop_size():
image = np.ones((100, 100, 3))
keypoints = [[0.2, 0.3, 0.6, 0.8], [0.3, 0.4, 0.7, 0.9, 99]]
aug = RandomSizedCrop((70, 90), 50, 50, p=1.)
aug = RandomSizedCrop(min_max_height=(70, 90), height=50, width=50, p=1.)
transformed = aug(image=image, keypoints=keypoints)
assert transformed['image'].shape == (50, 50, 3)
assert len(keypoints) == len(transformed['keypoints'])


def test_random_resized_crop_size():
image = np.ones((100, 100, 3))
keypoints = [[0.2, 0.3, 0.6, 0.8], [0.3, 0.4, 0.7, 0.9, 99]]
aug = RandomResizedCrop(height=50, width=50, p=1.)
transformed = aug(image=image, keypoints=keypoints)
assert transformed['image'].shape == (50, 50, 3)
assert len(keypoints) == len(transformed['keypoints'])
Expand Down

0 comments on commit 4dbe41e

Please sign in to comment.