diff --git a/examples/layers/preprocessing/segmentation/aug_mix_demo.py b/examples/layers/preprocessing/segmentation/aug_mix_demo.py new file mode 100644 index 0000000000..3e6461d3c5 --- /dev/null +++ b/examples/layers/preprocessing/segmentation/aug_mix_demo.py @@ -0,0 +1,34 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""aug_mix_demo.py shows how to use the AugMix preprocessing layer. + +Uses the oxford iiit pet_dataset. In this script the pets +are loaded, then are passed through the preprocessing layers. +Finally, they are shown using matplotlib. +""" +import demo_utils +import tensorflow as tf + +from keras_cv.layers import preprocessing + + +def main(): + ds = demo_utils.load_oxford_iiit_pet_dataset() + augmix = preprocessing.AugMix([0, 255]) + ds = ds.map(augmix, num_parallel_calls=tf.data.AUTOTUNE) + demo_utils.visualize_dataset(ds) + + +if __name__ == "__main__": + main() diff --git a/examples/layers/preprocessing/segmentation/demo_utils.py b/examples/layers/preprocessing/segmentation/demo_utils.py index e7fc4dc2b0..d9611d7912 100644 --- a/examples/layers/preprocessing/segmentation/demo_utils.py +++ b/examples/layers/preprocessing/segmentation/demo_utils.py @@ -28,6 +28,7 @@ def normalize(input_image, input_mask): input_image = tf.image.convert_image_dtype(input_image, tf.float32) input_image = (input_image - mean) / tf.maximum(std, backend.epsilon()) + input_image = input_image / 255 input_mask -= 1 return input_image, input_mask diff --git a/examples/layers/preprocessing/segmentation/resize_demo.py b/examples/layers/preprocessing/segmentation/resize_demo.py index ae70b9a0fa..68adf51e54 100644 --- a/examples/layers/preprocessing/segmentation/resize_demo.py +++ b/examples/layers/preprocessing/segmentation/resize_demo.py @@ -31,7 +31,7 @@ def load_data(): ) return ds.map( lambda inputs: { - "images": tf.cast(inputs["image"], dtype=tf.float32) / 255.0, + "images": tf.cast(inputs["image"], dtype=tf.float32), "segmentation_masks": inputs["segmentation_mask"] - 1, } ) diff --git a/keras_cv/layers/preprocessing/README.md b/keras_cv/layers/preprocessing/README.md index 32c05e6969..fcfacd8cd3 100644 --- a/keras_cv/layers/preprocessing/README.md +++ b/keras_cv/layers/preprocessing/README.md @@ -6,7 +6,7 @@ The provided table gives an overview of the different augmentation layers availa | Layer Name | Vectorized | Segmentation Masks | BBoxes | Class Labels | | :-- | :--: | :--: | :--: | :--: | -| AugMix | ❌ | ❌ | ✅ | ✅ | +| AugMix | ❌ | ✅ | ✅ | ✅ | | AutoContrast | ✅ | ✅ | ✅ | ✅ | | ChannelShuffle | ✅ | ✅ | ✅ | ✅ | | CutMix | ❌ | ✅ | ❌ | ✅ | diff --git a/keras_cv/layers/preprocessing/aug_mix.py b/keras_cv/layers/preprocessing/aug_mix.py index 4949b9eefd..eb18d708df 100644 --- a/keras_cv/layers/preprocessing/aug_mix.py +++ b/keras_cv/layers/preprocessing/aug_mix.py @@ -306,12 +306,33 @@ def _apply_op(self, image, op_index): ) return augmented - def augment_image(self, image, transformation=None, **kwargs): + def get_random_transformation( + self, + image=None, + label=None, + bounding_boxes=None, + keypoints=None, + segmentation_mask=None, + ): + # Generate random values of chain_mixing_weights and weight_sample chain_mixing_weights = self._sample_from_dirichlet( tf.ones([self.num_chains]) * self.alpha ) weight_sample = self._sample_from_beta(self.alpha, self.alpha) + # Create a transformation config containing the random values + transformation = { + "chain_mixing_weights": chain_mixing_weights, + "weight_sample": weight_sample, + } + + return transformation + + def augment_image(self, image, transformation=None, **kwargs): + # Extract chain_mixing_weights and weight_sample from the provided transformation # noqa: E501 + chain_mixing_weights = transformation["chain_mixing_weights"] + weight_sample = transformation["weight_sample"] + result = tf.zeros_like(image) curr_chain = tf.constant([0], dtype=tf.int32) @@ -328,6 +349,35 @@ def augment_image(self, image, transformation=None, **kwargs): def augment_label(self, label, transformation=None, **kwargs): return label + def augment_segmentation_mask( + self, segmentation_masks, transformation=None, **kwargs + ): + # Extract chain_mixing_weights and weight_sample from the provided transformation # noqa: E501 + chain_mixing_weights = transformation["chain_mixing_weights"] + weight_sample = transformation["weight_sample"] + + result = tf.zeros_like(segmentation_masks) + curr_chain = tf.constant([0], dtype=tf.int32) + + ( + segmentation_masks, + chain_mixing_weights, + curr_chain, + result, + ) = tf.while_loop( + lambda segmentation_masks, chain_mixing_weights, curr_chain, result: tf.less( # noqa: E501 + curr_chain, self.num_chains + ), + self._loop_on_width, + [segmentation_masks, chain_mixing_weights, curr_chain, result], + ) + + # Apply the mixing of segmentation_masks similar to images + result = ( + weight_sample * segmentation_masks + (1 - weight_sample) * result + ) + return result + def get_config(self): config = { "value_range": self.value_range, diff --git a/keras_cv/layers/preprocessing/aug_mix_test.py b/keras_cv/layers/preprocessing/aug_mix_test.py index 26ccdb990d..a6b3a6d4ab 100644 --- a/keras_cv/layers/preprocessing/aug_mix_test.py +++ b/keras_cv/layers/preprocessing/aug_mix_test.py @@ -25,14 +25,20 @@ def test_return_shapes(self): # RGB xs = tf.ones((2, 512, 512, 3)) xs = layer(xs) + ys_segmentation_masks = tf.ones((2, 512, 512, 3)) + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [2, 512, 512, 3]) + self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3]) # greyscale xs = tf.ones((2, 512, 512, 1)) xs = layer(xs) + ys_segmentation_masks = tf.ones((2, 512, 512, 1)) + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [2, 512, 512, 1]) + self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1]) - def test_in_single_image(self): + def test_in_single_image_and_mask(self): layer = preprocessing.AugMix([0, 255]) # RGB @@ -42,7 +48,14 @@ def test_in_single_image(self): ) xs = layer(xs) + ys_segmentation_masks = tf.cast( + tf.ones((512, 512, 3)), + dtype=tf.float32, + ) + + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [512, 512, 3]) + self.assertEqual(ys_segmentation_masks.shape, [512, 512, 3]) # greyscale xs = tf.cast( @@ -51,20 +64,32 @@ def test_in_single_image(self): ) xs = layer(xs) + ys_segmentation_masks = tf.cast( + tf.ones((512, 512, 1)), + dtype=tf.float32, + ) + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [512, 512, 1]) + self.assertEqual(ys_segmentation_masks.shape, [512, 512, 1]) - def test_non_square_images(self): + def test_non_square_images_and_masks(self): layer = preprocessing.AugMix([0, 255]) # RGB xs = tf.ones((2, 256, 512, 3)) xs = layer(xs) + ys_segmentation_masks = tf.ones((2, 256, 512, 3)) + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [2, 256, 512, 3]) + self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 3]) # greyscale xs = tf.ones((2, 256, 512, 1)) xs = layer(xs) + ys_segmentation_masks = tf.ones((2, 256, 512, 1)) + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [2, 256, 512, 1]) + self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 1]) def test_single_input_args(self): layer = preprocessing.AugMix([0, 255]) @@ -72,12 +97,18 @@ def test_single_input_args(self): # RGB xs = tf.ones((2, 512, 512, 3)) xs = layer(xs) + ys_segmentation_masks = tf.ones((2, 512, 512, 3)) + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [2, 512, 512, 3]) + self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3]) # greyscale xs = tf.ones((2, 512, 512, 1)) xs = layer(xs) + ys_segmentation_masks = tf.ones((2, 512, 512, 1)) + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [2, 512, 512, 1]) + self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1]) def test_many_augmentations(self): layer = preprocessing.AugMix([0, 255], chain_depth=[25, 26]) @@ -85,9 +116,15 @@ def test_many_augmentations(self): # RGB xs = tf.ones((2, 512, 512, 3)) xs = layer(xs) + ys_segmentation_masks = tf.ones((2, 512, 512, 3)) + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [2, 512, 512, 3]) + self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3]) # greyscale xs = tf.ones((2, 512, 512, 1)) xs = layer(xs) + ys_segmentation_masks = tf.ones((2, 512, 512, 1)) + ys_segmentation_masks = layer(ys_segmentation_masks) self.assertEqual(xs.shape, [2, 512, 512, 1]) + self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])