Skip to content

Commit

Permalink
Add support of segmentation mask to fourier mix (keras-team#1991)
Browse files Browse the repository at this point in the history
* added segmentation mask support in fourier mix

* fix

* add test

* add demo

* update readme
  • Loading branch information
cosmo3769 authored Jul 31, 2023
1 parent b4f6cd6 commit 6005ff8
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 7 deletions.
34 changes: 34 additions & 0 deletions examples/layers/preprocessing/segmentation/fourier_mix_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""fourier_mix_demo.py shows how to use the FourierMix preprocessing layer.
Uses the oxford iiit pet_dataset. In this script the pets
are loaded, then are passed through the preprocessing layers.
Finally, they are shown using matplotlib.
"""
import demo_utils
import tensorflow as tf

from keras_cv.layers import preprocessing


def main():
ds = demo_utils.load_oxford_iiit_pet_dataset()
fouriermix = preprocessing.FourierMix(alpha=0.8)
ds = ds.map(fouriermix, num_parallel_calls=tf.data.AUTOTUNE)
demo_utils.visualize_dataset(ds)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion keras_cv/layers/preprocessing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The provided table gives an overview of the different augmentation layers availa
| ChannelShuffle |||||
| CutMix |||||
| Equalization |||||
| FourierMix || |||
| FourierMix || |||
| Grayscale |||||
| GridMask |||||
| JitteredResize |||||
Expand Down
27 changes: 24 additions & 3 deletions keras_cv/layers/preprocessing/fourier_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,27 @@ def _binarise_mask(self, mask, lam, in_shape):
def _batch_augment(self, inputs):
images = inputs.get("images", None)
labels = inputs.get("labels", None)
if images is None or labels is None:
segmentation_masks = inputs.get("segmentation_masks", None)
if images is None or (labels is None and segmentation_masks is None):
raise ValueError(
"FourierMix expects inputs in a dictionary with format "
'{"images": images, "labels": labels}.'
'{"images": images, "segmentation_masks": segmentation_masks}.'
f"Got: inputs = {inputs}"
)
images, lambda_sample, permutation_order = self._fourier_mix(images)
images, masks, lambda_sample, permutation_order = self._fourier_mix(
images
)
if labels is not None:
labels = self._update_labels(
labels, lambda_sample, permutation_order
)
inputs["labels"] = labels
if segmentation_masks is not None:
segmentation_masks = self._update_segmentation_masks(
segmentation_masks, masks, permutation_order
)
inputs["segmentation_masks"] = segmentation_masks
inputs["images"] = images
return inputs

Expand Down Expand Up @@ -198,7 +207,7 @@ def _fourier_mix(self, images):
fmix_images = tf.gather(images, permutation_order)
images = masks * images + (1.0 - masks) * fmix_images

return images, lambda_sample, permutation_order
return images, masks, lambda_sample, permutation_order

def _update_labels(self, labels, lambda_sample, permutation_order):
labels_for_fmix = tf.gather(labels, permutation_order)
Expand All @@ -216,6 +225,18 @@ def _update_labels(self, labels, lambda_sample, permutation_order):
)
return labels

def _update_segmentation_masks(
self, segmentation_masks, masks, permutation_order
):
fmix_segmentation_masks = tf.gather(
segmentation_masks, permutation_order
)

segmentation_masks = (
masks * segmentation_masks + (1.0 - masks) * fmix_segmentation_masks
)
return segmentation_masks

def get_config(self):
config = {
"alpha": self.alpha,
Expand Down
57 changes: 54 additions & 3 deletions keras_cv/layers/preprocessing/fourier_mix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,35 @@ def test_return_shapes(self):
ys = tf.random.categorical(tf.math.log([[0.5, 0.5]]), 2)
ys = tf.squeeze(ys)
ys = tf.one_hot(ys, num_classes)
# randomly sample segmentation mask
ys_segmentation_masks = tf.cast(
tf.stack(
[2 * tf.ones((512, 512)), tf.ones((512, 512))],
axis=0,
),
tf.uint8,
)
ys_segmentation_masks = tf.one_hot(ys_segmentation_masks, 3)

layer = FourierMix()
outputs = layer({"images": xs, "labels": ys})
xs, ys = (
outputs = layer(
{
"images": xs,
"labels": ys,
"segmentation_masks": ys_segmentation_masks,
}
)
xs, ys, ys_segmentation_masks = (
outputs["images"],
outputs["labels"],
outputs["segmentation_masks"],
)

self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys.shape, [2, 10])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])

def test_fourier_mix_call_results(self):
def test_fourier_mix_call_results_with_labels(self):
xs = tf.cast(
tf.stack(
[2 * tf.ones((4, 4, 3)), tf.ones((4, 4, 3))],
Expand All @@ -59,6 +76,40 @@ def test_fourier_mix_call_results(self):
self.assertNotAllClose(ys, 1.0)
self.assertNotAllClose(ys, 0.0)

def test_mix_up_call_results_with_masks(self):
xs = tf.cast(
tf.stack(
[2 * tf.ones((4, 4, 3)), tf.ones((4, 4, 3))],
axis=0,
),
tf.float32,
)
ys_segmentation_masks = tf.cast(
tf.stack(
[2 * tf.ones((4, 4)), tf.ones((4, 4))],
axis=0,
),
tf.uint8,
)
ys_segmentation_masks = tf.one_hot(ys_segmentation_masks, 3)

layer = FourierMix()
outputs = layer(
{"images": xs, "segmentation_masks": ys_segmentation_masks}
)
xs, ys_segmentation_masks = (
outputs["images"],
outputs["segmentation_masks"],
)

# None of the individual values should still be close to 1 or 0
self.assertNotAllClose(xs, 1.0)
self.assertNotAllClose(xs, 2.0)

# No masks should still be close to their originals
self.assertNotAllClose(ys_segmentation_masks, 1.0)
self.assertNotAllClose(ys_segmentation_masks, 0.0)

def test_in_tf_function(self):
xs = tf.cast(
tf.stack(
Expand Down

0 comments on commit 6005ff8

Please sign in to comment.