Skip to content

Commit

Permalink
Add a condition to verify training status during image processing (ke…
Browse files Browse the repository at this point in the history
…ras-team#20650)

* Add a condition to verify training status during image processing

* resolve merge conflict

* fix transform_bounding_boxes logic

* add transform_bounding_boxes test
  • Loading branch information
shashaka authored Dec 18, 2024
1 parent ed1442e commit 4c7c4b5
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 84 deletions.
128 changes: 79 additions & 49 deletions keras/src/layers/preprocessing/image_preprocessing/mix_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
BaseImagePreprocessingLayer,
)
from keras.src.random import SeedGenerator
from keras.src.utils import backend_utils


@keras_export("keras.layers.MixUp")
Expand Down Expand Up @@ -66,36 +67,40 @@ def get_random_transformation(self, data, training=True, seed=None):
}

def transform_images(self, images, transformation=None, training=True):
images = self.backend.cast(images, self.compute_dtype)
mix_weight = transformation["mix_weight"]
permutation_order = transformation["permutation_order"]

mix_weight = self.backend.cast(
self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1]),
dtype=self.compute_dtype,
)

mix_up_images = self.backend.cast(
self.backend.numpy.take(images, permutation_order, axis=0),
dtype=self.compute_dtype,
)

images = mix_weight * images + (1.0 - mix_weight) * mix_up_images

def _mix_up_input(images, transformation):
images = self.backend.cast(images, self.compute_dtype)
mix_weight = transformation["mix_weight"]
permutation_order = transformation["permutation_order"]
mix_weight = self.backend.cast(
self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1]),
dtype=self.compute_dtype,
)
mix_up_images = self.backend.cast(
self.backend.numpy.take(images, permutation_order, axis=0),
dtype=self.compute_dtype,
)
images = mix_weight * images + (1.0 - mix_weight) * mix_up_images
return images

if training:
images = _mix_up_input(images, transformation)
return images

def transform_labels(self, labels, transformation, training=True):
mix_weight = transformation["mix_weight"]
permutation_order = transformation["permutation_order"]

labels_for_mix_up = self.backend.numpy.take(
labels, permutation_order, axis=0
)

mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1])

labels = mix_weight * labels + (1.0 - mix_weight) * labels_for_mix_up

def _mix_up_labels(labels, transformation):
mix_weight = transformation["mix_weight"]
permutation_order = transformation["permutation_order"]
labels_for_mix_up = self.backend.numpy.take(
labels, permutation_order, axis=0
)
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1])
labels = (
mix_weight * labels + (1.0 - mix_weight) * labels_for_mix_up
)
return labels

if training:
labels = _mix_up_labels(labels, transformation)
return labels

def transform_bounding_boxes(
Expand All @@ -104,33 +109,58 @@ def transform_bounding_boxes(
transformation,
training=True,
):
permutation_order = transformation["permutation_order"]
boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"]
boxes_for_mix_up = self.backend.numpy.take(boxes, permutation_order)
classes_for_mix_up = self.backend.numpy.take(classes, permutation_order)
boxes = self.backend.numpy.concat([boxes, boxes_for_mix_up], axis=1)
classes = self.backend.numpy.concat(
[classes, classes_for_mix_up], axis=1
)
return {"boxes": boxes, "classes": classes}
def _mix_up_bounding_boxes(bounding_boxes, transformation):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
mix_weight = transformation["mix_weight"]
permutation_order = transformation["permutation_order"]
permutation_order = transformation["permutation_order"]

mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1])
boxes, labels = bounding_boxes["boxes"], bounding_boxes["labels"]
boxes_for_mix_up = self.backend.numpy.take(
boxes, permutation_order, axis=0
)

segmentation_masks_for_mix_up = self.backend.numpy.take(
segmentation_masks, permutation_order
)
labels_for_mix_up = self.backend.numpy.take(
labels, permutation_order, axis=0
)
boxes = self.backend.numpy.concatenate(
[boxes, boxes_for_mix_up], axis=1
)

segmentation_masks = (
mix_weight * segmentation_masks
+ (1.0 - mix_weight) * segmentation_masks_for_mix_up
)
labels = self.backend.numpy.concatenate(
[labels, labels_for_mix_up], axis=0
)

self.backend.reset()

return {"boxes": boxes, "labels": labels}

if training:
bounding_boxes = _mix_up_bounding_boxes(
bounding_boxes, transformation
)
return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
def _mix_up_segmentation_masks(segmentation_masks, transformation):
mix_weight = transformation["mix_weight"]
permutation_order = transformation["permutation_order"]
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1])
segmentation_masks_for_mix_up = self.backend.numpy.take(
segmentation_masks, permutation_order
)
segmentation_masks = (
mix_weight * segmentation_masks
+ (1.0 - mix_weight) * segmentation_masks_for_mix_up
)
return segmentation_masks

if training:
segmentation_masks = _mix_up_segmentation_masks(
segmentation_masks, transformation
)
return segmentation_masks

def compute_output_shape(self, input_shape):
Expand Down
92 changes: 92 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/mix_up_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import pytest
from tensorflow import data as tf_data

from keras.src import backend
from keras.src import layers
from keras.src import testing
from keras.src.backend import convert_to_tensor


class MixUpTest(testing.TestCase):
Expand All @@ -21,6 +23,14 @@ def test_layer(self):
run_training_check=not testing.tensorflow_uses_gpu(),
)

