diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py index 2eba5aaa74..ef2e9cefe7 100644 --- a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py @@ -23,6 +23,7 @@ from keras_cv import bounding_box from keras_cv.api_export import keras_cv_export +from keras_cv.backend import config from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.backend import scope @@ -411,6 +412,8 @@ def get_random_transformation( def call(self, inputs): # try to convert a given backend native tensor to TensorFlow tensor # before passing it over to TFDataScope + is_tf_backend = config.backend() == "tensorflow" + is_in_tf_graph = not tf.executing_eagerly() contains_ragged = lambda y: any( tree.map_structure( lambda x: isinstance(x, (tf.RaggedTensor, tf.SparseTensor)), @@ -418,7 +421,7 @@ def call(self, inputs): ) ) inputs_contain_ragged = contains_ragged(inputs) - if not inputs_contain_ragged: + if not is_tf_backend and not inputs_contain_ragged: inputs = tree.map_structure( lambda x: tf.convert_to_tensor(x), inputs ) @@ -444,13 +447,15 @@ def call(self, inputs): # backend native tensors. This is to avoid breaking TF data # pipelines that can't easily be ported to become backend # agnostic. - if not inputs_contain_ragged and not contains_ragged(outputs): - outputs = tree.map_structure( - # some layers return None, handle that case when - # converting to tensors - lambda x: ops.convert_to_tensor(x) if x is not None else x, - outputs, - ) + # Skip this step for TF backend or if in `tf.graph` like `tf.data`. + if not is_tf_backend and not is_in_tf_graph: + if not inputs_contain_ragged and not contains_ragged(outputs): + outputs = tree.map_structure( + # some layers return None, handle that case when + # converting to tensors + lambda x: ops.convert_to_tensor(x) if x is not None else x, + outputs, + ) return outputs def _augment(self, inputs): diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer_test.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer_test.py index aed4dd3af0..edc484e694 100644 --- a/keras_cv/layers/preprocessing/base_image_augmentation_layer_test.py +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer_test.py @@ -265,3 +265,11 @@ def in_tf_function(inputs): self.assertNotAllClose( segmentation_mask_diff[0], segmentation_mask_diff[1] ) + + def test_augment_tf_data_pipeline(self): + image = np.random.random(size=(1, 8, 8, 3)).astype("float32") + tf_dataset = tf.data.Dataset.from_tensor_slices(image).map( + RandomAddLayer(fixed_value=2.0) + ) + output = iter(tf_dataset).get_next() + self.assertAllClose(image[0] + 2.0, output)