Skip to content

Commit

Permalink
Update pad if needed padding (#1714)
Browse files Browse the repository at this point in the history
* Removed default values from geometric functional functions

* Removed default values from geometric functional functions

* Removed default values from geometric functional functions

* Added tests to check parameters
  • Loading branch information
ternaus committed May 8, 2024
1 parent 2d5eca3 commit 6f7bc1d
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 58 deletions.
63 changes: 34 additions & 29 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
preserve_channel_dim,
)
from albumentations.core.bbox_utils import denormalize_bbox, normalize_bbox
from albumentations.core.types import BoxInternalType, ColorType, D4Type, KeypointInternalType
from albumentations.core.types import (
NUM_MULTI_CHANNEL_DIMENSIONS,
BoxInternalType,
ColorType,
D4Type,
KeypointInternalType,
)

__all__ = [
"optical_distortion",
Expand Down Expand Up @@ -70,7 +76,6 @@
]

TWO = 2
THREE = 3

ROT90_180_FACTOR = 2
ROT90_270_FACTOR = 3
Expand Down Expand Up @@ -239,8 +244,8 @@ def keypoint_d4(
def rotate(
img: np.ndarray,
angle: float,
interpolation: int = cv2.INTER_LINEAR,
border_mode: int = cv2.BORDER_REFLECT_101,
interpolation: int,
border_mode: int,
value: Optional[ColorType] = None,
) -> np.ndarray:
height, width = img.shape[:2]
Expand Down Expand Up @@ -272,7 +277,7 @@ def bbox_rotate(bbox: BoxInternalType, angle: float, method: str, rows: int, col
Returns:
A bounding box `(x_min, y_min, x_max, y_max)`.
References:
Reference:
https://arxiv.org/abs/2109.13488
"""
Expand Down Expand Up @@ -334,8 +339,8 @@ def elastic_transform(
alpha: float,
sigma: float,
alpha_affine: float,
interpolation: int = cv2.INTER_LINEAR,
border_mode: int = cv2.BORDER_REFLECT_101,
interpolation: int,
border_mode: int,
value: Optional[ColorType] = None,
random_state: Optional[np.random.RandomState] = None,
approximate: bool = False,
Expand Down Expand Up @@ -423,7 +428,7 @@ def elastic_transform(


@preserve_channel_dim
def resize(img: np.ndarray, height: int, width: int, interpolation: int = cv2.INTER_LINEAR) -> np.ndarray:
def resize(img: np.ndarray, height: int, width: int, interpolation: int) -> np.ndarray:
img_height, img_width = img.shape[:2]
if (height, width) == img.shape[:2]:
return img
Expand All @@ -432,7 +437,7 @@ def resize(img: np.ndarray, height: int, width: int, interpolation: int = cv2.IN


@preserve_channel_dim
def scale(img: np.ndarray, scale: float, interpolation: int = cv2.INTER_LINEAR) -> np.ndarray:
def scale(img: np.ndarray, scale: float, interpolation: int) -> np.ndarray:
height, width = img.shape[:2]
new_height, new_width = int(height * scale), int(width * scale)
return resize(img, new_height, new_width, interpolation)
Expand Down Expand Up @@ -531,7 +536,7 @@ def perspective_bbox(
)


def rotation2d_matrix_to_euler_angles(matrix: np.ndarray, y_up: bool = False) -> float:
def rotation2d_matrix_to_euler_angles(matrix: np.ndarray, y_up: bool) -> float:
"""Args:
matrix (np.ndarray): Rotation matrix
y_up (bool): is Y axis looks up or down
Expand Down Expand Up @@ -610,7 +615,7 @@ def keypoint_affine(

x, y, a, s = keypoint[:4]
x, y = cv2.transform(np.array([[[x, y]]]), matrix.params[:2]).squeeze()
a += rotation2d_matrix_to_euler_angles(matrix.params[:2])
a += rotation2d_matrix_to_euler_angles(matrix.params[:2], y_up=False)
s *= np.max([scale["x"], scale["y"]])
return x, y, a, s

Expand Down Expand Up @@ -832,12 +837,12 @@ def from_distance_maps(
distance_maps: np.ndarray,
inverted: bool,
if_not_found_coords: Optional[Union[Sequence[int], Dict[str, Any]]],
threshold: Optional[float] = None,
threshold: Optional[float],
) -> List[Tuple[float, float]]:
"""Convert outputs of `to_distance_maps` to `KeypointsOnImage`.
This is the inverse of `to_distance_maps`.
"""
if distance_maps.ndim != THREE:
if distance_maps.ndim != NUM_MULTI_CHANNEL_DIMENSIONS:
msg = f"Expected three-dimensional input, got {distance_maps.ndim} dimensions and shape {distance_maps.shape}."
raise ValueError(msg)
height, width, nb_keypoints = distance_maps.shape
Expand Down Expand Up @@ -1168,8 +1173,8 @@ def pad(
img: np.ndarray,
min_height: int,
min_width: int,
border_mode: int = cv2.BORDER_REFLECT_101,
value: Optional[ColorType] = None,
border_mode: int,
value: Optional[ColorType],
) -> np.ndarray:
height, width = img.shape[:2]

Expand Down Expand Up @@ -1204,8 +1209,8 @@ def pad_with_params(
h_pad_bottom: int,
w_pad_left: int,
w_pad_right: int,
border_mode: int = cv2.BORDER_REFLECT_101,
value: Optional[ColorType] = None,
border_mode: int,
value: Optional[ColorType],
) -> np.ndarray:
pad_fn = _maybe_process_in_chunks(
cv2.copyMakeBorder,
Expand All @@ -1222,11 +1227,11 @@ def pad_with_params(
@preserve_channel_dim
def optical_distortion(
img: np.ndarray,
k: int = 0,
dx: int = 0,
dy: int = 0,
interpolation: int = cv2.INTER_LINEAR,
border_mode: int = cv2.BORDER_REFLECT_101,
k: int,
dx: int,
dy: int,
interpolation: int,
border_mode: int,
value: Optional[ColorType] = None,
) -> np.ndarray:
"""Barrel / pincushion distortion. Unconventional augment.
Expand Down Expand Up @@ -1255,11 +1260,11 @@ def optical_distortion(
@preserve_channel_dim
def grid_distortion(
img: np.ndarray,
num_steps: int = 10,
xsteps: Tuple[()] = (),
ysteps: Tuple[()] = (),
interpolation: int = cv2.INTER_LINEAR,
border_mode: int = cv2.BORDER_REFLECT_101,
num_steps: int,
xsteps: Tuple[()],
ysteps: Tuple[()],
interpolation: int,
border_mode: int,
value: Optional[ColorType] = None,
) -> np.ndarray:
height, width = img.shape[:2]
Expand Down Expand Up @@ -1317,8 +1322,8 @@ def elastic_transform_approx(
alpha: float,
sigma: float,
alpha_affine: float,
interpolation: int = cv2.INTER_LINEAR,
border_mode: int = cv2.BORDER_REFLECT_101,
interpolation: int,
border_mode: int,
value: Optional[ColorType] = None,
random_state: Optional[np.random.RandomState] = None,
) -> np.ndarray:
Expand Down
40 changes: 26 additions & 14 deletions albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,22 +1188,26 @@ def apply_to_keypoint(


class PadIfNeeded(DualTransform):
"""Pad side of the image / max if side is less than desired number.
"""Pads the sides of an image if the image dimensions are less than the specified minimum dimensions.
If the `pad_height_divisor` or `pad_width_divisor` is specified, the function additionally ensures
that the image dimensions are divisible by these values.
Args:
min_height (int): minimal result image height.
min_width (int): minimal result image width.
pad_height_divisor (int): if not None, ensures image height is dividable by value of this argument.
pad_width_divisor (int): if not None, ensures image width is dividable by value of this argument.
position (Union[str, PositionType]): Position of the image. should be PositionType.CENTER or
PositionType.TOP_LEFT or PositionType.TOP_RIGHT or PositionType.BOTTOM_LEFT or PositionType.BOTTOM_RIGHT.
or PositionType.RANDOM. Default: PositionType.CENTER.
border_mode (OpenCV flag): OpenCV border mode.
value (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
mask_value (int, float,
list of int,
list of float): padding value for mask if border_mode is cv2.BORDER_CONSTANT.
p (float): probability of applying the transform. Default: 1.0.
min_height (int): Minimum desired height of the image. Ensures image height is at least this value.
min_width (int): Minimum desired width of the image. Ensures image width is at least this value.
pad_height_divisor (int, optional): If set, pads the image height to make it divisible by this value.
pad_width_divisor (int, optional): If set, pads the image width to make it divisible by this value.
position (Union[str, PositionType]): Position where the image is to be placed after padding.
Can be one of 'center', 'top_left', 'top_right', 'bottom_left', 'bottom_right', or 'random'.
Default is 'center'.
border_mode (int): Specifies the border mode to use if padding is required.
The default is `cv2.BORDER_REFLECT_101`. If `value` is provided and `border_mode` is set to a mode
that does not use a constant value, it should be manually set to `cv2.BORDER_CONSTANT`.
value (Union[int, float, list[int], list[float]], optional): Value to fill the border pixels if
the border mode is `cv2.BORDER_CONSTANT`. Default is None.
mask_value (Union[int, float, list[int], list[float]], optional): Similar to `value` but used for padding masks.
Default is None.
p (float): Probability of applying the transform. Default is 1.0.
Targets:
image, mask, bboxes, keypoints
Expand Down Expand Up @@ -1269,6 +1273,14 @@ def validate_divisibility(self) -> Self:
if (self.min_width is None) == (self.pad_width_divisor is None):
msg = "Only one of 'min_width' and 'pad_width_divisor' parameters must be set"
raise ValueError(msg)

if self.value is not None and self.border_mode in {cv2.BORDER_REFLECT_101, cv2.BORDER_REFLECT101}:
self.border_mode = cv2.BORDER_CONSTANT

if self.border_mode == cv2.BORDER_CONSTANT and self.value is None:
msg = "If 'border_mode' is set to 'BORDER_CONSTANT', 'value' must be provided."
raise ValueError(msg)

return self

def __init__(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,12 +853,12 @@ def test_pad_if_needed(augmentation_cls: Type[A.PadIfNeeded], params: Dict, imag
@pytest.mark.parametrize(
["params", "image_shape"],
[
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "center"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "top_left"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "top_right"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "bottom_left"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "bottom_right"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "random"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": cv2.BORDER_CONSTANT, "value": 1, "position": "center"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": cv2.BORDER_CONSTANT, "value": 1, "position": "top_left"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": cv2.BORDER_CONSTANT, "value": 1, "position": "top_right"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": cv2.BORDER_CONSTANT, "value": 1, "position": "bottom_left"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": cv2.BORDER_CONSTANT, "value": 1, "position": "bottom_right"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": cv2.BORDER_CONSTANT, "value": 1, "position": "random"}, (5, 6)],
],
)
def test_pad_if_needed_position(params, image_shape):
Expand Down
16 changes: 8 additions & 8 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_compare_rotate_and_affine(image):
rotation_matrix = generate_rotation_matrix(image, 60)

# Apply rotation using FGeometric.rotate
rotated_img_1 = FGeometric.rotate(image, angle=60, border_mode = cv2.BORDER_CONSTANT, value = 0)
rotated_img_1 = FGeometric.rotate(image, angle=60, border_mode = cv2.BORDER_CONSTANT, value = 0, interpolation=cv2.INTER_LINEAR)

# Convert 2x3 cv2 matrix to 3x3 for skimage's ProjectiveTransform
full_matrix = np.vstack([rotation_matrix, [0, 0, 1]])
Expand Down Expand Up @@ -289,7 +289,7 @@ def test_pad(target):
img = np.array([[1, 2], [3, 4]], dtype=np.uint8)
expected = np.array([[4, 3, 4, 3], [2, 1, 2, 1], [4, 3, 4, 3], [2, 1, 2, 1]], dtype=np.uint8)
img, expected = convert_2d_to_target_format([img, expected], target=target)
padded = FGeometric.pad(img, min_height=4, min_width=4)
padded = FGeometric.pad(img, min_height=4, min_width=4, border_mode=cv2.BORDER_REFLECT_101, value=None)
assert np.array_equal(padded, expected)


Expand All @@ -300,7 +300,7 @@ def test_pad_float(target):
[[0.4, 0.3, 0.4, 0.3], [0.2, 0.1, 0.2, 0.1], [0.4, 0.3, 0.4, 0.3], [0.2, 0.1, 0.2, 0.1]], dtype=np.float32
)
img, expected = convert_2d_to_target_format([img, expected], target=target)
padded_img = FGeometric.pad(img, min_height=4, min_width=4)
padded_img = FGeometric.pad(img, min_height=4, min_width=4, value=None, border_mode=cv2.BORDER_REFLECT_101)
assert_array_almost_equal_nulp(padded_img, expected)


Expand Down Expand Up @@ -542,11 +542,11 @@ def test_from_float_unknown_dtype():


@pytest.mark.parametrize("target", ["image", "mask"])
def test_resize_default_interpolation(target):
def test_resize_linear_interpolation(target):
img = np.array([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], dtype=np.uint8)
expected = np.array([[2, 2], [4, 4]], dtype=np.uint8)
img, expected = convert_2d_to_target_format([img, expected], target=target)
resized_img = FGeometric.resize(img, 2, 2)
resized_img = FGeometric.resize(img, 2, 2, interpolation=cv2.INTER_LINEAR)
height, width = resized_img.shape[:2]
assert height == 2
assert width == 2
Expand All @@ -569,7 +569,7 @@ def test_resize_nearest_interpolation(target):
def test_resize_different_height_and_width(target):
img = np.ones((100, 100), dtype=np.uint8)
img = convert_2d_to_target_format([img], target=target)
resized_img = FGeometric.resize(img, height=20, width=30)
resized_img = FGeometric.resize(img, height=20, width=30, interpolation=cv2.INTER_LINEAR)
height, width = resized_img.shape[:2]
assert height == 20
assert width == 30
Expand All @@ -585,7 +585,7 @@ def test_resize_default_interpolation_float(target):
)
expected = np.array([[0.15, 0.15], [0.35, 0.35]], dtype=np.float32)
img, expected = convert_2d_to_target_format([img, expected], target=target)
resized_img = FGeometric.resize(img, 2, 2)
resized_img = FGeometric.resize(img, 2, 2, interpolation=cv2.INTER_LINEAR)
height, width = resized_img.shape[:2]
assert height == 2
assert width == 2
Expand Down Expand Up @@ -972,7 +972,7 @@ def test_maybe_process_in_chunks():

for i in range(1, image.shape[-1] + 1):
before = image[:, :, :i]
after = FGeometric.rotate(before, angle=1)
after = FGeometric.rotate(before, angle=1, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101)
assert before.shape == after.shape


Expand Down
28 changes: 27 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,4 +1497,30 @@ def test_downscale_functionality(params, expected):
])
def test_downscale_invalid_input(params):
with pytest.raises(Exception):
aug = A.Downscale(**params, p=1)
A.Downscale(**params, p=1)


@pytest.mark.parametrize("params, expected", [
# Default values
({}, {"min_height": 1024, "min_width": 1024, "position": A.PadIfNeeded.PositionType.CENTER, "border_mode": cv2.BORDER_REFLECT_101}),
# Boundary values
({"min_height": 800, "min_width": 800}, {"min_height": 800, "min_width": 800}),
({"pad_height_divisor": 10, "min_height": None, "pad_width_divisor": 10, "min_width": None},
{"pad_height_divisor": 10, "min_height": None, "pad_width_divisor": 10, "min_width": None}),
({"position": "top_left"}, {"position": A.PadIfNeeded.PositionType.TOP_LEFT}),
# Value handling when border_mode is BORDER_CONSTANT
({"border_mode": cv2.BORDER_CONSTANT, "value": 255}, {"border_mode": cv2.BORDER_CONSTANT, "value": 255}),
({"border_mode": cv2.BORDER_REFLECT_101, "value": 255}, {"border_mode": cv2.BORDER_CONSTANT, "value": 255}),
({"border_mode": cv2.BORDER_CONSTANT, "value": [0, 0, 0]}, {"border_mode": cv2.BORDER_CONSTANT, "value": [0, 0, 0]}),
# Mask value handling
({"border_mode": cv2.BORDER_CONSTANT, "value": [0, 0, 0], "mask_value": 128}, {"border_mode": cv2.BORDER_CONSTANT, "mask_value": 128, "value": [0, 0, 0]}),
])
def test_pad_if_needed_functionality(params, expected):
# Setup the augmentation with the provided parameters
aug = A.PadIfNeeded(**params, p=1)
# Get the initialization arguments to check against expected
aug_dict = {key: getattr(aug, key) for key in expected.keys()}

# Assert each expected key/value pair
for key, value in expected.items():
assert aug_dict[key] == value, f"Failed on {key} with value {value}"

0 comments on commit 6f7bc1d

Please sign in to comment.