def test_mix_up_inference(self):
seed = 3481
layer = layers.MixUp(alpha=0.2)
np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_mix_up_basic_functionality(self):
image = np.random.random((64, 64, 3))
mix_up_layer = layers.MixUp(alpha=1)
Expand Down Expand Up @@ -63,3 +73,85 @@ def test_tf_data_compatibility(self):
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

def test_mix_up_bounding_boxes(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image_shape = (10, 8, 3)
else:
image_shape = (3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
),
"labels": np.array([1, 2]),
}
input_data = {"images": input_image, "bounding_boxes": bounding_boxes}

expected_boxes = [[2, 1, 4, 3, 6, 4, 8, 6], [6, 4, 8, 6, 2, 1, 4, 3]]

random_flip_layer = layers.MixUp(
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)

transformation = {
"mix_weight": convert_to_tensor([0.5, 0.5]),
"permutation_order": convert_to_tensor([1, 0]),
}
output = random_flip_layer.transform_bounding_boxes(
input_data["bounding_boxes"],
transformation=transformation,
training=True,
)
self.assertAllClose(output["boxes"], expected_boxes)

def test_mix_up_tf_data_bounding_boxes(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image_shape = (1, 10, 8, 3)
else:
image_shape = (1, 3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
]
),
"labels": np.array([[1, 2]]),
}

input_data = {"images": input_image, "bounding_boxes": bounding_boxes}
expected_boxes = [[2, 1, 4, 3, 6, 4, 8, 6], [6, 4, 8, 6, 2, 1, 4, 3]]

ds = tf_data.Dataset.from_tensor_slices(input_data)
layer = layers.MixUp(
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)

transformation = {
"mix_weight": convert_to_tensor([0.5, 0.5]),
"permutation_order": convert_to_tensor([1, 0]),
}
ds = ds.map(
lambda x: layer.transform_bounding_boxes(
x["bounding_boxes"],
transformation=transformation,
training=True,
)
)

output = next(iter(ds))
expected_boxes = np.array(expected_boxes)
self.assertAllClose(output["boxes"], expected_boxes)
76 changes: 41 additions & 35 deletions keras/src/layers/preprocessing/image_preprocessing/random_hue.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,46 +87,52 @@ def get_random_transformation(self, data, training=True, seed=None):
return {"factor": invert * factor * 0.5}

def transform_images(self, images, transformation=None, training=True):
images = self.backend.cast(images, self.compute_dtype)
images = self._transform_value_range(images, self.value_range, (0, 1))
adjust_factors = transformation["factor"]
adjust_factors = self.backend.cast(adjust_factors, images.dtype)
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)

images = self.backend.image.rgb_to_hsv(
images, data_format=self.data_format
)

if self.data_format == "channels_first":
h_channel = images[:, 0, :, :] + adjust_factors
h_channel = self.backend.numpy.where(
h_channel > 1.0, h_channel - 1.0, h_channel
)
h_channel = self.backend.numpy.where(
h_channel < 0.0, h_channel + 1.0, h_channel
def _apply_random_hue(images, transformation):
images = self.backend.cast(images, self.compute_dtype)
images = self._transform_value_range(
images, self.value_range, (0, 1)
)
images = self.backend.numpy.stack(
[h_channel, images[:, 1, :, :], images[:, 2, :, :]], axis=1
adjust_factors = transformation["factor"]
adjust_factors = self.backend.cast(adjust_factors, images.dtype)
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
images = self.backend.image.rgb_to_hsv(
images, data_format=self.data_format
)
else:
h_channel = images[..., 0] + adjust_factors
h_channel = self.backend.numpy.where(
h_channel > 1.0, h_channel - 1.0, h_channel
)
h_channel = self.backend.numpy.where(
h_channel < 0.0, h_channel + 1.0, h_channel
if self.data_format == "channels_first":
h_channel = images[:, 0, :, :] + adjust_factors
h_channel = self.backend.numpy.where(
h_channel > 1.0, h_channel - 1.0, h_channel
)
h_channel = self.backend.numpy.where(
h_channel < 0.0, h_channel + 1.0, h_channel
)
images = self.backend.numpy.stack(
[h_channel, images[:, 1, :, :], images[:, 2, :, :]], axis=1
)
else:
h_channel = images[..., 0] + adjust_factors
h_channel = self.backend.numpy.where(
h_channel > 1.0, h_channel - 1.0, h_channel
)
h_channel = self.backend.numpy.where(
h_channel < 0.0, h_channel + 1.0, h_channel
)
images = self.backend.numpy.stack(
[h_channel, images[..., 1], images[..., 2]], axis=-1
)
images = self.backend.image.hsv_to_rgb(
images, data_format=self.data_format
)
images = self.backend.numpy.stack(
[h_channel, images[..., 1], images[..., 2]], axis=-1
images = self.backend.numpy.clip(images, 0, 1)
images = self._transform_value_range(
images, (0, 1), self.value_range
)
images = self.backend.image.hsv_to_rgb(
images, data_format=self.data_format
)
images = self.backend.cast(images, self.compute_dtype)
return images

images = self.backend.numpy.clip(images, 0, 1)
images = self._transform_value_range(images, (0, 1), self.value_range)
images = self.backend.cast(images, self.compute_dtype)
if training:
images = _apply_random_hue(images, transformation)
return images

def transform_labels(self, labels, transformation, training=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def test_layer(self):
expected_output_shape=(8, 3, 4, 3),
)

def test_random_hue_inference(self):
seed = 3481
layer = layers.RandomHue(0.2, [0, 1.0])
np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_random_hue_value_range(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)

Expand Down

0 comments on commit 4c7c4b5

Please sign in to comment.