From 0d3ba37b19af86c704ea7624828490b9a3a732c9 Mon Sep 17 00:00:00 2001 From: Ugeun Park <37043543+shashaka@users.noreply.github.com> Date: Sun, 22 Dec 2024 11:02:03 +0900 Subject: [PATCH] Add random_color_jitter processing layer (#20673) * Add implementations for random_saturation * change parse_factor method to inner method. * Add implementations for random_color_jitter * Fix Randomhue (#20652) * Small fix in random hue * use self.backend for seed * test: add test for class weights (py_dataset adapter) (#20638) * test: add test for class weights (py_dataset adapter) * "call _standardize_batch from enqueuer" m * add more tests, handle pytorch astype issue m * convert to numpy to ensure consistent handling of operations * Fix paths for pytest in contribution guide (#20655) * Add preliminary support of OpenVINO as Keras 3 backend (#19727) * [POC][OV] Support OpenVINO as Keras 3 backend Signed-off-by: Kazantsev, Roman * Mark all unsupported ops from numpy space Signed-off-by: Kazantsev, Roman * Mark unsupported ops in core, image, and linalg spaces Signed-off-by: Kazantsev, Roman * Mark unsupported ops in math, nn, random, and rnn spaces Signed-off-by: Kazantsev, Roman * Fix sorting imports Signed-off-by: Kazantsev, Roman * Format imports Signed-off-by: Kazantsev, Roman * Fix sorting imports Signed-off-by: Kazantsev, Roman * Fix sorting imports Signed-off-by: Kazantsev, Roman * Fix inference Signed-off-by: Kazantsev, Roman * Remove openvino specific code in common part Signed-off-by: Kazantsev, Roman * Fix typo * Clean-up code Signed-off-by: Kazantsev, Roman * Recover imports Signed-off-by: Kazantsev, Roman * Sort imports properly Signed-off-by: Kazantsev, Roman * Format source code Signed-off-by: Kazantsev, Roman * Format the rest of source code Signed-off-by: Kazantsev, Roman * Continue format adjustment Signed-off-by: Kazantsev, Roman * Add OpenVINO dependency Signed-off-by: Kazantsev, Roman * Fix inference using OV backend Signed-off-by: Kazantsev, Roman * Support bert_base_en_uncased and mobilenet_v3_small from Keras Hub Signed-off-by: Kazantsev, Roman * Remove extra openvino specific code from layer.py Signed-off-by: Kazantsev, Roman * Apply code-style formatting Signed-off-by: Kazantsev, Roman * Apply code-style formatting Signed-off-by: Kazantsev, Roman * Fix remained code-style issue Signed-off-by: Kazantsev, Roman * Run tests for OpenVINO backend in GHA Signed-off-by: Kazantsev, Roman * Add config file for openvino backend validation Signed-off-by: Kazantsev, Roman * Add import test for openvino backend Signed-off-by: Kazantsev, Roman * Fix error in import_test.py Signed-off-by: Kazantsev, Roman * Add import_test for openvino backend Signed-off-by: Kazantsev, Roman * Add openvino specific integration tests in GHA Signed-off-by: Kazantsev, Roman * Exclude coverage for OpenVINO Signed-off-by: Kazantsev, Roman * remove coverage for openvino backend Signed-off-by: Kazantsev, Roman * Try layer tests for openvino backend Signed-off-by: Kazantsev, Roman * Run layer tests for openvino backend selectively Signed-off-by: Kazantsev, Roman * Mark enabled tests for openvino backend in a different way Signed-off-by: Kazantsev, Roman * Update .github/workflows/actions.yml * Fix import for BackendVariable Signed-off-by: Kazantsev, Roman * Fix errors in layer tests for openvino backend Signed-off-by: Kazantsev, Roman * Add test for Elu via openvino backend Signed-off-by: Kazantsev, Roman * Fix sorted imports Signed-off-by: Kazantsev, Roman * Extend testing for attention Signed-off-by: Kazantsev, Roman * Update keras/src/layers/attention/attention_test.py * Switch on activation tests for openvino backend Signed-off-by: Kazantsev, Roman * Switch on attention tests for openvino backend Signed-off-by: Kazantsev, Roman * Update keras/src/layers/attention/additive_attention_test.py * Update keras/src/layers/attention/grouped_query_attention_test.py * Run conv tests for openvino backend Signed-off-by: Kazantsev, Roman * Fix convolution in openvino backend Signed-off-by: Kazantsev, Roman * Work around constant creation for tuple Signed-off-by: Kazantsev, Roman * Work around constant creation in reshape Signed-off-by: Kazantsev, Roman * Run depthwise conv tests for openvino backend Signed-off-by: Kazantsev, Roman * Fix get_ov_output for other x types Signed-off-by: Kazantsev, Roman * Fix elu translation Signed-off-by: Kazantsev, Roman * Fix softmax and log_softmax for None axis Signed-off-by: Kazantsev, Roman * Run nn tests for openvino backend Signed-off-by: Kazantsev, Roman * Fix numpy operations for axis to be None Signed-off-by: Kazantsev, Roman * Run operation_test for openvino_backend Signed-off-by: Kazantsev, Roman * Switch on math_test for openvino backend Signed-off-by: Kazantsev, Roman * Switch on image tests for openvino backend Signed-off-by: Kazantsev, Roman * Switch on linalg test for openvino backend Signed-off-by: Kazantsev, Roman * Extend OpenVINOKerasTensor with new built-in methods and fix shape op Signed-off-by: Kazantsev, Roman * Switch on core tests for openvino backend Signed-off-by: Kazantsev, Roman * Use different way of OpenVINO model creation that supports call method Signed-off-by: Kazantsev, Roman * Unify integration test for openvino Signed-off-by: Kazantsev, Roman * Support new operations abs, mod, etc. Signed-off-by: Kazantsev, Roman * Add support for more operations like squeeze, max Signed-off-by: Kazantsev, Roman * Try to use excluded test files list Signed-off-by: Kazantsev, Roman * Apply formatting for normalization_test.py Signed-off-by: Kazantsev, Roman * Correct GHA yml file Signed-off-by: Kazantsev, Roman * Test that openvino backend is used Signed-off-by: Kazantsev, Roman * Revert testing change in excluded test files list Signed-off-by: Kazantsev, Roman * Include testing group Signed-off-by: Kazantsev, Roman * Include legacy test group Signed-off-by: Kazantsev, Roman * Exclude legacy group of tests Signed-off-by: Kazantsev, Roman * Include initializers tests Signed-off-by: Kazantsev, Roman * Skip tests for initializers group Signed-off-by: Kazantsev, Roman * Remove export test group from ignore Signed-off-by: Kazantsev, Roman * Include dtype_policies test group Signed-off-by: Kazantsev, Roman * Reduce ignored tests Signed-off-by: Kazantsev, Roman * Fix ops.cast Signed-off-by: Kazantsev, Roman * Add decorator for custom_gradient Signed-off-by: Kazantsev, Roman * Shorten line in custom_gradient Signed-off-by: Kazantsev, Roman * Ignore dtype_policy_map test Signed-off-by: Kazantsev, Roman * Include callback tests Signed-off-by: Kazantsev, Roman * Switch on backend tests Signed-off-by: Kazantsev, Roman * Exclude failing tests Signed-off-by: Kazantsev, Roman * Correct paths to excluded tests Signed-off-by: Kazantsev, Roman * Switch on some layers tests Signed-off-by: Kazantsev, Roman * Remove pytest.mark.openvino_backend Signed-off-by: Kazantsev, Roman * Register mark requires_trainable_backend Signed-off-by: Kazantsev, Roman * Ignore test files in a different way Signed-off-by: Kazantsev, Roman * Try different way to ignore test files Signed-off-by: Kazantsev, Roman * Fix GHA yml Signed-off-by: Kazantsev, Roman * Support tuple axis for logsumexp Signed-off-by: Kazantsev, Roman * Switch on some ops tests Signed-off-by: Kazantsev, Roman * Switch on some callbacks tests Signed-off-by: Kazantsev, Roman * Add openvino export Signed-off-by: Kazantsev, Roman * Update sklearn tests Signed-off-by: Kazantsev, Roman * Add a comment to skipp numerical_test Signed-off-by: Kazantsev, Roman * Add custom requirements file for OpenVINO Signed-off-by: Kazantsev, Roman * Add reqs of openvino installation for api changes check Signed-off-by: Kazantsev, Roman * Fix types of Variables and switch on some variables tests Signed-off-by: Kazantsev, Roman * Fix nightly code check Signed-off-by: Kazantsev, Roman --------- Signed-off-by: Kazantsev, Roman * Make sklearn dependency optional (#20657) * Add a condition to verify training status during image processing (#20650) * Add a condition to verify training status during image processing * resolve merge conflict * fix transform_bounding_boxes logic * add transform_bounding_boxes test * Fix recurrent dropout for GRU. (#20656) The simplified implementation, which used the same recurrent dropout masks for all the previous states didn't work and caused the training to not converge with large enough recurrent dropout values. This new implementation is now the same as Keras 2. Note that recurrent dropout requires "implementation 1" to be turned on. Fixes https://github.com/keras-team/keras/issues/20276 * Fix example title in probabilistic_metrics.py (#20662) * Change recurrent dropout implementation for LSTM. (#20663) This change is to make the implementation of recurrent dropout consistent with GRU (changed as of https://github.com/keras-team/keras/pull/20656 ) and Keras 2. Also fixed a bug where the GRU fix would break when using CUDNN with a dropout and no recurrent dropout. The solution is to create multiple masks only when needed (implementation == 1). Added coverage for the case when dropout is set and recurrent dropout is not set. * Never pass enable_xla=False or native_serialization=False in tests (#20664) These are invalid options in the latest version of jax2tf, they will just immediately throw. * Fix `PyDatasetAdapterTest::test_class_weight` test with Torch on GPU. (#20665) The test was failing because arrays on device and on cpu were compared. * Fix up torch GPU failing test for mix up (#20666) We need to make sure to use get any tensors places on cpu before using them in the tensorflow backend during preprocessing. * Add random_color_jitter processing layer * Add random_color_jitter test * Update test cases * Correct failed test case * Correct failed test case * Correct failed test case --------- Signed-off-by: Kazantsev, Roman Co-authored-by: IMvision12 <88665786+IMvision12@users.noreply.github.com> Co-authored-by: Enrico Co-authored-by: Marco Co-authored-by: Roman Kazantsev Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: Jasmine Dhantule --- keras/api/_tf_keras/keras/layers/__init__.py | 3 + keras/api/layers/__init__.py | 3 + keras/src/layers/__init__.py | 3 + .../random_color_jitter.py | 197 ++++++++++++++++++ .../random_color_jitter_test.py | 135 ++++++++++++ 5 files changed, 341 insertions(+) create mode 100644 keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py create mode 100644 keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 71db20bf394..7245456cc18 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -155,6 +155,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( + RandomColorJitter, +) from keras.src.layers.preprocessing.image_preprocessing.random_contrast import ( RandomContrast, ) diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index 4c31ded2375..6a3e3b55f14 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -155,6 +155,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( + RandomColorJitter, +) from keras.src.layers.preprocessing.image_preprocessing.random_contrast import ( RandomContrast, ) diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index 584a3cdc1f4..303b3104a56 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -99,6 +99,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( + RandomColorJitter, +) from keras.src.layers.preprocessing.image_preprocessing.random_contrast import ( RandomContrast, ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py new file mode 100644 index 00000000000..eee6f31b8e4 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py @@ -0,0 +1,197 @@ +import keras.src.layers.preprocessing.image_preprocessing.random_brightness as random_brightness # noqa: E501 +import keras.src.layers.preprocessing.image_preprocessing.random_contrast as random_contrast # noqa: E501 +import keras.src.layers.preprocessing.image_preprocessing.random_hue as random_hue # noqa: E501 +import keras.src.layers.preprocessing.image_preprocessing.random_saturation as random_saturation # noqa: E501 +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandomColorJitter") +class RandomColorJitter(BaseImagePreprocessingLayer): + """RandomColorJitter class randomly apply brightness, contrast, saturation + and hue image processing operation sequentially and randomly on the + input. + + Args: + value_range: the range of values the incoming images will have. + Represented as a two number tuple written [low, high]. + This is typically either `[0, 1]` or `[0, 255]` depending + on how your preprocessing pipeline is set up. + brightness_factor: Float or a list/tuple of 2 floats between -1.0 + and 1.0. The factor is used to determine the lower bound and + upper bound of the brightness adjustment. A float value will + be chosen randomly between the limits. When -1.0 is chosen, + the output image will be black, and when 1.0 is chosen, the + image will be fully white. When only one float is provided, + eg, 0.2, then -0.2 will be used for lower bound and 0.2 will + be used for upper bound. + contrast_factor: a positive float represented as fraction of value, + or a tuple of size 2 representing lower and upper bound. When + represented as a single float, lower = upper. The contrast + factor will be randomly picked between `[1.0 - lower, 1.0 + + upper]`. For any pixel x in the channel, the output will be + `(x - mean) * factor + mean` where `mean` is the mean value + of the channel. + saturation_factor: A tuple of two floats or a single float. `factor` + controls the extent to which the image saturation is impacted. + `factor=0.5` makes this layer perform a no-op operation. + `factor=0.0` makes the image fully grayscale. `factor=1.0` + makes the image fully saturated. Values should be between + `0.0` and `1.0`. If a tuple is used, a `factor` is sampled + between the two values for every image augmented. If a single + float is used, a value between `0.0` and the passed float is + sampled. To ensure the value is always the same, pass a tuple + with two identical floats: `(0.5, 0.5)`. + hue_factor: A single float or a tuple of two floats. `factor` + controls the extent to which the image hue is impacted. + `factor=0.0` makes this layer perform a no-op operation, + while a value of `1.0` performs the most aggressive contrast + adjustment available. If a tuple is used, a `factor` is + sampled between the two values for every image augmented. + If a single float is used, a value between `0.0` and the + passed float is sampled. In order to ensure the value is + always the same, please pass a tuple with two identical + floats: `(0.5, 0.5)`. + seed: Integer. Used to create a random seed. + """ + + def __init__( + self, + value_range=(0, 255), + brightness_factor=None, + contrast_factor=None, + saturation_factor=None, + hue_factor=None, + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self.value_range = value_range + self.brightness_factor = brightness_factor + self.contrast_factor = contrast_factor + self.saturation_factor = saturation_factor + self.hue_factor = hue_factor + self.seed = seed + self.generator = SeedGenerator(seed) + + self.random_brightness = None + self.random_contrast = None + self.random_saturation = None + self.random_hue = None + + if self.brightness_factor is not None: + self.random_brightness = random_brightness.RandomBrightness( + factor=self.brightness_factor, + value_range=self.value_range, + seed=self.seed, + ) + + if self.contrast_factor is not None: + self.random_contrast = random_contrast.RandomContrast( + factor=self.contrast_factor, + value_range=self.value_range, + seed=self.seed, + ) + + if self.saturation_factor is not None: + self.random_saturation = random_saturation.RandomSaturation( + factor=self.saturation_factor, + value_range=self.value_range, + seed=self.seed, + ) + + if self.hue_factor is not None: + self.random_hue = random_hue.RandomHue( + factor=self.hue_factor, + value_range=self.value_range, + seed=self.seed, + ) + + def transform_images(self, images, transformation, training=True): + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + images = self.backend.cast(images, self.compute_dtype) + if self.brightness_factor is not None: + if backend_utils.in_tf_graph(): + self.random_brightness.backend.set_backend("tensorflow") + transformation = ( + self.random_brightness.get_random_transformation( + images, + seed=self._get_seed_generator(self.backend._backend), + ) + ) + images = self.random_brightness.transform_images( + images, transformation + ) + if self.contrast_factor is not None: + if backend_utils.in_tf_graph(): + self.random_contrast.backend.set_backend("tensorflow") + transformation = self.random_contrast.get_random_transformation( + images, seed=self._get_seed_generator(self.backend._backend) + ) + transformation["contrast_factor"] = self.backend.cast( + transformation["contrast_factor"], dtype=self.compute_dtype + ) + images = self.random_contrast.transform_images( + images, transformation + ) + if self.saturation_factor is not None: + if backend_utils.in_tf_graph(): + self.random_saturation.backend.set_backend("tensorflow") + transformation = ( + self.random_saturation.get_random_transformation( + images, + seed=self._get_seed_generator(self.backend._backend), + ) + ) + images = self.random_saturation.transform_images( + images, transformation + ) + if self.hue_factor is not None: + if backend_utils.in_tf_graph(): + self.random_hue.backend.set_backend("tensorflow") + transformation = self.random_hue.get_random_transformation( + images, seed=self._get_seed_generator(self.backend._backend) + ) + images = self.random_hue.transform_images( + images, transformation + ) + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "value_range": self.value_range, + "brightness_factor": self.brightness_factor, + "contrast_factor": self.contrast_factor, + "saturation_factor": self.saturation_factor, + "hue_factor": self.hue_factor, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py new file mode 100644 index 00000000000..a465970b6b4 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py @@ -0,0 +1,135 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomColorJitterTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomColorJitter, + init_kwargs={ + "value_range": (20, 200), + "brightness_factor": 0.2, + "contrast_factor": 0.2, + "saturation_factor": 0.2, + "hue_factor": 0.2, + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_color_jitter_inference(self): + seed = 3481 + layer = layers.RandomColorJitter( + value_range=(0, 1), + brightness_factor=0.1, + contrast_factor=0.2, + saturation_factor=0.9, + hue_factor=0.1, + ) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_brightness_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter( + brightness_factor=[0.5, 0.5], seed=seed + ) + output = backend.convert_to_numpy(layer(inputs)) + + layer = layers.RandomBrightness(factor=[0.5, 0.5], seed=seed) + sub_output = backend.convert_to_numpy(layer(inputs)) + + self.assertAllClose(output, sub_output) + + def test_saturation_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter( + saturation_factor=[0.5, 0.5], seed=seed + ) + output = layer(inputs) + + layer = layers.RandomSaturation(factor=[0.5, 0.5], seed=seed) + sub_output = layer(inputs) + + self.assertAllClose(output, sub_output) + + def test_hue_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter(hue_factor=[0.5, 0.5], seed=seed) + output = layer(inputs) + + layer = layers.RandomHue(factor=[0.5, 0.5], seed=seed) + sub_output = layer(inputs) + + self.assertAllClose(output, sub_output) + + def test_contrast_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter(contrast_factor=[0.5, 0.5], seed=seed) + output = layer(inputs) + + layer = layers.RandomContrast(factor=[0.5, 0.5], seed=seed) + sub_output = layer(inputs) + + self.assertAllClose(output, sub_output) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomColorJitter( + value_range=(0, 1), + brightness_factor=0.1, + contrast_factor=0.2, + saturation_factor=0.9, + hue_factor=0.1, + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy()