-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support of segmentation mask in Augmix layer #1988
Changes from 6 commits
db3a2a3
0659180
6c46409
4fffc56
9db35c8
1091e83
273bdf8
a98f66d
76cc755
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2023 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""aug_mix_demo.py shows how to use the AugMix preprocessing layer. | ||
|
||
Uses the oxford iiit pet_dataset. In this script the pets | ||
are loaded, then are passed through the preprocessing layers. | ||
Finally, they are shown using matplotlib. | ||
""" | ||
import demo_utils | ||
import tensorflow as tf | ||
|
||
from keras_cv.layers import preprocessing | ||
|
||
|
||
def main(): | ||
ds = demo_utils.load_oxford_iiit_pet_dataset() | ||
augmix = preprocessing.AugMix([0, 255]) | ||
ds = ds.map(augmix, num_parallel_calls=tf.data.AUTOTUNE) | ||
demo_utils.visualize_dataset(ds) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -328,6 +328,36 @@ def augment_image(self, image, transformation=None, **kwargs): | |
def augment_label(self, label, transformation=None, **kwargs): | ||
return label | ||
|
||
def augment_segmentation_mask( | ||
self, segmentation_masks, transformation=None, **kwargs | ||
): | ||
chain_mixing_weights = self._sample_from_dirichlet( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sampling these random values differently for the image augmentation and the mask augmentation will cause the image + mask to be augmented inconsistently with one another (your demo image shows, for example, that the dog mask is rotated differently than the dog itself). We should implement and use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking into it. |
||
tf.ones([self.num_chains]) * self.alpha | ||
) | ||
weight_sample = self._sample_from_beta(self.alpha, self.alpha) | ||
|
||
result = tf.zeros_like(segmentation_masks) | ||
curr_chain = tf.constant([0], dtype=tf.int32) | ||
|
||
( | ||
segmentation_masks, | ||
chain_mixing_weights, | ||
curr_chain, | ||
result, | ||
) = tf.while_loop( | ||
lambda segmentation_masks, chain_mixing_weights, curr_chain, result: tf.less( # noqa: E501 | ||
curr_chain, self.num_chains | ||
), | ||
self._loop_on_width, | ||
[segmentation_masks, chain_mixing_weights, curr_chain, result], | ||
) | ||
|
||
# Apply the mixing of segmentation_masks similar to images | ||
result = ( | ||
weight_sample * segmentation_masks + (1 - weight_sample) * result | ||
) | ||
return result | ||
|
||
def get_config(self): | ||
config = { | ||
"value_range": self.value_range, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this break other demos?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this seems incorrect -- we're already scaling by mean and stddev so this shouldn't be done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I checked other demos as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why make this change though? It seems like it's not a reasonable transform given the existing input scale of the image
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this, the value range of image is greater than 255. As we are setting Augmix value range to be [0, 255], so the demo does not work without making this change. Any better alternative? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay this seems more or less reasonable. However, this does make
resize_demo.py
strange, as that file already manually rescales by255.0
on line 34. Let's get rid of that rescaling there and then this is fine.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.