Skip to content

Commit

Permalink
Support non-tensorflow backends in KerasCV's preprocessing layers (#2240
Browse files Browse the repository at this point in the history
)

* Support non-TF backends in KerasCV's preprocessing layers

* Fix failing PyTorch tests on GPU

* ops.convert_to_tensor -> ops.convert_to_numpy

* Fix remaining instances

* More torch failures

* Remove unrelated change
  • Loading branch information
tirthasheshpatel authored Dec 15, 2023
1 parent b5e7e6a commit e2b627e
Show file tree
Hide file tree
Showing 35 changed files with 442 additions and 297 deletions.
40 changes: 20 additions & 20 deletions keras_cv/layers/preprocessing/aug_mix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def test_return_shapes(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])
self.assertEqual(xs.shape, (2, 512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 3))

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])
self.assertEqual(xs.shape, (2, 512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 1))

def test_in_single_image_and_mask(self):
layer = preprocessing.AugMix([0, 255])
Expand All @@ -54,8 +54,8 @@ def test_in_single_image_and_mask(self):
)

ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [512, 512, 3])
self.assertEqual(xs.shape, (512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (512, 512, 3))

# greyscale
xs = tf.cast(
Expand All @@ -69,8 +69,8 @@ def test_in_single_image_and_mask(self):
dtype=tf.float32,
)
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [512, 512, 1])
self.assertEqual(xs.shape, (512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (512, 512, 1))

def test_non_square_images_and_masks(self):
layer = preprocessing.AugMix([0, 255])
Expand All @@ -80,16 +80,16 @@ def test_non_square_images_and_masks(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 256, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 256, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 3])
self.assertEqual(xs.shape, (2, 256, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 256, 512, 3))

# greyscale
xs = tf.ones((2, 256, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 256, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 256, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 1])
self.assertEqual(xs.shape, (2, 256, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 256, 512, 1))

def test_single_input_args(self):
layer = preprocessing.AugMix([0, 255])
Expand All @@ -99,16 +99,16 @@ def test_single_input_args(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])
self.assertEqual(xs.shape, (2, 512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 3))

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])
self.assertEqual(xs.shape, (2, 512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 1))

def test_many_augmentations(self):
layer = preprocessing.AugMix([0, 255], chain_depth=[25, 26])
Expand All @@ -118,13 +118,13 @@ def test_many_augmentations(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])
self.assertEqual(xs.shape, (2, 512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 3))

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])
self.assertEqual(xs.shape, (2, 512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 1))
66 changes: 33 additions & 33 deletions keras_cv/layers/preprocessing/auto_contrast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import tensorflow as tf

from keras_cv.backend import ops
from keras_cv.layers import preprocessing
from keras_cv.tests.test_case import TestCase


class AutoContrastTest(TestCase):
def test_constant_channels_dont_get_nanned(self):
img = tf.constant([1, 1], dtype=tf.float32)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([1, 1], dtype=np.float32)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 1.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 1.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))

def test_auto_contrast_expands_value_range(self):
img = tf.constant([0, 128], dtype=tf.float32)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([0, 128], dtype=np.float32)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0))

def test_auto_contrast_different_values_per_channel(self):
img = tf.constant(
img = np.array(
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=tf.float32,
dtype=np.float32,
)
img = tf.expand_dims(img, axis=0)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0, ..., 0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0, ..., 1] == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 0.0))

self.assertTrue(tf.math.reduce_any(ys[0, ..., 0] == 255.0))
self.assertTrue(tf.math.reduce_any(ys[0, ..., 1] == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 255.0))

self.assertAllClose(
ys,
Expand All @@ -71,25 +71,25 @@ def test_auto_contrast_different_values_per_channel(self):
)

def test_auto_contrast_expands_value_range_uint8(self):
img = tf.constant([0, 128], dtype=tf.uint8)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([0, 128], dtype=np.uint8)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0))

def test_auto_contrast_properly_converts_value_range(self):
img = tf.constant([0, 0.5], dtype=tf.float32)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([0, 0.5], dtype=np.float32)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 1))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 1.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))
40 changes: 30 additions & 10 deletions keras_cv/layers/preprocessing/base_image_augmentation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import keras
import tensorflow as tf
import tree

