Skip to content

Commit

Permalink
Add support of segmentation mask in Augmix layer (keras-team#1988)
Browse files Browse the repository at this point in the history
* update augmix segmentation mask

* fix

* fix

* added demo

* add test

* update readme

* fix

* fix

* fix
  • Loading branch information
cosmo3769 authored Jul 31, 2023
1 parent baf37b9 commit 2dd44b9
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 5 deletions.
34 changes: 34 additions & 0 deletions examples/layers/preprocessing/segmentation/aug_mix_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""aug_mix_demo.py shows how to use the AugMix preprocessing layer.
Uses the oxford iiit pet_dataset. In this script the pets
are loaded, then are passed through the preprocessing layers.
Finally, they are shown using matplotlib.
"""
import demo_utils
import tensorflow as tf

from keras_cv.layers import preprocessing


def main():
ds = demo_utils.load_oxford_iiit_pet_dataset()
augmix = preprocessing.AugMix([0, 255])
ds = ds.map(augmix, num_parallel_calls=tf.data.AUTOTUNE)
demo_utils.visualize_dataset(ds)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/layers/preprocessing/segmentation/demo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
def normalize(input_image, input_mask):
input_image = tf.image.convert_image_dtype(input_image, tf.float32)
input_image = (input_image - mean) / tf.maximum(std, backend.epsilon())
input_image = input_image / 255
input_mask -= 1
return input_image, input_mask

Expand Down
2 changes: 1 addition & 1 deletion examples/layers/preprocessing/segmentation/resize_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def load_data():
)
return ds.map(
lambda inputs: {
"images": tf.cast(inputs["image"], dtype=tf.float32) / 255.0,
"images": tf.cast(inputs["image"], dtype=tf.float32),
"segmentation_masks": inputs["segmentation_mask"] - 1,
}
)
Expand Down
2 changes: 1 addition & 1 deletion keras_cv/layers/preprocessing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The provided table gives an overview of the different augmentation layers availa

| Layer Name | Vectorized | Segmentation Masks | BBoxes | Class Labels |
| :-- | :--: | :--: | :--: | :--: |
| AugMix || |||
| AugMix || |||
| AutoContrast |||||
| ChannelShuffle |||||
| CutMix |||||
Expand Down
52 changes: 51 additions & 1 deletion keras_cv/layers/preprocessing/aug_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,33 @@ def _apply_op(self, image, op_index):
)
return augmented

def augment_image(self, image, transformation=None, **kwargs):
def get_random_transformation(
self,
image=None,
label=None,
bounding_boxes=None,
keypoints=None,
segmentation_mask=None,
):
# Generate random values of chain_mixing_weights and weight_sample
chain_mixing_weights = self._sample_from_dirichlet(
tf.ones([self.num_chains]) * self.alpha
)
weight_sample = self._sample_from_beta(self.alpha, self.alpha)

# Create a transformation config containing the random values
transformation = {
"chain_mixing_weights": chain_mixing_weights,
"weight_sample": weight_sample,
}

return transformation

def augment_image(self, image, transformation=None, **kwargs):
# Extract chain_mixing_weights and weight_sample from the provided transformation # noqa: E501
chain_mixing_weights = transformation["chain_mixing_weights"]
weight_sample = transformation["weight_sample"]

result = tf.zeros_like(image)
curr_chain = tf.constant([0], dtype=tf.int32)

Expand All @@ -328,6 +349,35 @@ def augment_image(self, image, transformation=None, **kwargs):
def augment_label(self, label, transformation=None, **kwargs):
return label

def augment_segmentation_mask(
self, segmentation_masks, transformation=None, **kwargs
):
# Extract chain_mixing_weights and weight_sample from the provided transformation # noqa: E501
chain_mixing_weights = transformation["chain_mixing_weights"]
weight_sample = transformation["weight_sample"]

result = tf.zeros_like(segmentation_masks)
curr_chain = tf.constant([0], dtype=tf.int32)

(
segmentation_masks,
chain_mixing_weights,
curr_chain,
result,
) = tf.while_loop(
lambda segmentation_masks, chain_mixing_weights, curr_chain, result: tf.less( # noqa: E501
curr_chain, self.num_chains
),
self._loop_on_width,
[segmentation_masks, chain_mixing_weights, curr_chain, result],
)

# Apply the mixing of segmentation_masks similar to images
result = (
weight_sample * segmentation_masks + (1 - weight_sample) * result
)
return result

def get_config(self):
config = {
"value_range": self.value_range,
Expand Down
41 changes: 39 additions & 2 deletions keras_cv/layers/preprocessing/aug_mix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,20 @@ def test_return_shapes(self):
# RGB
xs = tf.ones((2, 512, 512, 3))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])

def test_in_single_image(self):
def test_in_single_image_and_mask(self):
layer = preprocessing.AugMix([0, 255])

# RGB
Expand All @@ -42,7 +48,14 @@ def test_in_single_image(self):
)

xs = layer(xs)
ys_segmentation_masks = tf.cast(
tf.ones((512, 512, 3)),
dtype=tf.float32,
)

ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [512, 512, 3])

# greyscale
xs = tf.cast(
Expand All @@ -51,43 +64,67 @@ def test_in_single_image(self):
)

xs = layer(xs)
ys_segmentation_masks = tf.cast(
tf.ones((512, 512, 1)),
dtype=tf.float32,
)
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [512, 512, 1])

def test_non_square_images(self):
def test_non_square_images_and_masks(self):
layer = preprocessing.AugMix([0, 255])

# RGB
xs = tf.ones((2, 256, 512, 3))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 256, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 256, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 3])

# greyscale
xs = tf.ones((2, 256, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 256, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 256, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 1])

def test_single_input_args(self):
layer = preprocessing.AugMix([0, 255])

# RGB
xs = tf.ones((2, 512, 512, 3))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])

def test_many_augmentations(self):
layer = preprocessing.AugMix([0, 255], chain_depth=[25, 26])

# RGB
xs = tf.ones((2, 512, 512, 3))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])

0 comments on commit 2dd44b9

Please sign in to comment.