From 1f6bf4aee77c68052ad0b018c38bda4f56ae6243 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Mon, 11 Dec 2023 14:08:30 -0800 Subject: [PATCH] fix jax failing tests (#2231) --- keras_cv/layers/regularization/drop_path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/layers/regularization/drop_path.py b/keras_cv/layers/regularization/drop_path.py index 4475e2365f..92056e0882 100644 --- a/keras_cv/layers/regularization/drop_path.py +++ b/keras_cv/layers/regularization/drop_path.py @@ -50,7 +50,7 @@ class DropPath(keras.layers.Layer): def __init__(self, rate=0.5, seed=None, **kwargs): super().__init__(**kwargs) self.rate = rate - self.seed = seed + self.seed = keras.random.SeedGenerator(seed=seed) def call(self, x, training=None): if self.rate == 0.0 or not training: