diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index 5162a061e..6c9b85943 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -2430,7 +2430,7 @@ def get_params_dependent_on_targets(self, params): shape = [c] multiplier = np.random.uniform(self.multiplier[0], self.multiplier[1], shape) - if F.is_grayscale_image(img): + if F.is_grayscale_image(img) and img.ndim == 2: multiplier = np.squeeze(multiplier) return {"multiplier": multiplier} diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 8e7be5143..a78d848a3 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -524,7 +524,13 @@ def test_resize_keypoints(): @pytest.mark.parametrize( - "image", [np.random.randint(0, 256, [256, 320], np.uint8), np.random.random([256, 320]).astype(np.float32)] + "image", + [ + np.random.randint(0, 256, [256, 320], np.uint8), + np.random.random([256, 320]).astype(np.float32), + np.random.randint(0, 256, [256, 320, 1], np.uint8), + np.random.random([256, 320, 1]).astype(np.float32), + ], ) def test_multiplicative_noise_grayscale(image): m = 0.5