Skip to content

Commit

Permalink
Merge branch 'keras-team:master' into fix_drop_block
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli authored Dec 15, 2023
2 parents 3e149b4 + b980f68 commit 1660b07
Show file tree
Hide file tree
Showing 40 changed files with 479 additions and 316 deletions.
41 changes: 24 additions & 17 deletions keras_cv/backend/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random as python_random

from keras_cv.backend import keras
from keras_cv.backend.config import keras_3

Expand All @@ -20,13 +22,19 @@
from keras_core.random import * # noqa: F403, F401


def _make_default_seed():
return python_random.randint(1, int(1e9))


class SeedGenerator:
def __new__(cls, seed=None, **kwargs):
if keras_3():
return keras.random.SeedGenerator(seed=seed, **kwargs)
return super().__new__(cls)

def __init__(self, seed=None):
if seed is None:
seed = _make_default_seed()
self._initial_seed = seed
self._current_seed = [0, seed]

Expand All @@ -42,22 +50,21 @@ def from_config(cls, config):
return cls(**config)


def _get_init_seed(seed):
if keras_3() and isinstance(seed, keras.random.SeedGenerator):
def _draw_seed(seed):
if keras_3():
# Keras 3 seed can be directly passed to random functions
return seed
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0]
if seed[1] is not None:
init_seed += seed[1]
init_seed = seed.next()
else:
init_seed = seed
if seed is None:
seed = _make_default_seed()
init_seed = [0, seed]
return init_seed


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
init_seed = _get_init_seed(seed)
seed = _draw_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand All @@ -66,23 +73,23 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
shape,
mean=mean,
stddev=stddev,
seed=init_seed,
seed=seed,
**kwargs,
)
else:
import tensorflow as tf

return tf.random.normal(
return tf.random.stateless_normal(
shape,
mean=mean,
stddev=stddev,
seed=init_seed,
seed=seed,
**kwargs,
)


def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
init_seed = _get_init_seed(seed)
init_seed = _draw_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand All @@ -97,7 +104,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
else:
import tensorflow as tf

return tf.random.uniform(
return tf.random.stateless_uniform(
shape,
minval=minval,
maxval=maxval,
Expand All @@ -107,17 +114,17 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):


def shuffle(x, axis=0, seed=None):
init_seed = _get_init_seed(seed)
init_seed = _draw_seed(seed)
if keras_3():
return keras.random.shuffle(x=x, axis=axis, seed=init_seed)
else:
import tensorflow as tf

return tf.random.shuffle(x=x, axis=axis, seed=init_seed)
return tf.random.stateless_shuffle(x=x, axis=axis, seed=init_seed)


def categorical(logits, num_samples, dtype=None, seed=None):
init_seed = _get_init_seed(seed)
init_seed = _draw_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand All @@ -131,7 +138,7 @@ def categorical(logits, num_samples, dtype=None, seed=None):
else:
import tensorflow as tf

return tf.random.categorical(
return tf.random.stateless_categorical(
logits=logits,
num_samples=num_samples,
seed=init_seed,
Expand Down
40 changes: 20 additions & 20 deletions keras_cv/layers/preprocessing/aug_mix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def test_return_shapes(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])
self.assertEqual(xs.shape, (2, 512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 3))

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])
self.assertEqual(xs.shape, (2, 512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 1))

def test_in_single_image_and_mask(self):
layer = preprocessing.AugMix([0, 255])
Expand All @@ -54,8 +54,8 @@ def test_in_single_image_and_mask(self):
)

ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [512, 512, 3])
self.assertEqual(xs.shape, (512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (512, 512, 3))

# greyscale
xs = tf.cast(
Expand All @@ -69,8 +69,8 @@ def test_in_single_image_and_mask(self):
dtype=tf.float32,
)
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [512, 512, 1])
self.assertEqual(xs.shape, (512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (512, 512, 1))

def test_non_square_images_and_masks(self):
layer = preprocessing.AugMix([0, 255])
Expand All @@ -80,16 +80,16 @@ def test_non_square_images_and_masks(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 256, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 256, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 3])
self.assertEqual(xs.shape, (2, 256, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 256, 512, 3))

# greyscale
xs = tf.ones((2, 256, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 256, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 256, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 1])
self.assertEqual(xs.shape, (2, 256, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 256, 512, 1))

def test_single_input_args(self):
layer = preprocessing.AugMix([0, 255])
Expand All @@ -99,16 +99,16 @@ def test_single_input_args(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])
self.assertEqual(xs.shape, (2, 512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 3))

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])
self.assertEqual(xs.shape, (2, 512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 1))

def test_many_augmentations(self):
layer = preprocessing.AugMix([0, 255], chain_depth=[25, 26])
Expand All @@ -118,13 +118,13 @@ def test_many_augmentations(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])
self.assertEqual(xs.shape, (2, 512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 3))

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])
self.assertEqual(xs.shape, (2, 512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 1))
66 changes: 33 additions & 33 deletions keras_cv/layers/preprocessing/auto_contrast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import tensorflow as tf

from keras_cv.backend import ops
from keras_cv.layers import preprocessing
from keras_cv.tests.test_case import TestCase


class AutoContrastTest(TestCase):
def test_constant_channels_dont_get_nanned(self):
img = tf.constant([1, 1], dtype=tf.float32)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([1, 1], dtype=np.float32)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 1.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 1.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))

def test_auto_contrast_expands_value_range(self):
img = tf.constant([0, 128], dtype=tf.float32)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([0, 128], dtype=np.float32)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0))

def test_auto_contrast_different_values_per_channel(self):
img = tf.constant(
img = np.array(
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=tf.float32,
dtype=np.float32,
)
img = tf.expand_dims(img, axis=0)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0, ..., 0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0, ..., 1] == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 0.0))

self.assertTrue(tf.math.reduce_any(ys[0, ..., 0] == 255.0))
self.assertTrue(tf.math.reduce_any(ys[0, ..., 1] == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 255.0))

self.assertAllClose(
ys,
Expand All @@ -71,25 +71,25 @@ def test_auto_contrast_different_values_per_channel(self):
)

def test_auto_contrast_expands_value_range_uint8(self):
img = tf.constant([0, 128], dtype=tf.uint8)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([0, 128], dtype=np.uint8)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0))

def test_auto_contrast_properly_converts_value_range(self):
img = tf.constant([0, 0.5], dtype=tf.float32)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([0, 0.5], dtype=np.float32)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 1))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 1.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))
Loading

0 comments on commit 1660b07

Please sign in to comment.