if hasattr(keras, "src"):
keras_backend = keras.src.backend
Expand All @@ -23,8 +24,8 @@
from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import scope
from keras_cv.backend.config import keras_3
from keras_cv.utils import preprocessing

# In order to support both unbatched and batched inputs, the horizontal
Expand All @@ -42,15 +43,8 @@
USE_TARGETS = "use_targets"


base_class = (
keras.src.layers.preprocessing.tf_data_layer.TFDataLayer
if keras_3()
else keras.layers.Layer
)


@keras_cv_export("keras_cv.layers.BaseImageAugmentationLayer")
class BaseImageAugmentationLayer(base_class):
class BaseImageAugmentationLayer(keras.layers.Layer):
"""Abstract base layer for image augmentation.
This layer contains base functionalities for preprocessing layers which
Expand Down Expand Up @@ -415,6 +409,19 @@ def get_random_transformation(
return None

def call(self, inputs):
# try to convert a given backend native tensor to TensorFlow tensor
# before passing it over to TFDataScope
contains_ragged = lambda y: any(
tree.map_structure(
lambda x: isinstance(x, (tf.RaggedTensor, tf.SparseTensor)),
tree.flatten(y),
)
)
inputs_contain_ragged = contains_ragged(inputs)
if not inputs_contain_ragged:
inputs = tree.map_structure(
lambda x: tf.convert_to_tensor(x), inputs
)
with scope.TFDataScope():
inputs = self._ensure_inputs_are_compute_dtype(inputs)
inputs, metadata = self._format_inputs(inputs)
Expand All @@ -431,7 +438,20 @@ def call(self, inputs):
"rank 3 (HWC) or 4D (NHWC) tensors. Got shape: "
f"{images.shape}"
)
return outputs
# convert the outputs to backend native tensors if none of them
# contain RaggedTensors. Note that if the user passed in Raggeds
# but the outputs are dense, we still don't want to convert to
# backend native tensors. This is to avoid breaking TF data
# pipelines that can't easily be ported to become backend
# agnostic.
if not inputs_contain_ragged and not contains_ragged(outputs):
outputs = tree.map_structure(
# some layers return None, handle that case when
# converting to tensors
lambda x: ops.convert_to_tensor(x) if x is not None else x,
outputs,
)
return outputs

def _augment(self, inputs):
raw_image = inputs.get(IMAGES, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.backend import ops
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
Expand Down Expand Up @@ -78,17 +80,17 @@ def test_augment_dict_return_type(self):

def test_augment_casts_dtypes(self):
add_layer = RandomAddLayer(fixed_value=2.0)
images = tf.ones((2, 8, 8, 3), dtype="uint8")
images = np.ones((2, 8, 8, 3), dtype="uint8")
output = add_layer(images)

self.assertAllClose(
tf.ones((2, 8, 8, 3), dtype="float32") * 3.0, output
np.ones((2, 8, 8, 3), dtype="float32") * 3.0, output
)

def test_augment_batch_images(self):
add_layer = RandomAddLayer()
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
output = add_layer(images)
output = ops.convert_to_numpy(add_layer(images))

diff = output - images
# Make sure the first image and second image get different augmentation
Expand Down Expand Up @@ -118,8 +120,8 @@ def test_augment_batch_images_and_targets(self):
targets = np.random.random(size=(2, 1)).astype("float32")
output = add_layer({"images": images, "targets": targets})

image_diff = output["images"] - images
label_diff = output["targets"] - targets
image_diff = ops.convert_to_numpy(output["images"]) - images
label_diff = ops.convert_to_numpy(output["targets"]) - targets
# Make sure the first image and second image get different augmentation
self.assertNotAllClose(image_diff[0], image_diff[1])
self.assertNotAllClose(label_diff[0], label_diff[1])
Expand Down Expand Up @@ -225,6 +227,7 @@ def test_augment_batch_image_and_localization_data(self):
segmentation_mask_diff[0], segmentation_mask_diff[1]
)

@pytest.mark.tf_only
def test_augment_all_data_in_tf_function(self):
add_layer = RandomAddLayer()
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
Expand Down
Loading

0 comments on commit e2b627e

Please sign in to comment.