Skip to content

Commit

Permalink
Update base image aug layer tensor conversion (#2281)
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathweb authored Jan 8, 2024
1 parent 72d6120 commit e25de5e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
21 changes: 13 additions & 8 deletions keras_cv/layers/preprocessing/base_image_augmentation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import config
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import scope
Expand Down Expand Up @@ -411,14 +412,16 @@ def get_random_transformation(
def call(self, inputs):
# try to convert a given backend native tensor to TensorFlow tensor
# before passing it over to TFDataScope
is_tf_backend = config.backend() == "tensorflow"
is_in_tf_graph = not tf.executing_eagerly()
contains_ragged = lambda y: any(
tree.map_structure(
lambda x: isinstance(x, (tf.RaggedTensor, tf.SparseTensor)),
tree.flatten(y),
)
)
inputs_contain_ragged = contains_ragged(inputs)
if not inputs_contain_ragged:
if not is_tf_backend and not inputs_contain_ragged:
inputs = tree.map_structure(
lambda x: tf.convert_to_tensor(x), inputs
)
Expand All @@ -444,13 +447,15 @@ def call(self, inputs):
# backend native tensors. This is to avoid breaking TF data
# pipelines that can't easily be ported to become backend
# agnostic.
if not inputs_contain_ragged and not contains_ragged(outputs):
outputs = tree.map_structure(
# some layers return None, handle that case when
# converting to tensors
lambda x: ops.convert_to_tensor(x) if x is not None else x,
outputs,
)
# Skip this step for TF backend or if in `tf.graph` like `tf.data`.
if not is_tf_backend and not is_in_tf_graph:
if not inputs_contain_ragged and not contains_ragged(outputs):
outputs = tree.map_structure(
# some layers return None, handle that case when
# converting to tensors
lambda x: ops.convert_to_tensor(x) if x is not None else x,
outputs,
)
return outputs

def _augment(self, inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,11 @@ def in_tf_function(inputs):
self.assertNotAllClose(
segmentation_mask_diff[0], segmentation_mask_diff[1]
)

def test_augment_tf_data_pipeline(self):
image = np.random.random(size=(1, 8, 8, 3)).astype("float32")
tf_dataset = tf.data.Dataset.from_tensor_slices(image).map(
RandomAddLayer(fixed_value=2.0)
)
output = iter(tf_dataset).get_next()
self.assertAllClose(image[0] + 2.0, output)

0 comments on commit e25de5e

Please sign in to comment.